From 5751463942cc91f1f1ffaf6e2ac633d7a0085f25 Mon Sep 17 00:00:00 2001 From: Ellie Huxtable Date: Tue, 13 Apr 2021 19:14:07 +0100 Subject: Add history sync, resolves #13 (#31) * Add encryption * Add login and register command * Add count endpoint * Write initial sync push * Add single sync command Confirmed working for one client only * Automatically sync on a configurable frequency * Add key command, key arg to login * Only load session if it exists * Use sync and history timestamps for download * Bind other key code Seems like some systems have this code for up arrow? I'm not sure why, and it's not an easy one to google. * Simplify upload * Try and fix download sync loop * Change sync order to avoid uploading what we just downloaded * Multiline import fix * Fix time parsing * Fix importing history with no time * Add hostname to sync * Use hostname to filter sync * Fixes * Add binding * Stuff from yesterday * Set cursor modes * Make clippy happy * Bump version --- src/api.rs | 36 ++++++++++++ src/command/history.rs | 30 ++++++---- src/command/login.rs | 48 ++++++++++++++++ src/command/mod.rs | 34 +++++++++++- src/command/register.rs | 54 ++++++++++++++++++ src/command/search.rs | 3 +- src/command/server.rs | 4 +- src/command/sync.rs | 15 +++++ src/local/api_client.rs | 94 +++++++++++++++++++++++++++++++ src/local/database.rs | 55 ++++++++++++++++-- src/local/encryption.rs | 108 ++++++++++++++++++++++++++++++++++++ src/local/history.rs | 11 ++-- src/local/import.rs | 118 ++++++++++++++++++++++++++++----------- src/local/mod.rs | 3 + src/local/sync.rs | 135 +++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 19 ++++--- src/remote/auth.rs | 92 +++++++++++++++++++------------ src/remote/database.rs | 2 +- src/remote/models.rs | 16 ++++-- src/remote/server.rs | 26 +++++++-- src/remote/views.rs | 144 ++++++++++++++++++++++++++++++++++++++++-------- src/schema.rs | 4 +- src/settings.rs | 131 ++++++++++++++++++++++++++++++++++++------- src/shell/atuin.zsh | 26 +++------ src/utils.rs | 24 ++++++++ 25 files changed, 1055 insertions(+), 177 deletions(-) create mode 100644 src/api.rs create mode 100644 src/command/login.rs create mode 100644 src/command/register.rs create mode 100644 src/command/sync.rs create mode 100644 src/local/api_client.rs create mode 100644 src/local/encryption.rs create mode 100644 src/local/sync.rs create mode 100644 src/utils.rs (limited to 'src') diff --git a/src/api.rs b/src/api.rs new file mode 100644 index 00000000..90977404 --- /dev/null +++ b/src/api.rs @@ -0,0 +1,36 @@ +use chrono::Utc; + +// This is shared between the client and the server, and has the data structures +// representing the requests/responses for each method. +// TODO: Properly define responses rather than using json! + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterRequest { + pub email: String, + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AddHistoryRequest { + pub id: String, + pub timestamp: chrono::DateTime, + pub data: String, + pub hostname: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CountResponse { + pub count: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListHistoryResponse { + pub history: Vec, +} diff --git a/src/command/history.rs b/src/command/history.rs index 05aed4b9..3b4a717c 100644 --- a/src/command/history.rs +++ b/src/command/history.rs @@ -1,10 +1,13 @@ use std::env; use eyre::Result; +use fork::{fork, Fork}; use structopt::StructOpt; use crate::local::database::Database; use crate::local::history::History; +use crate::local::sync; +use crate::settings::Settings; #[derive(StructOpt)] pub enum Cmd { @@ -50,21 +53,13 @@ fn print_list(h: &[History]) { } impl Cmd { - pub fn run(&self, db: &mut impl Database) -> Result<()> { + pub fn run(&self, settings: &Settings, db: &mut impl Database) -> Result<()> { match self { Self::Start { command: words } => { let command = words.join(" "); let cwd = env::current_dir()?.display().to_string(); - let h = History::new( - chrono::Utc::now().timestamp_nanos(), - command, - cwd, - -1, - -1, - None, - None, - ); + let h = History::new(chrono::Utc::now(), command, cwd, -1, -1, None, None); // print the ID // we use this as the key for calling end @@ -76,10 +71,23 @@ impl Cmd { Self::End { id, exit } => { let mut h = db.load(id)?; h.exit = *exit; - h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp; + h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp.timestamp_nanos(); db.update(&h)?; + if settings.local.should_sync()? { + match fork() { + Ok(Fork::Parent(child)) => { + debug!("launched sync background process with PID {}", child); + } + Ok(Fork::Child) => { + debug!("running periodic background sync"); + sync::sync(settings, false, db)?; + } + Err(_) => println!("Fork failed"), + } + } + Ok(()) } diff --git a/src/command/login.rs b/src/command/login.rs new file mode 100644 index 00000000..4f58b77f --- /dev/null +++ b/src/command/login.rs @@ -0,0 +1,48 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::prelude::*; + +use eyre::Result; +use structopt::StructOpt; + +use crate::settings::Settings; + +#[derive(StructOpt)] +#[structopt(setting(structopt::clap::AppSettings::DeriveDisplayOrder))] +pub struct Cmd { + #[structopt(long, short)] + pub username: String, + + #[structopt(long, short)] + pub password: String, + + #[structopt(long, short, about = "the encryption key for your account")] + pub key: String, +} + +impl Cmd { + pub fn run(&self, settings: &Settings) -> Result<()> { + let mut map = HashMap::new(); + map.insert("username", self.username.clone()); + map.insert("password", self.password.clone()); + + let url = format!("{}/login", settings.local.sync_address); + let client = reqwest::blocking::Client::new(); + let resp = client.post(url).json(&map).send()?; + + let session = resp.json::>()?; + let session = session["session"].clone(); + + let session_path = settings.local.session_path.as_str(); + let mut file = File::create(session_path)?; + file.write_all(session.as_bytes())?; + + let key_path = settings.local.key_path.as_str(); + let mut file = File::create(key_path)?; + file.write_all(&base64::decode(self.key.clone())?)?; + + println!("Logged in!"); + + Ok(()) + } +} diff --git a/src/command/mod.rs b/src/command/mod.rs index a5ea0228..eeb11a87 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -9,9 +9,12 @@ mod event; mod history; mod import; mod init; +mod login; +mod register; mod search; mod server; mod stats; +mod sync; #[derive(StructOpt)] pub enum AtuinCmd { @@ -38,6 +41,21 @@ pub enum AtuinCmd { #[structopt(about = "interactive history search")] Search { query: Vec }, + + #[structopt(about = "sync with the configured server")] + Sync { + #[structopt(long, short, about = "force re-download everything")] + force: bool, + }, + + #[structopt(about = "login to the configured server")] + Login(login::Cmd), + + #[structopt(about = "register with the configured server")] + Register(register::Cmd), + + #[structopt(about = "print the encryption key for transfer to another machine")] + Key, } pub fn uuid_v4() -> String { @@ -47,13 +65,27 @@ pub fn uuid_v4() -> String { impl AtuinCmd { pub fn run(self, db: &mut impl Database, settings: &Settings) -> Result<()> { match self { - Self::History(history) => history.run(db), + Self::History(history) => history.run(settings, db), Self::Import(import) => import.run(db), Self::Server(server) => server.run(settings), Self::Stats(stats) => stats.run(db, settings), Self::Init => init::init(), Self::Search { query } => search::run(&query, db), + Self::Sync { force } => sync::run(settings, force, db), + Self::Login(l) => l.run(settings), + Self::Register(r) => register::run( + settings, + r.username.as_str(), + r.email.as_str(), + r.password.as_str(), + ), + Self::Key => { + let key = std::fs::read(settings.local.key_path.as_str())?; + println!("{}", base64::encode(key)); + Ok(()) + } + Self::Uuid => { println!("{}", uuid_v4()); Ok(()) diff --git a/src/command/register.rs b/src/command/register.rs new file mode 100644 index 00000000..62bbeaeb --- /dev/null +++ b/src/command/register.rs @@ -0,0 +1,54 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::prelude::*; + +use eyre::{eyre, Result}; +use structopt::StructOpt; + +use crate::settings::Settings; + +#[derive(StructOpt)] +#[structopt(setting(structopt::clap::AppSettings::DeriveDisplayOrder))] +pub struct Cmd { + #[structopt(long, short)] + pub username: String, + + #[structopt(long, short)] + pub email: String, + + #[structopt(long, short)] + pub password: String, +} + +pub fn run(settings: &Settings, username: &str, email: &str, password: &str) -> Result<()> { + let mut map = HashMap::new(); + map.insert("username", username); + map.insert("email", email); + map.insert("password", password); + + let url = format!("{}/user/{}", settings.local.sync_address, username); + let resp = reqwest::blocking::get(url)?; + + if resp.status().is_success() { + println!("Username is already in use! Please try another."); + return Ok(()); + } + + let url = format!("{}/register", settings.local.sync_address); + let client = reqwest::blocking::Client::new(); + let resp = client.post(url).json(&map).send()?; + + if !resp.status().is_success() { + println!("Failed to register user - please check your details and try again"); + return Err(eyre!("failed to register user")); + } + + let session = resp.json::>()?; + let session = session["session"].clone(); + + let path = settings.local.session_path.as_str(); + let mut file = File::create(path)?; + file.write_all(session.as_bytes())?; + + Ok(()) +} diff --git a/src/command/search.rs b/src/command/search.rs index d51e29ef..b9f3987c 100644 --- a/src/command/search.rs +++ b/src/command/search.rs @@ -171,7 +171,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result { .iter() .enumerate() .map(|(i, m)| { - let mut content = Span::raw(m.command.to_string()); + let mut content = + Span::raw(m.command.to_string().replace("\n", " ").replace("\t", " ")); if let Some(selected) = app.results_state.selected() { if selected == i { diff --git a/src/command/server.rs b/src/command/server.rs index 5156f409..ba2a9a2f 100644 --- a/src/command/server.rs +++ b/src/command/server.rs @@ -24,10 +24,10 @@ impl Cmd { match self { Self::Start { host, port } => { let host = host.as_ref().map_or( - settings.remote.host.clone(), + settings.server.host.clone(), std::string::ToString::to_string, ); - let port = port.map_or(settings.remote.port, |p| p); + let port = port.map_or(settings.server.port, |p| p); server::launch(settings, host, port); } diff --git a/src/command/sync.rs b/src/command/sync.rs new file mode 100644 index 00000000..facbe578 --- /dev/null +++ b/src/command/sync.rs @@ -0,0 +1,15 @@ +use eyre::Result; + +use crate::local::database::Database; +use crate::local::sync; +use crate::settings::Settings; + +pub fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { + sync::sync(settings, force, db)?; + println!( + "Sync complete! {} items in database, force: {}", + db.history_count()?, + force + ); + Ok(()) +} diff --git a/src/local/api_client.rs b/src/local/api_client.rs new file mode 100644 index 00000000..434c07ba --- /dev/null +++ b/src/local/api_client.rs @@ -0,0 +1,94 @@ +use chrono::Utc; +use eyre::Result; +use reqwest::header::AUTHORIZATION; + +use crate::api::{AddHistoryRequest, CountResponse, ListHistoryResponse}; +use crate::local::encryption::{decrypt, load_key}; +use crate::local::history::History; +use crate::settings::Settings; +use crate::utils::hash_str; + +pub struct Client<'a> { + settings: &'a Settings, +} + +impl<'a> Client<'a> { + pub const fn new(settings: &'a Settings) -> Self { + Client { settings } + } + + pub fn count(&self) -> Result { + let url = format!("{}/sync/count", self.settings.local.sync_address); + let client = reqwest::blocking::Client::new(); + + let resp = client + .get(url) + .header( + AUTHORIZATION, + format!("Token {}", self.settings.local.session_token), + ) + .send()?; + + let count = resp.json::()?; + + Ok(count.count) + } + + pub fn get_history( + &self, + sync_ts: chrono::DateTime, + history_ts: chrono::DateTime, + host: Option, + ) -> Result> { + let key = load_key(self.settings)?; + + let host = match host { + None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())), + Some(h) => h, + }; + + // this allows for syncing between users on the same machine + let url = format!( + "{}/sync/history?sync_ts={}&history_ts={}&host={}", + self.settings.local.sync_address, + sync_ts.to_rfc3339(), + history_ts.to_rfc3339(), + host, + ); + let client = reqwest::blocking::Client::new(); + + let resp = client + .get(url) + .header( + AUTHORIZATION, + format!("Token {}", self.settings.local.session_token), + ) + .send()?; + + let history = resp.json::()?; + let history = history + .history + .iter() + .map(|h| serde_json::from_str(h).expect("invalid base64")) + .map(|h| decrypt(&h, &key).expect("failed to decrypt history! check your key")) + .collect(); + + Ok(history) + } + + pub fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { + let client = reqwest::blocking::Client::new(); + + let url = format!("{}/history", self.settings.local.sync_address); + client + .post(url) + .json(history) + .header( + AUTHORIZATION, + format!("Token {}", self.settings.local.session_token), + ) + .send()?; + + Ok(()) + } +} diff --git a/src/local/database.rs b/src/local/database.rs index ad7078e5..977f11cc 100644 --- a/src/local/database.rs +++ b/src/local/database.rs @@ -1,3 +1,4 @@ +use chrono::prelude::*; use chrono::Utc; use std::path::Path; @@ -21,6 +22,10 @@ pub trait Database { fn update(&self, h: &History) -> Result<()>; fn history_count(&self) -> Result; + fn first(&self) -> Result; + fn last(&self) -> Result; + fn before(&self, timestamp: chrono::DateTime, count: i64) -> Result>; + fn prefix_search(&self, query: &str) -> Result>; } @@ -44,9 +49,7 @@ impl Sqlite { let conn = Connection::open(path)?; - if create { - Self::setup_db(&conn)?; - } + Self::setup_db(&conn)?; Ok(Self { conn }) } @@ -70,6 +73,14 @@ impl Sqlite { [], )?; + conn.execute( + "create table if not exists history_encrypted ( + id text primary key, + data blob not null + )", + [], + )?; + Ok(()) } @@ -87,7 +98,7 @@ impl Sqlite { ) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", params![ h.id, - h.timestamp, + h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, @@ -146,7 +157,7 @@ impl Database for Sqlite { "update history set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8 where id = ?1", - params![h.id, h.timestamp, h.duration, h.exit, h.command, h.cwd, h.session, h.hostname], + params![h.id, h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, h.cwd, h.session, h.hostname], )?; Ok(()) @@ -183,6 +194,38 @@ impl Database for Sqlite { Ok(history_iter.filter_map(Result::ok).collect()) } + fn first(&self) -> Result { + let mut stmt = self + .conn + .prepare("SELECT * FROM history order by timestamp asc limit 1")?; + + let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?; + + Ok(history) + } + + fn last(&self) -> Result { + let mut stmt = self + .conn + .prepare("SELECT * FROM history order by timestamp desc limit 1")?; + + let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?; + + Ok(history) + } + + fn before(&self, timestamp: chrono::DateTime, count: i64) -> Result> { + let mut stmt = self.conn.prepare( + "SELECT * FROM history where timestamp <= ? order by timestamp desc limit ?", + )?; + + let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| { + history_from_sqlite_row(None, row) + })?; + + Ok(history_iter.filter_map(Result::ok).collect()) + } + fn query(&self, query: &str, params: impl Params) -> Result> { let mut stmt = self.conn.prepare(query)?; @@ -218,7 +261,7 @@ fn history_from_sqlite_row( Ok(History { id, - timestamp: row.get(1)?, + timestamp: Utc.timestamp_nanos(row.get(1)?), duration: row.get(2)?, exit: row.get(3)?, command: row.get(4)?, diff --git a/src/local/encryption.rs b/src/local/encryption.rs new file mode 100644 index 00000000..3c1699e3 --- /dev/null +++ b/src/local/encryption.rs @@ -0,0 +1,108 @@ +// The general idea is that we NEVER send cleartext history to the server +// This way the odds of anything private ending up where it should not are +// very low +// The server authenticates via the usual username and password. This has +// nothing to do with the encryption, and is purely authentication! The client +// generates its own secret key, and encrypts all shell history with libsodium's +// secretbox. The data is then sent to the server, where it is stored. All +// clients must share the secret in order to be able to sync, as it is needed +// to decrypt + +use std::fs::File; +use std::io::prelude::*; +use std::path::PathBuf; + +use eyre::{eyre, Result}; +use sodiumoxide::crypto::secretbox; + +use crate::local::history::History; +use crate::settings::Settings; + +#[derive(Debug, Serialize, Deserialize)] +pub struct EncryptedHistory { + pub ciphertext: Vec, + pub nonce: secretbox::Nonce, +} + +// Loads the secret key, will create + save if it doesn't exist +pub fn load_key(settings: &Settings) -> Result { + let path = settings.local.key_path.as_str(); + + if PathBuf::from(path).exists() { + let bytes = std::fs::read(path)?; + let key: secretbox::Key = rmp_serde::from_read_ref(&bytes)?; + Ok(key) + } else { + let key = secretbox::gen_key(); + let buf = rmp_serde::to_vec(&key)?; + + let mut file = File::create(path)?; + file.write_all(&buf)?; + + Ok(key) + } +} + +pub fn encrypt(history: &History, key: &secretbox::Key) -> Result { + // serialize with msgpack + let buf = rmp_serde::to_vec(history)?; + + let nonce = secretbox::gen_nonce(); + + let ciphertext = secretbox::seal(&buf, &nonce, key); + + Ok(EncryptedHistory { ciphertext, nonce }) +} + +pub fn decrypt(encrypted_history: &EncryptedHistory, key: &secretbox::Key) -> Result { + let plaintext = secretbox::open(&encrypted_history.ciphertext, &encrypted_history.nonce, key) + .map_err(|_| eyre!("failed to open secretbox - invalid key?"))?; + + let history = rmp_serde::from_read_ref(&plaintext)?; + + Ok(history) +} + +#[cfg(test)] +mod test { + use sodiumoxide::crypto::secretbox; + + use crate::local::history::History; + + use super::{decrypt, encrypt}; + + #[test] + fn test_encrypt_decrypt() { + let key1 = secretbox::gen_key(); + let key2 = secretbox::gen_key(); + + let history = History::new( + chrono::Utc::now(), + "ls".to_string(), + "/home/ellie".to_string(), + 0, + 1, + Some("beep boop".to_string()), + Some("booop".to_string()), + ); + + let e1 = encrypt(&history, &key1).unwrap(); + let e2 = encrypt(&history, &key2).unwrap(); + + assert_ne!(e1.ciphertext, e2.ciphertext); + assert_ne!(e1.nonce, e2.nonce); + + // test decryption works + // this should pass + match decrypt(&e1, &key1) { + Err(e) => assert!(false, "failed to decrypt, got {}", e), + Ok(h) => assert_eq!(h, history), + }; + + // this should err + match decrypt(&e2, &key1) { + Ok(_) => assert!(false, "expected an error decrypting with invalid key"), + Err(_) => {} + }; + } +} diff --git a/src/local/history.rs b/src/local/history.rs index 0ca112bd..1712f8b9 100644 --- a/src/local/history.rs +++ b/src/local/history.rs @@ -1,12 +1,15 @@ use std::env; use std::hash::{Hash, Hasher}; +use chrono::Utc; + use crate::command::uuid_v4; -#[derive(Debug, Clone)] +// Any new fields MUST be Optional<>! +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct History { pub id: String, - pub timestamp: i64, + pub timestamp: chrono::DateTime, pub duration: i64, pub exit: i64, pub command: String, @@ -17,7 +20,7 @@ pub struct History { impl History { pub fn new( - timestamp: i64, + timestamp: chrono::DateTime, command: String, cwd: String, exit: i64, @@ -29,7 +32,7 @@ impl History { .or_else(|| env::var("ATUIN_SESSION").ok()) .unwrap_or_else(uuid_v4); let hostname = - hostname.unwrap_or_else(|| hostname::get().unwrap().to_str().unwrap().to_string()); + hostname.unwrap_or_else(|| format!("{}:{}", whoami::hostname(), whoami::username())); Self { id: uuid_v4(), diff --git a/src/local/import.rs b/src/local/import.rs index 9bf79c72..d0f679c9 100644 --- a/src/local/import.rs +++ b/src/local/import.rs @@ -4,7 +4,9 @@ use std::io::{BufRead, BufReader, Seek, SeekFrom}; use std::{fs::File, path::Path}; -use eyre::{Result, WrapErr}; +use chrono::prelude::*; +use chrono::Utc; +use eyre::{eyre, Result}; use super::history::History; @@ -13,6 +15,7 @@ pub struct Zsh { file: BufReader, pub loc: u64, + pub counter: i64, } // this could probably be sped up @@ -32,19 +35,23 @@ impl Zsh { Ok(Self { file: buf, loc: loc as u64, + counter: 0, }) } } -fn parse_extended(line: &str) -> History { +fn parse_extended(line: &str, counter: i64) -> History { let line = line.replacen(": ", "", 2); let (time, duration) = line.split_once(':').unwrap(); let (duration, command) = duration.split_once(';').unwrap(); - let time = time.parse::().map_or_else( - |_| chrono::Utc::now().timestamp_nanos(), - |t| t * 1_000_000_000, - ); + let time = time + .parse::() + .unwrap_or_else(|_| chrono::Utc::now().timestamp()); + + let offset = chrono::Duration::milliseconds(counter); + let time = Utc.timestamp(time, 0); + let time = time + offset; let duration = duration.parse::().map_or(-1, |t| t * 1_000_000_000); @@ -60,6 +67,18 @@ fn parse_extended(line: &str) -> History { ) } +impl Zsh { + fn read_line(&mut self) -> Option> { + let mut line = String::new(); + + match self.file.read_line(&mut line) { + Ok(0) => None, + Ok(_) => Some(Ok(line)), + Err(e) => Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 + } + } +} + impl Iterator for Zsh { type Item = Result; @@ -68,54 +87,89 @@ impl Iterator for Zsh { // These lines begin with : // So, if the line begins with :, parse it. Otherwise it's just // the command - let mut line = String::new(); + let line = self.read_line()?; - match self.file.read_line(&mut line) { - Ok(0) => None, - Ok(_) => { - let extended = line.starts_with(':'); - - if extended { - Some(Ok(parse_extended(line.as_str()))) - } else { - Some(Ok(History::new( - chrono::Utc::now().timestamp_nanos(), // what else? :/ - line.trim_end().to_string(), - String::from("unknown"), - -1, - -1, - None, - None, - ))) - } + if let Err(e) = line { + return Some(Err(e)); // :( + } + + let mut line = line.unwrap(); + + while line.ends_with("\\\n") { + let next_line = self.read_line()?; + + if next_line.is_err() { + // There's a chance that the last line of a command has invalid + // characters, the only safe thing to do is break :/ + // usually just invalid utf8 or smth + // however, we really need to avoid missing history, so it's + // better to have some items that should have been part of + // something else, than to miss things. So break. + break; } - Err(e) => Some(Err(e).wrap_err("failed to parse line")), + + line.push_str(next_line.unwrap().as_str()); + } + + // We have to handle the case where a line has escaped newlines. + // Keep reading until we have a non-escaped newline + + let extended = line.starts_with(':'); + + if extended { + self.counter += 1; + Some(Ok(parse_extended(line.as_str(), self.counter))) + } else { + let time = chrono::Utc::now(); + let offset = chrono::Duration::seconds(self.counter); + let time = time - offset; + + self.counter += 1; + + Some(Ok(History::new( + time, + line.trim_end().to_string(), + String::from("unknown"), + -1, + -1, + None, + None, + ))) } } } #[cfg(test)] mod test { + use chrono::prelude::*; + use chrono::Utc; + use super::parse_extended; #[test] fn test_parse_extended_simple() { - let parsed = parse_extended(": 1613322469:0;cargo install atuin"); + let parsed = parse_extended(": 1613322469:0;cargo install atuin", 0); assert_eq!(parsed.command, "cargo install atuin"); assert_eq!(parsed.duration, 0); - assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000); + assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); - let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update"); + let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update", 0); assert_eq!(parsed.command, "cargo install atuin;cargo update"); assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000); + assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); - let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); + let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000); + assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); + + let parsed = parse_extended(": 1613322469:10;cargo install \\n atuin\n", 0); + + assert_eq!(parsed.command, "cargo install \\n atuin"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); } } diff --git a/src/local/mod.rs b/src/local/mod.rs index a11ee213..9fe31292 100644 --- a/src/local/mod.rs +++ b/src/local/mod.rs @@ -1,3 +1,6 @@ +pub mod api_client; pub mod database; +pub mod encryption; pub mod history; pub mod import; +pub mod sync; diff --git a/src/local/sync.rs b/src/local/sync.rs new file mode 100644 index 00000000..c22d2f27 --- /dev/null +++ b/src/local/sync.rs @@ -0,0 +1,135 @@ +use std::convert::TryInto; + +use chrono::prelude::*; +use eyre::Result; + +use crate::local::api_client; +use crate::local::database::Database; +use crate::local::encryption::{encrypt, load_key}; +use crate::settings::{Local, Settings, HISTORY_PAGE_SIZE}; +use crate::{api::AddHistoryRequest, utils::hash_str}; + +// Currently sync is kinda naive, and basically just pages backwards through +// history. This means newly added stuff shows up properly! We also just use +// the total count in each database to indicate whether a sync is needed. +// I think this could be massively improved! If we had a way of easily +// indicating count per time period (hour, day, week, year, etc) then we can +// easily pinpoint where we are missing data and what needs downloading. Start +// with year, then find the week, then the day, then the hour, then download it +// all! The current naive approach will do for now. + +// Check if remote has things we don't, and if so, download them. +// Returns (num downloaded, total local) +fn sync_download( + force: bool, + client: &api_client::Client, + db: &mut impl Database, +) -> Result<(i64, i64)> { + let remote_count = client.count()?; + + let initial_local = db.history_count()?; + let mut local_count = initial_local; + + let mut last_sync = if force { + Utc.timestamp_millis(0) + } else { + Local::last_sync()? + }; + + let mut last_timestamp = Utc.timestamp_millis(0); + + let host = if force { Some(String::from("")) } else { None }; + + while remote_count > local_count { + let page = client.get_history(last_sync, last_timestamp, host.clone())?; + + if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() { + break; + } + + db.save_bulk(&page)?; + + local_count = db.history_count()?; + + let page_last = page + .last() + .expect("could not get last element of page") + .timestamp; + + // in the case of a small sync frequency, it's possible for history to + // be "lost" between syncs. In this case we need to rewind the sync + // timestamps + if page_last == last_timestamp { + last_timestamp = Utc.timestamp_millis(0); + last_sync = last_sync - chrono::Duration::hours(1); + } else { + last_timestamp = page_last; + } + } + + Ok((local_count - initial_local, local_count)) +} + +// Check if we have things remote doesn't, and if so, upload them +fn sync_upload( + settings: &Settings, + _force: bool, + client: &api_client::Client, + db: &mut impl Database, +) -> Result<()> { + let initial_remote_count = client.count()?; + let mut remote_count = initial_remote_count; + + let local_count = db.history_count()?; + + let key = load_key(settings)?; // encryption key + + // first just try the most recent set + + let mut cursor = Utc::now(); + + while local_count > remote_count { + let last = db.before(cursor, HISTORY_PAGE_SIZE)?; + let mut buffer = Vec::::new(); + + if last.is_empty() { + break; + } + + for i in last { + let data = encrypt(&i, &key)?; + let data = serde_json::to_string(&data)?; + + let add_hist = AddHistoryRequest { + id: i.id, + timestamp: i.timestamp, + data, + hostname: hash_str(i.hostname.as_str()), + }; + + buffer.push(add_hist); + } + + // anything left over outside of the 100 block size + client.post_history(&buffer)?; + cursor = buffer.last().unwrap().timestamp; + + remote_count = client.count()?; + } + + Ok(()) +} + +pub fn sync(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { + let client = api_client::Client::new(settings); + + sync_upload(settings, force, &client, db)?; + + let download = sync_download(force, &client, db)?; + + debug!("sync downloaded {}", download.0); + + Local::save_sync_time()?; + + Ok(()) +} diff --git a/src/main.rs b/src/main.rs index bac75362..ae459807 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ use std::path::PathBuf; use eyre::{eyre, Result}; -use structopt::StructOpt; +use structopt::{clap::AppSettings, StructOpt}; #[macro_use] extern crate log; @@ -30,18 +30,21 @@ use command::AtuinCmd; use local::database::Sqlite; use settings::Settings; +mod api; mod command; mod local; mod remote; mod settings; +mod utils; pub mod schema; #[derive(StructOpt)] #[structopt( author = "Ellie Huxtable ", - version = "0.4.0", - about = "Magical shell history" + version = "0.5.0", + about = "Magical shell history", + global_settings(&[AppSettings::ColoredHelp, AppSettings::DeriveDisplayOrder]) )] struct Atuin { #[structopt(long, parse(from_os_str), help = "db file path")] @@ -52,9 +55,7 @@ struct Atuin { } impl Atuin { - fn run(self) -> Result<()> { - let settings = Settings::new()?; - + fn run(self, settings: &Settings) -> Result<()> { let db_path = if let Some(db_path) = self.db { let path = db_path .to_str() @@ -67,11 +68,13 @@ impl Atuin { let mut db = Sqlite::new(db_path)?; - self.atuin.run(&mut db, &settings) + self.atuin.run(&mut db, settings) } } fn main() -> Result<()> { + let settings = Settings::new()?; + fern::Dispatch::new() .format(|out, message, record| { out.finish(format_args!( @@ -85,5 +88,5 @@ fn main() -> Result<()> { .chain(std::io::stdout()) .apply()?; - Atuin::from_args().run() + Atuin::from_args().run(&settings) } diff --git a/src/remote/auth.rs b/src/remote/auth.rs index 8f9e9b46..cf61b077 100644 --- a/src/remote/auth.rs +++ b/src/remote/auth.rs @@ -1,6 +1,8 @@ use self::diesel::prelude::*; +use eyre::Result; use rocket::http::Status; use rocket::request::{self, FromRequest, Outcome, Request}; +use rocket::State; use rocket_contrib::databases::diesel; use sodiumoxide::crypto::pwhash::argon2id13; @@ -9,7 +11,11 @@ use uuid::Uuid; use super::models::{NewSession, NewUser, Session, User}; use super::views::ApiResponse; + +use crate::api::{LoginRequest, RegisterRequest}; use crate::schema::{sessions, users}; +use crate::settings::Settings; +use crate::utils::hash_secret; use super::database::AtuinDbConn; @@ -19,20 +25,6 @@ pub enum KeyError { Invalid, } -pub fn hash_str(secret: &str) -> String { - sodiumoxide::init().unwrap(); - let hash = argon2id13::pwhash( - secret.as_bytes(), - argon2id13::OPSLIMIT_INTERACTIVE, - argon2id13::MEMLIMIT_INTERACTIVE, - ) - .unwrap(); - let texthash = std::str::from_utf8(&hash.0).unwrap().to_string(); - - // postgres hates null chars. don't do that to postgres - texthash.trim_end_matches('\u{0}').to_string() -} - pub fn verify_str(secret: &str, verify: &str) -> bool { sodiumoxide::init().unwrap(); @@ -95,19 +87,54 @@ impl<'a, 'r> FromRequest<'a, 'r> for User { } } -#[derive(Deserialize)] -pub struct Register { - email: String, - password: String, +#[get("/user/")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse { + use crate::schema::users::dsl::{username, users}; + + let user: Result = users + .select(username) + .filter(username.eq(user)) + .first(&*conn); + + if user.is_err() { + return ApiResponse { + json: json!({ + "message": "could not find user", + }), + status: Status::NotFound, + }; + } + + let user = user.unwrap(); + + ApiResponse { + json: json!({ "username": user.as_str() }), + status: Status::Ok, + } } #[post("/register", data = "")] #[allow(clippy::clippy::needless_pass_by_value)] -pub fn register(conn: AtuinDbConn, register: Json) -> ApiResponse { - let hashed = hash_str(register.password.as_str()); +pub fn register( + conn: AtuinDbConn, + register: Json, + settings: State, +) -> ApiResponse { + if !settings.server.open_registration { + return ApiResponse { + status: Status::BadRequest, + json: json!({ + "message": "registrations are not open" + }), + }; + } + + let hashed = hash_secret(register.password.as_str()); let new_user = NewUser { email: register.email.as_str(), + username: register.username.as_str(), password: hashed.as_str(), }; @@ -119,8 +146,7 @@ pub fn register(conn: AtuinDbConn, register: Json) -> ApiResponse { return ApiResponse { status: Status::BadRequest, json: json!({ - "status": "error", - "message": "failed to create user - is the email already in use?", + "message": "failed to create user - username or email in use?", }), }; } @@ -139,32 +165,26 @@ pub fn register(conn: AtuinDbConn, register: Json) -> ApiResponse { { Ok(_) => ApiResponse { status: Status::Ok, - json: json!({"status": "ok", "message": "user created!", "session": token}), + json: json!({"message": "user created!", "session": token}), }, Err(_) => ApiResponse { status: Status::BadRequest, - json: json!({"status": "error", "message": "failed to create user"}), + json: json!({ "message": "failed to create user"}), }, } } -#[derive(Deserialize)] -pub struct Login { - email: String, - password: String, -} - #[post("/login", data = "")] #[allow(clippy::clippy::needless_pass_by_value)] -pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { +pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { let user = users::table - .filter(users::email.eq(login.email.as_str())) + .filter(users::username.eq(login.username.as_str())) .first(&*conn); if user.is_err() { return ApiResponse { status: Status::NotFound, - json: json!({"status": "error", "message": "user not found"}), + json: json!({"message": "user not found"}), }; } @@ -178,7 +198,7 @@ pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { if session.is_err() { return ApiResponse { status: Status::InternalServerError, - json: json!({"status": "error", "message": "something went wrong"}), + json: json!({"message": "something went wrong"}), }; } @@ -187,7 +207,7 @@ pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { if !verified { return ApiResponse { status: Status::NotFound, - json: json!({"status": "error", "message": "user not found"}), + json: json!({"message": "user not found"}), }; } @@ -195,6 +215,6 @@ pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { ApiResponse { status: Status::Ok, - json: json!({"status": "ok", "token": session.token}), + json: json!({"session": session.token}), } } diff --git a/src/remote/database.rs b/src/remote/database.rs index fabd07de..ddcffda0 100644 --- a/src/remote/database.rs +++ b/src/remote/database.rs @@ -8,7 +8,7 @@ pub struct AtuinDbConn(diesel::PgConnection); // TODO: connection pooling pub fn establish_connection(settings: &Settings) -> PgConnection { - let database_url = &settings.remote.db_uri; + let database_url = &settings.server.db_uri; PgConnection::establish(database_url) .unwrap_or_else(|_| panic!("Error connecting to {}", database_url)) } diff --git a/src/remote/models.rs b/src/remote/models.rs index 058b2f0b..7f6f7766 100644 --- a/src/remote/models.rs +++ b/src/remote/models.rs @@ -1,23 +1,26 @@ -use chrono::naive::NaiveDateTime; +use chrono::prelude::*; use crate::schema::{history, sessions, users}; -#[derive(Identifiable, Queryable, Associations)] +#[derive(Deserialize, Serialize, Identifiable, Queryable, Associations)] #[table_name = "history"] #[belongs_to(User)] pub struct History { pub id: i64, - pub client_id: String, + pub client_id: String, // a client generated ID pub user_id: i64, - pub mac: String, + pub hostname: String, pub timestamp: NaiveDateTime, pub data: String, + + pub created_at: NaiveDateTime, } #[derive(Identifiable, Queryable, Associations)] pub struct User { pub id: i64, + pub username: String, pub email: String, pub password: String, } @@ -35,8 +38,8 @@ pub struct Session { pub struct NewHistory<'a> { pub client_id: &'a str, pub user_id: i64, - pub mac: &'a str, - pub timestamp: NaiveDateTime, + pub hostname: String, + pub timestamp: chrono::NaiveDateTime, pub data: &'a str, } @@ -44,6 +47,7 @@ pub struct NewHistory<'a> { #[derive(Insertable)] #[table_name = "users"] pub struct NewUser<'a> { + pub username: &'a str, pub email: &'a str, pub password: &'a str, } diff --git a/src/remote/server.rs b/src/remote/server.rs index cd2ca7b8..de58397d 100644 --- a/src/remote/server.rs +++ b/src/remote/server.rs @@ -17,13 +17,15 @@ use super::auth::*; embed_migrations!("migrations"); pub fn launch(settings: &Settings, host: String, port: u16) { + let settings: Settings = settings.clone(); // clone so rocket can manage it + let mut database_config = HashMap::new(); let mut databases = HashMap::new(); - database_config.insert("url", Value::from(settings.remote.db_uri.clone())); + database_config.insert("url", Value::from(settings.server.db_uri.clone())); databases.insert("atuin", Value::from(database_config)); - let connection = establish_connection(settings); + let connection = establish_connection(&settings); embedded_migrations::run(&connection).expect("failed to run migrations"); let config = Config::build(Environment::Production) @@ -36,8 +38,20 @@ pub fn launch(settings: &Settings, host: String, port: u16) { let app = rocket::custom(config); - app.mount("/", routes![index, register, add_history, login]) - .attach(AtuinDbConn::fairing()) - .register(catchers![internal_error, bad_request]) - .launch(); + app.mount( + "/", + routes![ + index, + register, + add_history, + login, + get_user, + sync_count, + sync_list + ], + ) + .manage(settings) + .attach(AtuinDbConn::fairing()) + .register(catchers![internal_error, bad_request]) + .launch(); } diff --git a/src/remote/views.rs b/src/remote/views.rs index 2af3f369..08dff13e 100644 --- a/src/remote/views.rs +++ b/src/remote/views.rs @@ -1,14 +1,22 @@ -use self::diesel::prelude::*; +use chrono::Utc; +use rocket::http::uri::Uri; +use rocket::http::RawStr; use rocket::http::{ContentType, Status}; +use rocket::request::FromFormValue; use rocket::request::Request; use rocket::response; use rocket::response::{Responder, Response}; use rocket_contrib::databases::diesel; use rocket_contrib::json::{Json, JsonValue}; -use super::database::AtuinDbConn; -use super::models::{NewHistory, User}; +use self::diesel::prelude::*; + +use crate::api::AddHistoryRequest; use crate::schema::history; +use crate::settings::HISTORY_PAGE_SIZE; + +use super::database::AtuinDbConn; +use super::models::{History, NewHistory, User}; #[derive(Debug)] pub struct ApiResponse { @@ -46,40 +54,36 @@ pub fn bad_request(_req: &Request) -> ApiResponse { } } -#[derive(Deserialize)] -pub struct AddHistory { - id: String, - timestamp: i64, - data: String, - mac: String, -} - #[post("/history", data = "")] #[allow( clippy::clippy::cast_sign_loss, clippy::cast_possible_truncation, clippy::clippy::needless_pass_by_value )] -pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json) -> ApiResponse { - let secs: i64 = add_history.timestamp / 1_000_000_000; - let nanosecs: u32 = (add_history.timestamp - (secs * 1_000_000_000)) as u32; - let datetime = chrono::NaiveDateTime::from_timestamp(secs, nanosecs); - - let new_history = NewHistory { - client_id: add_history.id.as_str(), - user_id: user.id, - mac: add_history.mac.as_str(), - timestamp: datetime, - data: add_history.data.as_str(), - }; +pub fn add_history( + conn: AtuinDbConn, + user: User, + add_history: Json>, +) -> ApiResponse { + let new_history: Vec = add_history + .iter() + .map(|h| NewHistory { + client_id: h.id.as_str(), + hostname: h.hostname.to_string(), + user_id: user.id, + timestamp: h.timestamp.naive_utc(), + data: h.data.as_str(), + }) + .collect(); match diesel::insert_into(history::table) .values(&new_history) + .on_conflict_do_nothing() .execute(&*conn) { Ok(_) => ApiResponse { status: Status::Ok, - json: json!({"status": "ok", "message": "history added", "id": new_history.client_id}), + json: json!({"status": "ok", "message": "history added"}), }, Err(_) => ApiResponse { status: Status::BadRequest, @@ -87,3 +91,95 @@ pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json) }, } } + +#[get("/sync/count")] +#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)] +pub fn sync_count(conn: AtuinDbConn, user: User) -> ApiResponse { + use crate::schema::history::dsl::*; + + // we need to return the number of history items we have for this user + // in the future I'd like to use something like a merkel tree to calculate + // which day specifically needs syncing + let count = history + .filter(user_id.eq(user.id)) + .count() + .first::(&*conn); + + if count.is_err() { + error!("failed to count: {}", count.err().unwrap()); + + return ApiResponse { + json: json!({"message": "internal server error"}), + status: Status::InternalServerError, + }; + } + + ApiResponse { + status: Status::Ok, + json: json!({"count": count.ok()}), + } +} + +pub struct UtcDateTime(chrono::DateTime); + +impl<'v> FromFormValue<'v> for UtcDateTime { + type Error = &'v RawStr; + + fn from_form_value(form_value: &'v RawStr) -> Result { + let time = Uri::percent_decode(form_value.as_bytes()).map_err(|_| form_value)?; + let time = time.to_string(); + + match chrono::DateTime::parse_from_rfc3339(time.as_str()) { + Ok(t) => Ok(UtcDateTime(t.with_timezone(&Utc))), + Err(e) => { + error!("failed to parse time {}, got: {}", time, e); + Err(form_value) + } + } + } +} + +// Request a list of all history items added to the DB after a given timestamp. +// Provide the current hostname, so that we don't send the client data that +// originated from them +#[get("/sync/history?&&")] +#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)] +pub fn sync_list( + conn: AtuinDbConn, + user: User, + sync_ts: UtcDateTime, + history_ts: UtcDateTime, + host: String, +) -> ApiResponse { + use crate::schema::history::dsl::*; + + // we need to return the number of history items we have for this user + // in the future I'd like to use something like a merkel tree to calculate + // which day specifically needs syncing + // TODO: Allow for configuring the page size, both from params, and setting + // the max in config. 100 is fine for now. + let h = history + .filter(user_id.eq(user.id)) + .filter(hostname.ne(host)) + .filter(created_at.ge(sync_ts.0.naive_utc())) + .filter(timestamp.ge(history_ts.0.naive_utc())) + .order(timestamp.asc()) + .limit(HISTORY_PAGE_SIZE) + .load::(&*conn); + + if let Err(e) = h { + error!("failed to load history: {}", e); + + return ApiResponse { + json: json!({"message": "internal server error"}), + status: Status::InternalServerError, + }; + } + + let user_data: Vec = h.unwrap().iter().map(|i| i.data.to_string()).collect(); + + ApiResponse { + status: Status::Ok, + json: json!({ "history": user_data }), + } +} diff --git a/src/schema.rs b/src/schema.rs index efa9ddcc..84bf5bab 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -3,9 +3,10 @@ table! { id -> Int8, client_id -> Text, user_id -> Int8, - mac -> Varchar, + hostname -> Text, timestamp -> Timestamp, data -> Varchar, + created_at -> Timestamp, } } @@ -20,6 +21,7 @@ table! { table! { users (id) { id -> Int8, + username -> Varchar, email -> Varchar, password -> Varchar, } diff --git a/src/settings.rs b/src/settings.rs index 0e554bed..dcf69a7c 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,31 +1,90 @@ -use std::path::PathBuf; +use std::fs::{create_dir_all, File}; +use std::io::prelude::*; +use std::path::{Path, PathBuf}; -use config::{Config, File}; +use chrono::prelude::*; +use chrono::Utc; +use config::{Config, File as ConfigFile}; use directories::ProjectDirs; use eyre::{eyre, Result}; -use std::fs; +use parse_duration::parse; -#[derive(Debug, Deserialize)] +pub const HISTORY_PAGE_SIZE: i64 = 100; + +#[derive(Clone, Debug, Deserialize)] pub struct Local { pub dialect: String, - pub sync: bool, + pub auto_sync: bool, pub sync_address: String, pub sync_frequency: String, pub db_path: String, + pub key_path: String, + pub session_path: String, + + // This is automatically loaded when settings is created. Do not set in + // config! Keep secrets and settings apart. + pub session_token: String, } -#[derive(Debug, Deserialize)] -pub struct Remote { +impl Local { + pub fn save_sync_time() -> Result<()> { + let sync_time_path = ProjectDirs::from("com", "elliehuxtable", "atuin") + .ok_or_else(|| eyre!("could not determine key file location"))?; + let sync_time_path = sync_time_path.data_dir().join("last_sync_time"); + + std::fs::write(sync_time_path, Utc::now().to_rfc3339())?; + + Ok(()) + } + + pub fn last_sync() -> Result> { + let sync_time_path = ProjectDirs::from("com", "elliehuxtable", "atuin"); + + if sync_time_path.is_none() { + debug!("failed to load projectdirs, not syncing"); + return Err(eyre!("could not load project dirs")); + } + + let sync_time_path = sync_time_path.unwrap(); + let sync_time_path = sync_time_path.data_dir().join("last_sync_time"); + + if !sync_time_path.exists() { + return Ok(Utc.ymd(1970, 1, 1).and_hms(0, 0, 0)); + } + + let time = std::fs::read_to_string(sync_time_path)?; + let time = chrono::DateTime::parse_from_rfc3339(time.as_str())?; + + Ok(time.with_timezone(&Utc)) + } + + pub fn should_sync(&self) -> Result { + if !self.auto_sync { + return Ok(false); + } + + match parse(self.sync_frequency.as_str()) { + Ok(d) => { + let d = chrono::Duration::from_std(d).unwrap(); + Ok(Utc::now() - Local::last_sync()? >= d) + } + Err(e) => Err(eyre!("failed to check sync: {}", e)), + } + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Server { pub host: String, pub port: u16, pub db_uri: String, pub open_registration: bool, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct Settings { pub local: Local, - pub remote: Remote, + pub server: Server, } impl Settings { @@ -33,7 +92,7 @@ impl Settings { let config_dir = ProjectDirs::from("com", "elliehuxtable", "atuin").unwrap(); let config_dir = config_dir.config_dir(); - fs::create_dir_all(config_dir)?; + create_dir_all(config_dir)?; let mut config_file = PathBuf::new(); config_file.push(config_dir); @@ -45,31 +104,61 @@ impl Settings { let mut s = Config::new(); let db_path = ProjectDirs::from("com", "elliehuxtable", "atuin") - .ok_or_else(|| { - eyre!("could not determine db file location\nspecify one using the --db flag") - })? + .ok_or_else(|| eyre!("could not determine db file location"))? .data_dir() .join("history.db"); + let key_path = ProjectDirs::from("com", "elliehuxtable", "atuin") + .ok_or_else(|| eyre!("could not determine key file location"))? + .data_dir() + .join("key"); + + let session_path = ProjectDirs::from("com", "elliehuxtable", "atuin") + .ok_or_else(|| eyre!("could not determine session file location"))? + .data_dir() + .join("session"); + s.set_default("local.db_path", db_path.to_str())?; + s.set_default("local.key_path", key_path.to_str())?; + s.set_default("local.session_path", session_path.to_str())?; s.set_default("local.dialect", "us")?; - s.set_default("local.sync", false)?; + s.set_default("local.auto_sync", true)?; s.set_default("local.sync_frequency", "5m")?; - s.set_default("local.sync_address", "https://atuin.ellie.wtf")?; + s.set_default("local.sync_address", "https://api.atuin.sh")?; - s.set_default("remote.host", "127.0.0.1")?; - s.set_default("remote.port", 8888)?; - s.set_default("remote.open_registration", false)?; - s.set_default("remote.db_uri", "please set a postgres url")?; + s.set_default("server.host", "127.0.0.1")?; + s.set_default("server.port", 8888)?; + s.set_default("server.open_registration", false)?; + s.set_default("server.db_uri", "please set a postgres url")?; if config_file.exists() { - s.merge(File::with_name(config_file.to_str().unwrap()))?; + s.merge(ConfigFile::with_name(config_file.to_str().unwrap()))?; + } else { + let example_config = include_bytes!("../config.toml"); + let mut file = File::create(config_file)?; + file.write_all(example_config)?; } // all paths should be expanded let db_path = s.get_str("local.db_path")?; let db_path = shellexpand::full(db_path.as_str())?; - s.set("local.db.path", db_path.to_string())?; + s.set("local.db_path", db_path.to_string())?; + + let key_path = s.get_str("local.key_path")?; + let key_path = shellexpand::full(key_path.as_str())?; + s.set("local.key_path", key_path.to_string())?; + + let session_path = s.get_str("local.session_path")?; + let session_path = shellexpand::full(session_path.as_str())?; + s.set("local.session_path", session_path.to_string())?; + + // Finally, set the auth token + if Path::new(session_path.to_string().as_str()).exists() { + let token = std::fs::read_to_string(session_path.to_string())?; + s.set("local.session_token", token)?; + } else { + s.set("local.session_token", "not logged in")?; + } s.try_into() .map_err(|e| eyre!("failed to deserialize: {}", e)) diff --git a/src/shell/atuin.zsh b/src/shell/atuin.zsh index 8407efd2..d2abf3c1 100644 --- a/src/shell/atuin.zsh +++ b/src/shell/atuin.zsh @@ -1,4 +1,6 @@ # Source this in your ~/.zshrc +autoload -U add-zsh-hook + export ATUIN_SESSION=$(atuin uuid) export ATUIN_HISTORY="atuin history list" export ATUIN_BINDKEYS="true" @@ -20,24 +22,12 @@ _atuin_search(){ emulate -L zsh zle -I + # Switch to cursor mode, then back to application + echoti rmkx # swap stderr and stdout, so that the tui stuff works # TODO: not this output=$(atuin search $BUFFER 3>&1 1>&2 2>&3) - - if [[ -n $output ]] ; then - LBUFFER=$output - fi - - zle reset-prompt -} - -_atuin_up_search(){ - emulate -L zsh - zle -I - - # swap stderr and stdout, so that the tui stuff works - # TODO: not this - output=$(atuin search $BUFFER 3>&1 1>&2 2>&3) + echoti smkx if [[ -n $output ]] ; then LBUFFER=$output @@ -50,9 +40,11 @@ add-zsh-hook preexec _atuin_preexec add-zsh-hook precmd _atuin_precmd zle -N _atuin_search_widget _atuin_search -zle -N _atuin_up_search_widget _atuin_up_search if [[ $ATUIN_BINDKEYS == "true" ]]; then bindkey '^r' _atuin_search_widget - bindkey '^[[A' _atuin_up_search_widget + + # depends on terminal mode + bindkey '^[[A' _atuin_search_widget + bindkey '^[OA' _atuin_search_widget fi diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 00000000..b395b148 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,24 @@ +use crypto::digest::Digest; +use crypto::sha2::Sha256; +use sodiumoxide::crypto::pwhash::argon2id13; + +pub fn hash_secret(secret: &str) -> String { + sodiumoxide::init().unwrap(); + let hash = argon2id13::pwhash( + secret.as_bytes(), + argon2id13::OPSLIMIT_INTERACTIVE, + argon2id13::MEMLIMIT_INTERACTIVE, + ) + .unwrap(); + let texthash = std::str::from_utf8(&hash.0).unwrap().to_string(); + + // postgres hates null chars. don't do that to postgres + texthash.trim_end_matches('\u{0}').to_string() +} + +pub fn hash_str(string: &str) -> String { + let mut hasher = Sha256::new(); + hasher.input_str(string); + + hasher.result_str() +} -- cgit v1.3.1