aboutsummaryrefslogtreecommitdiffstats
path: root/src/local
diff options
context:
space:
mode:
authorEllie Huxtable <e@elm.sh>2021-04-13 19:14:07 +0100
committerGitHub <noreply@github.com>2021-04-13 19:14:07 +0100
commit5751463942cc91f1f1ffaf6e2ac633d7a0085f25 (patch)
treef7b5b9a4702c4c3ef29aa60d36612f61ffeae052 /src/local
parentUpdate config (diff)
downloadatuin-5751463942cc91f1f1ffaf6e2ac633d7a0085f25.zip
Add history sync, resolves #13 (#31)
* Add encryption * Add login and register command * Add count endpoint * Write initial sync push * Add single sync command Confirmed working for one client only * Automatically sync on a configurable frequency * Add key command, key arg to login * Only load session if it exists * Use sync and history timestamps for download * Bind other key code Seems like some systems have this code for up arrow? I'm not sure why, and it's not an easy one to google. * Simplify upload * Try and fix download sync loop * Change sync order to avoid uploading what we just downloaded * Multiline import fix * Fix time parsing * Fix importing history with no time * Add hostname to sync * Use hostname to filter sync * Fixes * Add binding * Stuff from yesterday * Set cursor modes * Make clippy happy * Bump version
Diffstat (limited to 'src/local')
-rw-r--r--src/local/api_client.rs94
-rw-r--r--src/local/database.rs55
-rw-r--r--src/local/encryption.rs108
-rw-r--r--src/local/history.rs11
-rw-r--r--src/local/import.rs116
-rw-r--r--src/local/mod.rs3
-rw-r--r--src/local/sync.rs135
7 files changed, 481 insertions, 41 deletions
diff --git a/src/local/api_client.rs b/src/local/api_client.rs
new file mode 100644
index 00000000..434c07ba
--- /dev/null
+++ b/src/local/api_client.rs
@@ -0,0 +1,94 @@
+use chrono::Utc;
+use eyre::Result;
+use reqwest::header::AUTHORIZATION;
+
+use crate::api::{AddHistoryRequest, CountResponse, ListHistoryResponse};
+use crate::local::encryption::{decrypt, load_key};
+use crate::local::history::History;
+use crate::settings::Settings;
+use crate::utils::hash_str;
+
+pub struct Client<'a> {
+ settings: &'a Settings,
+}
+
+impl<'a> Client<'a> {
+ pub const fn new(settings: &'a Settings) -> Self {
+ Client { settings }
+ }
+
+ pub fn count(&self) -> Result<i64> {
+ let url = format!("{}/sync/count", self.settings.local.sync_address);
+ let client = reqwest::blocking::Client::new();
+
+ let resp = client
+ .get(url)
+ .header(
+ AUTHORIZATION,
+ format!("Token {}", self.settings.local.session_token),
+ )
+ .send()?;
+
+ let count = resp.json::<CountResponse>()?;
+
+ Ok(count.count)
+ }
+
+ pub fn get_history(
+ &self,
+ sync_ts: chrono::DateTime<Utc>,
+ history_ts: chrono::DateTime<Utc>,
+ host: Option<String>,
+ ) -> Result<Vec<History>> {
+ let key = load_key(self.settings)?;
+
+ let host = match host {
+ None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())),
+ Some(h) => h,
+ };
+
+ // this allows for syncing between users on the same machine
+ let url = format!(
+ "{}/sync/history?sync_ts={}&history_ts={}&host={}",
+ self.settings.local.sync_address,
+ sync_ts.to_rfc3339(),
+ history_ts.to_rfc3339(),
+ host,
+ );
+ let client = reqwest::blocking::Client::new();
+
+ let resp = client
+ .get(url)
+ .header(
+ AUTHORIZATION,
+ format!("Token {}", self.settings.local.session_token),
+ )
+ .send()?;
+
+ let history = resp.json::<ListHistoryResponse>()?;
+ let history = history
+ .history
+ .iter()
+ .map(|h| serde_json::from_str(h).expect("invalid base64"))
+ .map(|h| decrypt(&h, &key).expect("failed to decrypt history! check your key"))
+ .collect();
+
+ Ok(history)
+ }
+
+ pub fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
+ let client = reqwest::blocking::Client::new();
+
+ let url = format!("{}/history", self.settings.local.sync_address);
+ client
+ .post(url)
+ .json(history)
+ .header(
+ AUTHORIZATION,
+ format!("Token {}", self.settings.local.session_token),
+ )
+ .send()?;
+
+ Ok(())
+ }
+}
diff --git a/src/local/database.rs b/src/local/database.rs
index ad7078e5..977f11cc 100644
--- a/src/local/database.rs
+++ b/src/local/database.rs
@@ -1,3 +1,4 @@
+use chrono::prelude::*;
use chrono::Utc;
use std::path::Path;
@@ -21,6 +22,10 @@ pub trait Database {
fn update(&self, h: &History) -> Result<()>;
fn history_count(&self) -> Result<i64>;
+ fn first(&self) -> Result<History>;
+ fn last(&self) -> Result<History>;
+ fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>;
+
fn prefix_search(&self, query: &str) -> Result<Vec<History>>;
}
@@ -44,9 +49,7 @@ impl Sqlite {
let conn = Connection::open(path)?;
- if create {
- Self::setup_db(&conn)?;
- }
+ Self::setup_db(&conn)?;
Ok(Self { conn })
}
@@ -70,6 +73,14 @@ impl Sqlite {
[],
)?;
+ conn.execute(
+ "create table if not exists history_encrypted (
+ id text primary key,
+ data blob not null
+ )",
+ [],
+ )?;
+
Ok(())
}
@@ -87,7 +98,7 @@ impl Sqlite {
) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
params![
h.id,
- h.timestamp,
+ h.timestamp.timestamp_nanos(),
h.duration,
h.exit,
h.command,
@@ -146,7 +157,7 @@ impl Database for Sqlite {
"update history
set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8
where id = ?1",
- params![h.id, h.timestamp, h.duration, h.exit, h.command, h.cwd, h.session, h.hostname],
+ params![h.id, h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, h.cwd, h.session, h.hostname],
)?;
Ok(())
@@ -183,6 +194,38 @@ impl Database for Sqlite {
Ok(history_iter.filter_map(Result::ok).collect())
}
+ fn first(&self) -> Result<History> {
+ let mut stmt = self
+ .conn
+ .prepare("SELECT * FROM history order by timestamp asc limit 1")?;
+
+ let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?;
+
+ Ok(history)
+ }
+
+ fn last(&self) -> Result<History> {
+ let mut stmt = self
+ .conn
+ .prepare("SELECT * FROM history order by timestamp desc limit 1")?;
+
+ let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?;
+
+ Ok(history)
+ }
+
+ fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT * FROM history where timestamp <= ? order by timestamp desc limit ?",
+ )?;
+
+ let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| {
+ history_from_sqlite_row(None, row)
+ })?;
+
+ Ok(history_iter.filter_map(Result::ok).collect())
+ }
+
fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>> {
let mut stmt = self.conn.prepare(query)?;
@@ -218,7 +261,7 @@ fn history_from_sqlite_row(
Ok(History {
id,
- timestamp: row.get(1)?,
+ timestamp: Utc.timestamp_nanos(row.get(1)?),
duration: row.get(2)?,
exit: row.get(3)?,
command: row.get(4)?,
diff --git a/src/local/encryption.rs b/src/local/encryption.rs
new file mode 100644
index 00000000..3c1699e3
--- /dev/null
+++ b/src/local/encryption.rs
@@ -0,0 +1,108 @@
+// The general idea is that we NEVER send cleartext history to the server
+// This way the odds of anything private ending up where it should not are
+// very low
+// The server authenticates via the usual username and password. This has
+// nothing to do with the encryption, and is purely authentication! The client
+// generates its own secret key, and encrypts all shell history with libsodium's
+// secretbox. The data is then sent to the server, where it is stored. All
+// clients must share the secret in order to be able to sync, as it is needed
+// to decrypt
+
+use std::fs::File;
+use std::io::prelude::*;
+use std::path::PathBuf;
+
+use eyre::{eyre, Result};
+use sodiumoxide::crypto::secretbox;
+
+use crate::local::history::History;
+use crate::settings::Settings;
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct EncryptedHistory {
+ pub ciphertext: Vec<u8>,
+ pub nonce: secretbox::Nonce,
+}
+
+// Loads the secret key, will create + save if it doesn't exist
+pub fn load_key(settings: &Settings) -> Result<secretbox::Key> {
+ let path = settings.local.key_path.as_str();
+
+ if PathBuf::from(path).exists() {
+ let bytes = std::fs::read(path)?;
+ let key: secretbox::Key = rmp_serde::from_read_ref(&bytes)?;
+ Ok(key)
+ } else {
+ let key = secretbox::gen_key();
+ let buf = rmp_serde::to_vec(&key)?;
+
+ let mut file = File::create(path)?;
+ file.write_all(&buf)?;
+
+ Ok(key)
+ }
+}
+
+pub fn encrypt(history: &History, key: &secretbox::Key) -> Result<EncryptedHistory> {
+ // serialize with msgpack
+ let buf = rmp_serde::to_vec(history)?;
+
+ let nonce = secretbox::gen_nonce();
+
+ let ciphertext = secretbox::seal(&buf, &nonce, key);
+
+ Ok(EncryptedHistory { ciphertext, nonce })
+}
+
+pub fn decrypt(encrypted_history: &EncryptedHistory, key: &secretbox::Key) -> Result<History> {
+ let plaintext = secretbox::open(&encrypted_history.ciphertext, &encrypted_history.nonce, key)
+ .map_err(|_| eyre!("failed to open secretbox - invalid key?"))?;
+
+ let history = rmp_serde::from_read_ref(&plaintext)?;
+
+ Ok(history)
+}
+
+#[cfg(test)]
+mod test {
+ use sodiumoxide::crypto::secretbox;
+
+ use crate::local::history::History;
+
+ use super::{decrypt, encrypt};
+
+ #[test]
+ fn test_encrypt_decrypt() {
+ let key1 = secretbox::gen_key();
+ let key2 = secretbox::gen_key();
+
+ let history = History::new(
+ chrono::Utc::now(),
+ "ls".to_string(),
+ "/home/ellie".to_string(),
+ 0,
+ 1,
+ Some("beep boop".to_string()),
+ Some("booop".to_string()),
+ );
+
+ let e1 = encrypt(&history, &key1).unwrap();
+ let e2 = encrypt(&history, &key2).unwrap();
+
+ assert_ne!(e1.ciphertext, e2.ciphertext);
+ assert_ne!(e1.nonce, e2.nonce);
+
+ // test decryption works
+ // this should pass
+ match decrypt(&e1, &key1) {
+ Err(e) => assert!(false, "failed to decrypt, got {}", e),
+ Ok(h) => assert_eq!(h, history),
+ };
+
+ // this should err
+ match decrypt(&e2, &key1) {
+ Ok(_) => assert!(false, "expected an error decrypting with invalid key"),
+ Err(_) => {}
+ };
+ }
+}
diff --git a/src/local/history.rs b/src/local/history.rs
index 0ca112bd..1712f8b9 100644
--- a/src/local/history.rs
+++ b/src/local/history.rs
@@ -1,12 +1,15 @@
use std::env;
use std::hash::{Hash, Hasher};
+use chrono::Utc;
+
use crate::command::uuid_v4;
-#[derive(Debug, Clone)]
+// Any new fields MUST be Optional<>!
+#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct History {
pub id: String,
- pub timestamp: i64,
+ pub timestamp: chrono::DateTime<Utc>,
pub duration: i64,
pub exit: i64,
pub command: String,
@@ -17,7 +20,7 @@ pub struct History {
impl History {
pub fn new(
- timestamp: i64,
+ timestamp: chrono::DateTime<Utc>,
command: String,
cwd: String,
exit: i64,
@@ -29,7 +32,7 @@ impl History {
.or_else(|| env::var("ATUIN_SESSION").ok())
.unwrap_or_else(uuid_v4);
let hostname =
- hostname.unwrap_or_else(|| hostname::get().unwrap().to_str().unwrap().to_string());
+ hostname.unwrap_or_else(|| format!("{}:{}", whoami::hostname(), whoami::username()));
Self {
id: uuid_v4(),
diff --git a/src/local/import.rs b/src/local/import.rs
index 9bf79c72..d0f679c9 100644
--- a/src/local/import.rs
+++ b/src/local/import.rs
@@ -4,7 +4,9 @@
use std::io::{BufRead, BufReader, Seek, SeekFrom};
use std::{fs::File, path::Path};
-use eyre::{Result, WrapErr};
+use chrono::prelude::*;
+use chrono::Utc;
+use eyre::{eyre, Result};
use super::history::History;
@@ -13,6 +15,7 @@ pub struct Zsh {
file: BufReader<File>,
pub loc: u64,
+ pub counter: i64,
}
// this could probably be sped up
@@ -32,19 +35,23 @@ impl Zsh {
Ok(Self {
file: buf,
loc: loc as u64,
+ counter: 0,
})
}
}
-fn parse_extended(line: &str) -> History {
+fn parse_extended(line: &str, counter: i64) -> History {
let line = line.replacen(": ", "", 2);
let (time, duration) = line.split_once(':').unwrap();
let (duration, command) = duration.split_once(';').unwrap();
- let time = time.parse::<i64>().map_or_else(
- |_| chrono::Utc::now().timestamp_nanos(),
- |t| t * 1_000_000_000,
- );
+ let time = time
+ .parse::<i64>()
+ .unwrap_or_else(|_| chrono::Utc::now().timestamp());
+
+ let offset = chrono::Duration::milliseconds(counter);
+ let time = Utc.timestamp(time, 0);
+ let time = time + offset;
let duration = duration.parse::<i64>().map_or(-1, |t| t * 1_000_000_000);
@@ -60,6 +67,18 @@ fn parse_extended(line: &str) -> History {
)
}
+impl Zsh {
+ fn read_line(&mut self) -> Option<Result<String>> {
+ let mut line = String::new();
+
+ match self.file.read_line(&mut line) {
+ Ok(0) => None,
+ Ok(_) => Some(Ok(line)),
+ Err(e) => Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8
+ }
+ }
+}
+
impl Iterator for Zsh {
type Item = Result<History>;
@@ -68,54 +87,89 @@ impl Iterator for Zsh {
// These lines begin with :
// So, if the line begins with :, parse it. Otherwise it's just
// the command
- let mut line = String::new();
+ let line = self.read_line()?;
- match self.file.read_line(&mut line) {
- Ok(0) => None,
- Ok(_) => {
- let extended = line.starts_with(':');
+ if let Err(e) = line {
+ return Some(Err(e)); // :(
+ }
- if extended {
- Some(Ok(parse_extended(line.as_str())))
- } else {
- Some(Ok(History::new(
- chrono::Utc::now().timestamp_nanos(), // what else? :/
- line.trim_end().to_string(),
- String::from("unknown"),
- -1,
- -1,
- None,
- None,
- )))
- }
+ let mut line = line.unwrap();
+
+ while line.ends_with("\\\n") {
+ let next_line = self.read_line()?;
+
+ if next_line.is_err() {
+ // There's a chance that the last line of a command has invalid
+ // characters, the only safe thing to do is break :/
+ // usually just invalid utf8 or smth
+ // however, we really need to avoid missing history, so it's
+ // better to have some items that should have been part of
+ // something else, than to miss things. So break.
+ break;
}
- Err(e) => Some(Err(e).wrap_err("failed to parse line")),
+
+ line.push_str(next_line.unwrap().as_str());
+ }
+
+ // We have to handle the case where a line has escaped newlines.
+ // Keep reading until we have a non-escaped newline
+
+ let extended = line.starts_with(':');
+
+ if extended {
+ self.counter += 1;
+ Some(Ok(parse_extended(line.as_str(), self.counter)))
+ } else {
+ let time = chrono::Utc::now();
+ let offset = chrono::Duration::seconds(self.counter);
+ let time = time - offset;
+
+ self.counter += 1;
+
+ Some(Ok(History::new(
+ time,
+ line.trim_end().to_string(),
+ String::from("unknown"),
+ -1,
+ -1,
+ None,
+ None,
+ )))
}
}
}
#[cfg(test)]
mod test {
+ use chrono::prelude::*;
+ use chrono::Utc;
+
use super::parse_extended;
#[test]
fn test_parse_extended_simple() {
- let parsed = parse_extended(": 1613322469:0;cargo install atuin");
+ let parsed = parse_extended(": 1613322469:0;cargo install atuin", 0);
assert_eq!(parsed.command, "cargo install atuin");
assert_eq!(parsed.duration, 0);
- assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000);
+ assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0));
- let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update");
+ let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update", 0);
assert_eq!(parsed.command, "cargo install atuin;cargo update");
assert_eq!(parsed.duration, 10_000_000_000);
- assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000);
+ assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0));
- let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷");
+ let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0);
assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷");
assert_eq!(parsed.duration, 10_000_000_000);
- assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000);
+ assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0));
+
+ let parsed = parse_extended(": 1613322469:10;cargo install \\n atuin\n", 0);
+
+ assert_eq!(parsed.command, "cargo install \\n atuin");
+ assert_eq!(parsed.duration, 10_000_000_000);
+ assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0));
}
}
diff --git a/src/local/mod.rs b/src/local/mod.rs
index a11ee213..9fe31292 100644
--- a/src/local/mod.rs
+++ b/src/local/mod.rs
@@ -1,3 +1,6 @@
+pub mod api_client;
pub mod database;
+pub mod encryption;
pub mod history;
pub mod import;
+pub mod sync;
diff --git a/src/local/sync.rs b/src/local/sync.rs
new file mode 100644
index 00000000..c22d2f27
--- /dev/null
+++ b/src/local/sync.rs
@@ -0,0 +1,135 @@
+use std::convert::TryInto;
+
+use chrono::prelude::*;
+use eyre::Result;
+
+use crate::local::api_client;
+use crate::local::database::Database;
+use crate::local::encryption::{encrypt, load_key};
+use crate::settings::{Local, Settings, HISTORY_PAGE_SIZE};
+use crate::{api::AddHistoryRequest, utils::hash_str};
+
+// Currently sync is kinda naive, and basically just pages backwards through
+// history. This means newly added stuff shows up properly! We also just use
+// the total count in each database to indicate whether a sync is needed.
+// I think this could be massively improved! If we had a way of easily
+// indicating count per time period (hour, day, week, year, etc) then we can
+// easily pinpoint where we are missing data and what needs downloading. Start
+// with year, then find the week, then the day, then the hour, then download it
+// all! The current naive approach will do for now.
+
+// Check if remote has things we don't, and if so, download them.
+// Returns (num downloaded, total local)
+fn sync_download(
+ force: bool,
+ client: &api_client::Client,
+ db: &mut impl Database,
+) -> Result<(i64, i64)> {
+ let remote_count = client.count()?;
+
+ let initial_local = db.history_count()?;
+ let mut local_count = initial_local;
+
+ let mut last_sync = if force {
+ Utc.timestamp_millis(0)
+ } else {
+ Local::last_sync()?
+ };
+
+ let mut last_timestamp = Utc.timestamp_millis(0);
+
+ let host = if force { Some(String::from("")) } else { None };
+
+ while remote_count > local_count {
+ let page = client.get_history(last_sync, last_timestamp, host.clone())?;
+
+ if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() {
+ break;
+ }
+
+ db.save_bulk(&page)?;
+
+ local_count = db.history_count()?;
+
+ let page_last = page
+ .last()
+ .expect("could not get last element of page")
+ .timestamp;
+
+ // in the case of a small sync frequency, it's possible for history to
+ // be "lost" between syncs. In this case we need to rewind the sync
+ // timestamps
+ if page_last == last_timestamp {
+ last_timestamp = Utc.timestamp_millis(0);
+ last_sync = last_sync - chrono::Duration::hours(1);
+ } else {
+ last_timestamp = page_last;
+ }
+ }
+
+ Ok((local_count - initial_local, local_count))
+}
+
+// Check if we have things remote doesn't, and if so, upload them
+fn sync_upload(
+ settings: &Settings,
+ _force: bool,
+ client: &api_client::Client,
+ db: &mut impl Database,
+) -> Result<()> {
+ let initial_remote_count = client.count()?;
+ let mut remote_count = initial_remote_count;
+
+ let local_count = db.history_count()?;
+
+ let key = load_key(settings)?; // encryption key
+
+ // first just try the most recent set
+
+ let mut cursor = Utc::now();
+
+ while local_count > remote_count {
+ let last = db.before(cursor, HISTORY_PAGE_SIZE)?;
+ let mut buffer = Vec::<AddHistoryRequest>::new();
+
+ if last.is_empty() {
+ break;
+ }
+
+ for i in last {
+ let data = encrypt(&i, &key)?;
+ let data = serde_json::to_string(&data)?;
+
+ let add_hist = AddHistoryRequest {
+ id: i.id,
+ timestamp: i.timestamp,
+ data,
+ hostname: hash_str(i.hostname.as_str()),
+ };
+
+ buffer.push(add_hist);
+ }
+
+ // anything left over outside of the 100 block size
+ client.post_history(&buffer)?;
+ cursor = buffer.last().unwrap().timestamp;
+
+ remote_count = client.count()?;
+ }
+
+ Ok(())
+}
+
+pub fn sync(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
+ let client = api_client::Client::new(settings);
+
+ sync_upload(settings, force, &client, db)?;
+
+ let download = sync_download(force, &client, db)?;
+
+ debug!("sync downloaded {}", download.0);
+
+ Local::save_sync_time()?;
+
+ Ok(())
+}