aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorEllie Huxtable <e@elm.sh>2021-04-13 19:14:07 +0100
committerGitHub <noreply@github.com>2021-04-13 19:14:07 +0100
commit5751463942cc91f1f1ffaf6e2ac633d7a0085f25 (patch)
treef7b5b9a4702c4c3ef29aa60d36612f61ffeae052 /src
parentUpdate config (diff)
downloadatuin-5751463942cc91f1f1ffaf6e2ac633d7a0085f25.zip
Add history sync, resolves #13 (#31)
* Add encryption * Add login and register command * Add count endpoint * Write initial sync push * Add single sync command Confirmed working for one client only * Automatically sync on a configurable frequency * Add key command, key arg to login * Only load session if it exists * Use sync and history timestamps for download * Bind other key code Seems like some systems have this code for up arrow? I'm not sure why, and it's not an easy one to google. * Simplify upload * Try and fix download sync loop * Change sync order to avoid uploading what we just downloaded * Multiline import fix * Fix time parsing * Fix importing history with no time * Add hostname to sync * Use hostname to filter sync * Fixes * Add binding * Stuff from yesterday * Set cursor modes * Make clippy happy * Bump version
Diffstat (limited to 'src')
-rw-r--r--src/api.rs36
-rw-r--r--src/command/history.rs30
-rw-r--r--src/command/login.rs48
-rw-r--r--src/command/mod.rs34
-rw-r--r--src/command/register.rs54
-rw-r--r--src/command/search.rs3
-rw-r--r--src/command/server.rs4
-rw-r--r--src/command/sync.rs15
-rw-r--r--src/local/api_client.rs94
-rw-r--r--src/local/database.rs55
-rw-r--r--src/local/encryption.rs108
-rw-r--r--src/local/history.rs11
-rw-r--r--src/local/import.rs116
-rw-r--r--src/local/mod.rs3
-rw-r--r--src/local/sync.rs135
-rw-r--r--src/main.rs19
-rw-r--r--src/remote/auth.rs92
-rw-r--r--src/remote/database.rs2
-rw-r--r--src/remote/models.rs16
-rw-r--r--src/remote/server.rs26
-rw-r--r--src/remote/views.rs144
-rw-r--r--src/schema.rs4
-rw-r--r--src/settings.rs131
-rw-r--r--src/shell/atuin.zsh26
-rw-r--r--src/utils.rs24
25 files changed, 1054 insertions, 176 deletions
diff --git a/src/api.rs b/src/api.rs
new file mode 100644
index 00000000..90977404
--- /dev/null
+++ b/src/api.rs
@@ -0,0 +1,36 @@
+use chrono::Utc;
+
+// This is shared between the client and the server, and has the data structures
+// representing the requests/responses for each method.
+// TODO: Properly define responses rather than using json!
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct RegisterRequest {
+ pub email: String,
+ pub username: String,
+ pub password: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct LoginRequest {
+ pub username: String,
+ pub password: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct AddHistoryRequest {
+ pub id: String,
+ pub timestamp: chrono::DateTime<Utc>,
+ pub data: String,
+ pub hostname: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CountResponse {
+ pub count: i64,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ListHistoryResponse {
+ pub history: Vec<String>,
+}
diff --git a/src/command/history.rs b/src/command/history.rs
index 05aed4b9..3b4a717c 100644
--- a/src/command/history.rs
+++ b/src/command/history.rs
@@ -1,10 +1,13 @@
use std::env;
use eyre::Result;
+use fork::{fork, Fork};
use structopt::StructOpt;
use crate::local::database::Database;
use crate::local::history::History;
+use crate::local::sync;
+use crate::settings::Settings;
#[derive(StructOpt)]
pub enum Cmd {
@@ -50,21 +53,13 @@ fn print_list(h: &[History]) {
}
impl Cmd {
- pub fn run(&self, db: &mut impl Database) -> Result<()> {
+ pub fn run(&self, settings: &Settings, db: &mut impl Database) -> Result<()> {
match self {
Self::Start { command: words } => {
let command = words.join(" ");
let cwd = env::current_dir()?.display().to_string();
- let h = History::new(
- chrono::Utc::now().timestamp_nanos(),
- command,
- cwd,
- -1,
- -1,
- None,
- None,
- );
+ let h = History::new(chrono::Utc::now(), command, cwd, -1, -1, None, None);
// print the ID
// we use this as the key for calling end
@@ -76,10 +71,23 @@ impl Cmd {
Self::End { id, exit } => {
let mut h = db.load(id)?;
h.exit = *exit;
- h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp;
+ h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp.timestamp_nanos();
db.update(&h)?;
+ if settings.local.should_sync()? {
+ match fork() {
+ Ok(Fork::Parent(child)) => {
+ debug!("launched sync background process with PID {}", child);
+ }
+ Ok(Fork::Child) => {
+ debug!("running periodic background sync");
+ sync::sync(settings, false, db)?;
+ }
+ Err(_) => println!("Fork failed"),
+ }
+ }
+
Ok(())
}
diff --git a/src/command/login.rs b/src/command/login.rs
new file mode 100644
index 00000000..4f58b77f
--- /dev/null
+++ b/src/command/login.rs
@@ -0,0 +1,48 @@
+use std::collections::HashMap;
+use std::fs::File;
+use std::io::prelude::*;
+
+use eyre::Result;
+use structopt::StructOpt;
+
+use crate::settings::Settings;
+
+#[derive(StructOpt)]
+#[structopt(setting(structopt::clap::AppSettings::DeriveDisplayOrder))]
+pub struct Cmd {
+ #[structopt(long, short)]
+ pub username: String,
+
+ #[structopt(long, short)]
+ pub password: String,
+
+ #[structopt(long, short, about = "the encryption key for your account")]
+ pub key: String,
+}
+
+impl Cmd {
+ pub fn run(&self, settings: &Settings) -> Result<()> {
+ let mut map = HashMap::new();
+ map.insert("username", self.username.clone());
+ map.insert("password", self.password.clone());
+
+ let url = format!("{}/login", settings.local.sync_address);
+ let client = reqwest::blocking::Client::new();
+ let resp = client.post(url).json(&map).send()?;
+
+ let session = resp.json::<HashMap<String, String>>()?;
+ let session = session["session"].clone();
+
+ let session_path = settings.local.session_path.as_str();
+ let mut file = File::create(session_path)?;
+ file.write_all(session.as_bytes())?;
+
+ let key_path = settings.local.key_path.as_str();
+ let mut file = File::create(key_path)?;
+ file.write_all(&base64::decode(self.key.clone())?)?;
+
+ println!("Logged in!");
+
+ Ok(())
+ }
+}
diff --git a/src/command/mod.rs b/src/command/mod.rs
index a5ea0228..eeb11a87 100644
--- a/src/command/mod.rs
+++ b/src/command/mod.rs
@@ -9,9 +9,12 @@ mod event;
mod history;
mod import;
mod init;
+mod login;
+mod register;
mod search;
mod server;
mod stats;
+mod sync;
#[derive(StructOpt)]
pub enum AtuinCmd {
@@ -38,6 +41,21 @@ pub enum AtuinCmd {
#[structopt(about = "interactive history search")]
Search { query: Vec<String> },
+
+ #[structopt(about = "sync with the configured server")]
+ Sync {
+ #[structopt(long, short, about = "force re-download everything")]
+ force: bool,
+ },
+
+ #[structopt(about = "login to the configured server")]
+ Login(login::Cmd),
+
+ #[structopt(about = "register with the configured server")]
+ Register(register::Cmd),
+
+ #[structopt(about = "print the encryption key for transfer to another machine")]
+ Key,
}
pub fn uuid_v4() -> String {
@@ -47,13 +65,27 @@ pub fn uuid_v4() -> String {
impl AtuinCmd {
pub fn run(self, db: &mut impl Database, settings: &Settings) -> Result<()> {
match self {
- Self::History(history) => history.run(db),
+ Self::History(history) => history.run(settings, db),
Self::Import(import) => import.run(db),
Self::Server(server) => server.run(settings),
Self::Stats(stats) => stats.run(db, settings),
Self::Init => init::init(),
Self::Search { query } => search::run(&query, db),
+ Self::Sync { force } => sync::run(settings, force, db),
+ Self::Login(l) => l.run(settings),
+ Self::Register(r) => register::run(
+ settings,
+ r.username.as_str(),
+ r.email.as_str(),
+ r.password.as_str(),
+ ),
+ Self::Key => {
+ let key = std::fs::read(settings.local.key_path.as_str())?;
+ println!("{}", base64::encode(key));
+ Ok(())
+ }
+
Self::Uuid => {
println!("{}", uuid_v4());
Ok(())
diff --git a/src/command/register.rs b/src/command/register.rs
new file mode 100644
index 00000000..62bbeaeb
--- /dev/null
+++ b/src/command/register.rs
@@ -0,0 +1,54 @@
+use std::collections::HashMap;
+use std::fs::File;
+use std::io::prelude::*;
+
+use eyre::{eyre, Result};
+use structopt::StructOpt;
+
+use crate::settings::Settings;
+
+#[derive(StructOpt)]
+#[structopt(setting(structopt::clap::AppSettings::DeriveDisplayOrder))]
+pub struct Cmd {
+ #[structopt(long, short)]
+ pub username: String,
+
+ #[structopt(long, short)]
+ pub email: String,
+
+ #[structopt(long, short)]
+ pub password: String,
+}
+
+pub fn run(settings: &Settings, username: &str, email: &str, password: &str) -> Result<()> {
+ let mut map = HashMap::new();
+ map.insert("username", username);
+ map.insert("email", email);
+ map.insert("password", password);
+
+ let url = format!("{}/user/{}", settings.local.sync_address, username);
+ let resp = reqwest::blocking::get(url)?;
+
+ if resp.status().is_success() {
+ println!("Username is already in use! Please try another.");
+ return Ok(());
+ }
+
+ let url = format!("{}/register", settings.local.sync_address);
+ let client = reqwest::blocking::Client::new();
+ let resp = client.post(url).json(&map).send()?;
+
+ if !resp.status().is_success() {
+ println!("Failed to register user - please check your details and try again");
+ return Err(eyre!("failed to register user"));
+ }
+
+ let session = resp.json::<HashMap<String, String>>()?;
+ let session = session["session"].clone();
+
+ let path = settings.local.session_path.as_str();
+ let mut file = File::create(path)?;
+ file.write_all(session.as_bytes())?;
+
+ Ok(())
+}
diff --git a/src/command/search.rs b/src/command/search.rs
index d51e29ef..b9f3987c 100644
--- a/src/command/search.rs
+++ b/src/command/search.rs
@@ -171,7 +171,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result<String> {
.iter()
.enumerate()
.map(|(i, m)| {
- let mut content = Span::raw(m.command.to_string());
+ let mut content =
+ Span::raw(m.command.to_string().replace("\n", " ").replace("\t", " "));
if let Some(selected) = app.results_state.selected() {
if selected == i {
diff --git a/src/command/server.rs b/src/command/server.rs
index 5156f409..ba2a9a2f 100644
--- a/src/command/server.rs
+++ b/src/command/server.rs
@@ -24,10 +24,10 @@ impl Cmd {
match self {
Self::Start { host, port } => {
let host = host.as_ref().map_or(
- settings.remote.host.clone(),
+ settings.server.host.clone(),
std::string::ToString::to_string,
);
- let port = port.map_or(settings.remote.port, |p| p);
+ let port = port.map_or(settings.server.port, |p| p);
server::launch(settings, host, port);
}
diff --git a/src/command/sync.rs b/src/command/sync.rs
new file mode 100644
index 00000000..facbe578
--- /dev/null
+++ b/src/command/sync.rs
@@ -0,0 +1,15 @@
+use eyre::Result;
+
+use crate::local::database::Database;
+use crate::local::sync;
+use crate::settings::Settings;
+
+pub fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
+ sync::sync(settings, force, db)?;
+ println!(
+ "Sync complete! {} items in database, force: {}",
+ db.history_count()?,
+ force
+ );
+ Ok(())
+}
diff --git a/src/local/api_client.rs b/src/local/api_client.rs
new file mode 100644
index 00000000..434c07ba
--- /dev/null
+++ b/src/local/api_client.rs
@@ -0,0 +1,94 @@
+use chrono::Utc;
+use eyre::Result;
+use reqwest::header::AUTHORIZATION;
+
+use crate::api::{AddHistoryRequest, CountResponse, ListHistoryResponse};
+use crate::local::encryption::{decrypt, load_key};
+use crate::local::history::History;
+use crate::settings::Settings;
+use crate::utils::hash_str;
+
+pub struct Client<'a> {
+ settings: &'a Settings,
+}
+
+impl<'a> Client<'a> {
+ pub const fn new(settings: &'a Settings) -> Self {
+ Client { settings }
+ }
+
+ pub fn count(&self) -> Result<i64> {
+ let url = format!("{}/sync/count", self.settings.local.sync_address);
+ let client = reqwest::blocking::Client::new();
+
+ let resp = client
+ .get(url)
+ .header(
+ AUTHORIZATION,
+ format!("Token {}", self.settings.local.session_token),
+ )
+ .send()?;
+
+ let count = resp.json::<CountResponse>()?;
+
+ Ok(count.count)
+ }
+
+ pub fn get_history(
+ &self,
+ sync_ts: chrono::DateTime<Utc>,
+ history_ts: chrono::DateTime<Utc>,
+ host: Option<String>,
+ ) -> Result<Vec<History>> {
+ let key = load_key(self.settings)?;
+
+ let host = match host {
+ None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())),
+ Some(h) => h,
+ };
+
+ // this allows for syncing between users on the same machine
+ let url = format!(
+ "{}/sync/history?sync_ts={}&history_ts={}&host={}",
+ self.settings.local.sync_address,
+ sync_ts.to_rfc3339(),
+ history_ts.to_rfc3339(),
+ host,
+ );
+ let client = reqwest::blocking::Client::new();
+
+ let resp = client
+ .get(url)
+ .header(
+ AUTHORIZATION,
+ format!("Token {}", self.settings.local.session_token),
+ )
+ .send()?;
+
+ let history = resp.json::<ListHistoryResponse>()?;
+ let history = history
+ .history
+ .iter()
+ .map(|h| serde_json::from_str(h).expect("invalid base64"))
+ .map(|h| decrypt(&h, &key).expect("failed to decrypt history! check your key"))
+ .collect();
+
+ Ok(history)
+ }
+
+ pub fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
+ let client = reqwest::blocking::Client::new();
+
+ let url = format!("{}/history", self.settings.local.sync_address);
+ client
+ .post(url)
+ .json(history)
+ .header(
+ AUTHORIZATION,
+ format!("Token {}", self.settings.local.session_token),
+ )
+ .send()?;
+
+ Ok(())
+ }
+}
diff --git a/src/local/database.rs b/src/local/database.rs
index ad7078e5..977f11cc 100644
--- a/src/local/database.rs
+++ b/src/local/database.rs
@@ -1,3 +1,4 @@
+use chrono::prelude::*;
use chrono::Utc;
use std::path::Path;
@@ -21,6 +22,10 @@ pub trait Database {
fn update(&self, h: &History) -> Result<()>;
fn history_count(&self) -> Result<i64>;
+ fn first(&self) -> Result<History>;
+ fn last(&self) -> Result<History>;
+ fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>;
+
fn prefix_search(&self, query: &str) -> Result<Vec<History>>;
}
@@ -44,9 +49,7 @@ impl Sqlite {
let conn = Connection::open(path)?;
- if create {
- Self::setup_db(&conn)?;
- }
+ Self::setup_db(&conn)?;
Ok(Self { conn })
}
@@ -70,6 +73,14 @@ impl Sqlite {
[],
)?;
+ conn.execute(
+ "create table if not exists history_encrypted (
+ id text primary key,
+ data blob not null
+ )",
+ [],
+ )?;
+
Ok(())
}
@@ -87,7 +98,7 @@ impl Sqlite {
) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
params![
h.id,
- h.timestamp,
+ h.timestamp.timestamp_nanos(),
h.duration,
h.exit,
h.command,
@@ -146,7 +157,7 @@ impl Database for Sqlite {
"update history
set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8
where id = ?1",
- params![h.id, h.timestamp, h.duration, h.exit, h.command, h.cwd, h.session, h.hostname],
+ params![h.id, h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, h.cwd, h.session, h.hostname],
)?;
Ok(())
@@ -183,6 +194,38 @@ impl Database for Sqlite {
Ok(history_iter.filter_map(Result::ok).collect())
}
+ fn first(&self) -> Result<History> {
+ let mut stmt = self
+ .conn
+ .prepare("SELECT * FROM history order by timestamp asc limit 1")?;
+
+ let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?;
+
+ Ok(history)
+ }
+
+ fn last(&self) -> Result<History> {
+ let mut stmt = self
+ .conn
+ .prepare("SELECT * FROM history order by timestamp desc limit 1")?;
+
+ let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?;
+
+ Ok(history)
+ }
+
+ fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT * FROM history where timestamp <= ? order by timestamp desc limit ?",
+ )?;
+
+ let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| {
+ history_from_sqlite_row(None, row)
+ })?;
+
+ Ok(history_iter.filter_map(Result::ok).collect())
+ }
+
fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>> {
let mut stmt = self.conn.prepare(query)?;
@@ -218,7 +261,7 @@ fn history_from_sqlite_row(
Ok(History {
id,
- timestamp: row.get(1)?,
+ timestamp: Utc.timestamp_nanos(row.get(1)?),
duration: row.get(2)?,
exit: row.get(3)?,
command: row.get(4)?,
diff --git a/src/local/encryption.rs b/src/local/encryption.rs
new file mode 100644
index 00000000..3c1699e3
--- /dev/null
+++ b/src/local/encryption.rs
@@ -0,0 +1,108 @@
+// The general idea is that we NEVER send cleartext history to the server
+// This way the odds of anything private ending up where it should not are
+// very low
+// The server authenticates via the usual username and password. This has
+// nothing to do with the encryption, and is purely authentication! The client
+// generates its own secret key, and encrypts all shell history with libsodium's
+// secretbox. The data is then sent to the server, where it is stored. All
+// clients must share the secret in order to be able to sync, as it is needed
+// to decrypt
+
+use std::fs::File;
+use std::io::prelude::*;
+use std::path::PathBuf;
+
+use eyre::{eyre, Result};
+use sodiumoxide::crypto::secretbox;
+
+use crate::local::history::History;
+use crate::settings::Settings;
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct EncryptedHistory {
+ pub ciphertext: Vec<u8>,
+ pub nonce: secretbox::Nonce,
+}
+
+// Loads the secret key, will create + save if it doesn't exist
+pub fn load_key(settings: &Settings) -> Result<secretbox::Key> {
+ let path = settings.local.key_path.as_str();
+
+ if PathBuf::from(path).exists() {
+ let bytes = std::fs::read(path)?;
+ let key: secretbox::Key = rmp_serde::from_read_ref(&bytes)?;
+ Ok(key)
+ } else {
+ let key = secretbox::gen_key();
+ let buf = rmp_serde::to_vec(&key)?;
+
+ let mut file = File::create(path)?;
+ file.write_all(&buf)?;
+
+ Ok(key)
+ }
+}
+
+pub fn encrypt(history: &History, key: &secretbox::Key) -> Result<EncryptedHistory> {
+ // serialize with msgpack
+ let buf = rmp_serde::to_vec(history)?;
+
+ let nonce = secretbox::gen_nonce();
+
+ let ciphertext = secretbox::seal(&buf, &nonce, key);
+
+ Ok(EncryptedHistory { ciphertext, nonce })
+}
+
+pub fn decrypt(encrypted_history: &EncryptedHistory, key: &secretbox::Key) -> Result<History> {
+ let plaintext = secretbox::open(&encrypted_history.ciphertext, &encrypted_history.nonce, key)
+ .map_err(|_| eyre!("failed to open secretbox - invalid key?"))?;
+
+ let history = rmp_serde::from_read_ref(&plaintext)?;
+
+ Ok(history)
+}
+
+#[cfg(test)]
+mod test {
+ use sodiumoxide::crypto::secretbox;
+
+ use crate::local::history::History;
+
+ use super::{decrypt, encrypt};
+
+ #[test]
+ fn test_encrypt_decrypt() {
+ let key1 = secretbox::gen_key();
+ let key2 = secretbox::gen_key();
+
+ let history = History::new(
+ chrono::Utc::now(),
+ "ls".to_string(),
+ "/home/ellie".to_string(),
+ 0,
+ 1,
+ Some("beep boop".to_string()),
+ Some("booop".to_string()),
+ );
+
+ let e1 = encrypt(&history, &key1).unwrap();
+ let e2 = encrypt(&history, &key2).unwrap();
+
+ assert_ne!(e1.ciphertext, e2.ciphertext);
+ assert_ne!(e1.nonce, e2.nonce);
+
+ // test decryption works
+ // this should pass
+ match decrypt(&e1, &key1) {
+ Err(e) => assert!(false, "failed to decrypt, got {}", e),
+ Ok(h) => assert_eq!(h, history),
+ };
+
+ // this should err
+ match decrypt(&e2, &key1) {
+ Ok(_) => assert!(false, "expected an error decrypting with invalid key"),
+ Err(_) => {}
+ };
+ }
+}
diff --git a/src/local/history.rs b/src/local/history.rs
index 0ca112bd..1712f8b9 100644
--- a/src/local/history.rs
+++ b/src/local/history.rs
@@ -1,12 +1,15 @@
use std::env;
use std::hash::{Hash, Hasher};
+use chrono::Utc;
+
use crate::command::uuid_v4;
-#[derive(Debug, Clone)]
+// Any new fields MUST be Optional<>!
+#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct History {
pub id: String,
- pub timestamp: i64,
+ pub timestamp: chrono::DateTime<Utc>,
pub duration: i64,
pub exit: i64,
pub command: String,
@@ -17,7 +20,7 @@ pub struct History {
impl History {
pub fn new(
- timestamp: i64,
+ timestamp: chrono::DateTime<Utc>,
command: String,
cwd: String,
exit: i64,
@@ -29,7 +32,7 @@ impl History {
.or_else(|| env::var("ATUIN_SESSION").ok())
.unwrap_or_else(uuid_v4);
let hostname =
- hostname.unwrap_or_else(|| hostname::get().unwrap().to_str().unwrap().to_string());
+ hostname.unwrap_or_else(|| format!("{}:{}", whoami::hostname(), whoami::username()));
Self {
id: uuid_v4(),
diff --git a/src/local/import.rs b/src/local/import.rs
index 9bf79c72..d0f679c9 100644
--- a/src/local/import.rs
+++ b/src/local/import.rs
@@ -4,7 +4,9 @@
use std::io::{BufRead, BufReader, Seek, SeekFrom};
use std::{fs::File, path::Path};
-use eyre::{Result, WrapErr};
+use chrono::prelude::*;
+use chrono::Utc;
+use eyre::{eyre, Result};
use super::history::History;
@@ -13,6 +15,7 @@ pub struct Zsh {
file: BufReader<File>,
pub loc: u64,
+ pub counter: i64,
}
// this could probably be sped up
@@ -32,19 +35,23 @@ impl Zsh {
Ok(Self {
file: buf,
loc: loc as u64,
+ counter: 0,
})
}
}
-fn parse_extended(line: &str) -> History {
+fn parse_extended(line: &str, counter: i64) -> History {
let line = line.replacen(": ", "", 2);
let (time, duration) = line.split_once(':').unwrap();
let (duration, command) = duration.split_once(';').unwrap();
- let time = time.parse::<i64>().map_or_else(
- |_| chrono::Utc::now().timestamp_nanos(),
- |t| t * 1_000_000_000,
- );
+ let time = time
+ .parse::<i64>()
+ .unwrap_or_else(|_| chrono::Utc::now().timestamp());
+
+ let offset = chrono::Duration::milliseconds(counter);
+ let time = Utc.timestamp(time, 0);
+ let time = time + offset;
let duration = duration.parse::<i64>().map_or(-1, |t| t * 1_000_000_000);
@@ -60,6 +67,18 @@ fn parse_extended(line: &str) -> History {
)
}
+impl Zsh {
+ fn read_line(&mut self) -> Option<Result<String>> {
+ let mut line = String::new();
+
+ match self.file.read_line(&mut line) {
+ Ok(0) => None,
+ Ok(_) => Some(Ok(line)),
+ Err(e) => Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8
+ }
+ }
+}
+
impl Iterator for Zsh {
type Item = Result<History>;
@@ -68,54 +87,89 @@ impl Iterator for Zsh {
// These lines begin with :
// So, if the line begins with :, parse it. Otherwise it's just
// the command
- let mut line = String::new();
+ let line = self.read_line()?;
- match self.file.read_line(&mut line) {
- Ok(0) => None,
- Ok(_) => {
- let extended = line.starts_with(':');
+ if let Err(e) = line {
+ return Some(Err(e)); // :(
+ }
- if extended {
- Some(Ok(parse_extended(line.as_str())))
- } else {
- Some(Ok(History::new(
- chrono::Utc::now().timestamp_nanos(), // what else? :/
- line.trim_end().to_string(),
- String::from("unknown"),
- -1,
- -1,
- None,
- None,
- )))
- }
+ let mut line = line.unwrap();
+
+ while line.ends_with("\\\n") {
+ let next_line = self.read_line()?;
+
+ if next_line.is_err() {
+ // There's a chance that the last line of a command has invalid
+ // characters, the only safe thing to do is break :/
+ // usually just invalid utf8 or smth
+ // however, we really need to avoid missing history, so it's
+ // better to have some items that should have been part of
+ // something else, than to miss things. So break.
+ break;
}
- Err(e) => Some(Err(e).wrap_err("failed to parse line")),
+
+ line.push_str(next_line.unwrap().as_str());
+ }
+
+ // We have to handle the case where a line has escaped newlines.
+ // Keep reading until we have a non-escaped newline
+
+ let extended = line.starts_with(':');
+
+ if extended {
+ self.counter += 1;
+ Some(Ok(parse_extended(line.as_str(), self.counter)))
+ } else {
+ let time = chrono::Utc::now();
+ let offset = chrono::Duration::seconds(self.counter);
+ let time = time - offset;
+
+ self.counter += 1;
+
+ Some(Ok(History::new(
+ time,
+ line.trim_end().to_string(),
+ String::from("unknown"),
+ -1,
+ -1,
+ None,
+ None,
+ )))
}
}
}
#[cfg(test)]
mod test {
+ use chrono::prelude::*;
+ use chrono::Utc;
+
use super::parse_extended;
#[test]
fn test_parse_extended_simple() {
- let parsed = parse_extended(": 1613322469:0;cargo install atuin");
+ let parsed = parse_extended(": 1613322469:0;cargo install atuin", 0);
assert_eq!(parsed.command, "cargo install atuin");
assert_eq!(parsed.duration, 0);
- assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000);
+ assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0));
- let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update");
+ let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update", 0);
assert_eq!(parsed.command, "cargo install atuin;cargo update");
assert_eq!(parsed.duration, 10_000_000_000);
- assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000);
+ assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0));
- let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷");
+ let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0);
assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷");
assert_eq!(parsed.duration, 10_000_000_000);
- assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000);
+ assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0));
+
+ let parsed = parse_extended(": 1613322469:10;cargo install \\n atuin\n", 0);
+
+ assert_eq!(parsed.command, "cargo install \\n atuin");
+ assert_eq!(parsed.duration, 10_000_000_000);
+ assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0));
}
}
diff --git a/src/local/mod.rs b/src/local/mod.rs
index a11ee213..9fe31292 100644
--- a/src/local/mod.rs
+++ b/src/local/mod.rs
@@ -1,3 +1,6 @@
+pub mod api_client;
pub mod database;
+pub mod encryption;
pub mod history;
pub mod import;
+pub mod sync;
diff --git a/src/local/sync.rs b/src/local/sync.rs
new file mode 100644
index 00000000..c22d2f27
--- /dev/null
+++ b/src/local/sync.rs
@@ -0,0 +1,135 @@
+use std::convert::TryInto;
+
+use chrono::prelude::*;
+use eyre::Result;
+
+use crate::local::api_client;
+use crate::local::database::Database;
+use crate::local::encryption::{encrypt, load_key};
+use crate::settings::{Local, Settings, HISTORY_PAGE_SIZE};
+use crate::{api::AddHistoryRequest, utils::hash_str};
+
+// Currently sync is kinda naive, and basically just pages backwards through
+// history. This means newly added stuff shows up properly! We also just use
+// the total count in each database to indicate whether a sync is needed.
+// I think this could be massively improved! If we had a way of easily
+// indicating count per time period (hour, day, week, year, etc) then we can
+// easily pinpoint where we are missing data and what needs downloading. Start
+// with year, then find the week, then the day, then the hour, then download it
+// all! The current naive approach will do for now.
+
+// Check if remote has things we don't, and if so, download them.
+// Returns (num downloaded, total local)
+fn sync_download(
+ force: bool,
+ client: &api_client::Client,
+ db: &mut impl Database,
+) -> Result<(i64, i64)> {
+ let remote_count = client.count()?;
+
+ let initial_local = db.history_count()?;
+ let mut local_count = initial_local;
+
+ let mut last_sync = if force {
+ Utc.timestamp_millis(0)
+ } else {
+ Local::last_sync()?
+ };
+
+ let mut last_timestamp = Utc.timestamp_millis(0);
+
+ let host = if force { Some(String::from("")) } else { None };
+
+ while remote_count > local_count {
+ let page = client.get_history(last_sync, last_timestamp, host.clone())?;
+
+ if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() {
+ break;
+ }
+
+ db.save_bulk(&page)?;
+
+ local_count = db.history_count()?;
+
+ let page_last = page
+ .last()
+ .expect("could not get last element of page")
+ .timestamp;
+
+ // in the case of a small sync frequency, it's possible for history to
+ // be "lost" between syncs. In this case we need to rewind the sync
+ // timestamps
+ if page_last == last_timestamp {
+ last_timestamp = Utc.timestamp_millis(0);
+ last_sync = last_sync - chrono::Duration::hours(1);
+ } else {
+ last_timestamp = page_last;
+ }
+ }
+
+ Ok((local_count - initial_local, local_count))
+}
+
+// Check if we have things remote doesn't, and if so, upload them
+fn sync_upload(
+ settings: &Settings,
+ _force: bool,
+ client: &api_client::Client,
+ db: &mut impl Database,
+) -> Result<()> {
+ let initial_remote_count = client.count()?;
+ let mut remote_count = initial_remote_count;
+
+ let local_count = db.history_count()?;
+
+ let key = load_key(settings)?; // encryption key
+
+ // first just try the most recent set
+
+ let mut cursor = Utc::now();
+
+ while local_count > remote_count {
+ let last = db.before(cursor, HISTORY_PAGE_SIZE)?;
+ let mut buffer = Vec::<AddHistoryRequest>::new();
+
+ if last.is_empty() {
+ break;
+ }
+
+ for i in last {
+ let data = encrypt(&i, &key)?;
+ let data = serde_json::to_string(&data)?;
+
+ let add_hist = AddHistoryRequest {
+ id: i.id,
+ timestamp: i.timestamp,
+ data,
+ hostname: hash_str(i.hostname.as_str()),
+ };
+
+ buffer.push(add_hist);
+ }
+
+ // anything left over outside of the 100 block size
+ client.post_history(&buffer)?;
+ cursor = buffer.last().unwrap().timestamp;
+
+ remote_count = client.count()?;
+ }
+
+ Ok(())
+}
+
+pub fn sync(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
+ let client = api_client::Client::new(settings);
+
+ sync_upload(settings, force, &client, db)?;
+
+ let download = sync_download(force, &client, db)?;
+
+ debug!("sync downloaded {}", download.0);
+
+ Local::save_sync_time()?;
+
+ Ok(())
+}
diff --git a/src/main.rs b/src/main.rs
index bac75362..ae459807 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -6,7 +6,7 @@
use std::path::PathBuf;
use eyre::{eyre, Result};
-use structopt::StructOpt;
+use structopt::{clap::AppSettings, StructOpt};
#[macro_use]
extern crate log;
@@ -30,18 +30,21 @@ use command::AtuinCmd;
use local::database::Sqlite;
use settings::Settings;
+mod api;
mod command;
mod local;
mod remote;
mod settings;
+mod utils;
pub mod schema;
#[derive(StructOpt)]
#[structopt(
author = "Ellie Huxtable <e@elm.sh>",
- version = "0.4.0",
- about = "Magical shell history"
+ version = "0.5.0",
+ about = "Magical shell history",
+ global_settings(&[AppSettings::ColoredHelp, AppSettings::DeriveDisplayOrder])
)]
struct Atuin {
#[structopt(long, parse(from_os_str), help = "db file path")]
@@ -52,9 +55,7 @@ struct Atuin {
}
impl Atuin {
- fn run(self) -> Result<()> {
- let settings = Settings::new()?;
-
+ fn run(self, settings: &Settings) -> Result<()> {
let db_path = if let Some(db_path) = self.db {
let path = db_path
.to_str()
@@ -67,11 +68,13 @@ impl Atuin {
let mut db = Sqlite::new(db_path)?;
- self.atuin.run(&mut db, &settings)
+ self.atuin.run(&mut db, settings)
}
}
fn main() -> Result<()> {
+ let settings = Settings::new()?;
+
fern::Dispatch::new()
.format(|out, message, record| {
out.finish(format_args!(
@@ -85,5 +88,5 @@ fn main() -> Result<()> {
.chain(std::io::stdout())
.apply()?;
- Atuin::from_args().run()
+ Atuin::from_args().run(&settings)
}
diff --git a/src/remote/auth.rs b/src/remote/auth.rs
index 8f9e9b46..cf61b077 100644
--- a/src/remote/auth.rs
+++ b/src/remote/auth.rs
@@ -1,6 +1,8 @@
use self::diesel::prelude::*;
+use eyre::Result;
use rocket::http::Status;
use rocket::request::{self, FromRequest, Outcome, Request};
+use rocket::State;
use rocket_contrib::databases::diesel;
use sodiumoxide::crypto::pwhash::argon2id13;
@@ -9,7 +11,11 @@ use uuid::Uuid;
use super::models::{NewSession, NewUser, Session, User};
use super::views::ApiResponse;
+
+use crate::api::{LoginRequest, RegisterRequest};
use crate::schema::{sessions, users};
+use crate::settings::Settings;
+use crate::utils::hash_secret;
use super::database::AtuinDbConn;
@@ -19,20 +25,6 @@ pub enum KeyError {
Invalid,
}
-pub fn hash_str(secret: &str) -> String {
- sodiumoxide::init().unwrap();
- let hash = argon2id13::pwhash(
- secret.as_bytes(),
- argon2id13::OPSLIMIT_INTERACTIVE,
- argon2id13::MEMLIMIT_INTERACTIVE,
- )
- .unwrap();
- let texthash = std::str::from_utf8(&hash.0).unwrap().to_string();
-
- // postgres hates null chars. don't do that to postgres
- texthash.trim_end_matches('\u{0}').to_string()
-}
-
pub fn verify_str(secret: &str, verify: &str) -> bool {
sodiumoxide::init().unwrap();
@@ -95,19 +87,54 @@ impl<'a, 'r> FromRequest<'a, 'r> for User {
}
}
-#[derive(Deserialize)]
-pub struct Register {
- email: String,
- password: String,
+#[get("/user/<user>")]
+#[allow(clippy::clippy::needless_pass_by_value)]
+pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse {
+ use crate::schema::users::dsl::{username, users};
+
+ let user: Result<String, diesel::result::Error> = users
+ .select(username)
+ .filter(username.eq(user))
+ .first(&*conn);
+
+ if user.is_err() {
+ return ApiResponse {
+ json: json!({
+ "message": "could not find user",
+ }),
+ status: Status::NotFound,
+ };
+ }
+
+ let user = user.unwrap();
+
+ ApiResponse {
+ json: json!({ "username": user.as_str() }),
+ status: Status::Ok,
+ }
}
#[post("/register", data = "<register>")]
#[allow(clippy::clippy::needless_pass_by_value)]
-pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
- let hashed = hash_str(register.password.as_str());
+pub fn register(
+ conn: AtuinDbConn,
+ register: Json<RegisterRequest>,
+ settings: State<Settings>,
+) -> ApiResponse {
+ if !settings.server.open_registration {
+ return ApiResponse {
+ status: Status::BadRequest,
+ json: json!({
+ "message": "registrations are not open"
+ }),
+ };
+ }
+
+ let hashed = hash_secret(register.password.as_str());
let new_user = NewUser {
email: register.email.as_str(),
+ username: register.username.as_str(),
password: hashed.as_str(),
};
@@ -119,8 +146,7 @@ pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
return ApiResponse {
status: Status::BadRequest,
json: json!({
- "status": "error",
- "message": "failed to create user - is the email already in use?",
+ "message": "failed to create user - username or email in use?",
}),
};
}
@@ -139,32 +165,26 @@ pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
{
Ok(_) => ApiResponse {
status: Status::Ok,
- json: json!({"status": "ok", "message": "user created!", "session": token}),
+ json: json!({"message": "user created!", "session": token}),
},
Err(_) => ApiResponse {
status: Status::BadRequest,
- json: json!({"status": "error", "message": "failed to create user"}),
+ json: json!({ "message": "failed to create user"}),
},
}
}
-#[derive(Deserialize)]
-pub struct Login {
- email: String,
- password: String,
-}
-
#[post("/login", data = "<login>")]
#[allow(clippy::clippy::needless_pass_by_value)]
-pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
+pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse {
let user = users::table
- .filter(users::email.eq(login.email.as_str()))
+ .filter(users::username.eq(login.username.as_str()))
.first(&*conn);
if user.is_err() {
return ApiResponse {
status: Status::NotFound,
- json: json!({"status": "error", "message": "user not found"}),
+ json: json!({"message": "user not found"}),
};
}
@@ -178,7 +198,7 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
if session.is_err() {
return ApiResponse {
status: Status::InternalServerError,
- json: json!({"status": "error", "message": "something went wrong"}),
+ json: json!({"message": "something went wrong"}),
};
}
@@ -187,7 +207,7 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
if !verified {
return ApiResponse {
status: Status::NotFound,
- json: json!({"status": "error", "message": "user not found"}),
+ json: json!({"message": "user not found"}),
};
}
@@ -195,6 +215,6 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
ApiResponse {
status: Status::Ok,
- json: json!({"status": "ok", "token": session.token}),
+ json: json!({"session": session.token}),
}
}
diff --git a/src/remote/database.rs b/src/remote/database.rs
index fabd07de..ddcffda0 100644
--- a/src/remote/database.rs
+++ b/src/remote/database.rs
@@ -8,7 +8,7 @@ pub struct AtuinDbConn(diesel::PgConnection);
// TODO: connection pooling
pub fn establish_connection(settings: &Settings) -> PgConnection {
- let database_url = &settings.remote.db_uri;
+ let database_url = &settings.server.db_uri;
PgConnection::establish(database_url)
.unwrap_or_else(|_| panic!("Error connecting to {}", database_url))
}
diff --git a/src/remote/models.rs b/src/remote/models.rs
index 058b2f0b..7f6f7766 100644
--- a/src/remote/models.rs
+++ b/src/remote/models.rs
@@ -1,23 +1,26 @@
-use chrono::naive::NaiveDateTime;
+use chrono::prelude::*;
use crate::schema::{history, sessions, users};
-#[derive(Identifiable, Queryable, Associations)]
+#[derive(Deserialize, Serialize, Identifiable, Queryable, Associations)]
#[table_name = "history"]
#[belongs_to(User)]
pub struct History {
pub id: i64,
- pub client_id: String,
+ pub client_id: String, // a client generated ID
pub user_id: i64,
- pub mac: String,
+ pub hostname: String,
pub timestamp: NaiveDateTime,
pub data: String,
+
+ pub created_at: NaiveDateTime,
}
#[derive(Identifiable, Queryable, Associations)]
pub struct User {
pub id: i64,
+ pub username: String,
pub email: String,
pub password: String,
}
@@ -35,8 +38,8 @@ pub struct Session {
pub struct NewHistory<'a> {
pub client_id: &'a str,
pub user_id: i64,
- pub mac: &'a str,
- pub timestamp: NaiveDateTime,
+ pub hostname: String,
+ pub timestamp: chrono::NaiveDateTime,
pub data: &'a str,
}
@@ -44,6 +47,7 @@ pub struct NewHistory<'a> {
#[derive(Insertable)]
#[table_name = "users"]
pub struct NewUser<'a> {
+ pub username: &'a str,
pub email: &'a str,
pub password: &'a str,
}
diff --git a/src/remote/server.rs b/src/remote/server.rs
index cd2ca7b8..de58397d 100644
--- a/src/remote/server.rs
+++ b/src/remote/server.rs
@@ -17,13 +17,15 @@ use super::auth::*;
embed_migrations!("migrations");
pub fn launch(settings: &Settings, host: String, port: u16) {
+ let settings: Settings = settings.clone(); // clone so rocket can manage it
+
let mut database_config = HashMap::new();
let mut databases = HashMap::new();
- database_config.insert("url", Value::from(settings.remote.db_uri.clone()));
+ database_config.insert("url", Value::from(settings.server.db_uri.clone()));
databases.insert("atuin", Value::from(database_config));
- let connection = establish_connection(settings);
+ let connection = establish_connection(&settings);
embedded_migrations::run(&connection).expect("failed to run migrations");
let config = Config::build(Environment::Production)
@@ -36,8 +38,20 @@ pub fn launch(settings: &Settings, host: String, port: u16) {
let app = rocket::custom(config);
- app.mount("/", routes![index, register, add_history, login])
- .attach(AtuinDbConn::fairing())
- .register(catchers![internal_error, bad_request])
- .launch();
+ app.mount(
+ "/",
+ routes![
+ index,
+ register,
+ add_history,
+ login,
+ get_user,
+ sync_count,
+ sync_list
+ ],
+ )
+ .manage(settings)
+ .attach(AtuinDbConn::fairing())
+ .register(catchers![internal_error, bad_request])
+ .launch();
}
diff --git a/src/remote/views.rs b/src/remote/views.rs
index 2af3f369..08dff13e 100644
--- a/src/remote/views.rs
+++ b/src/remote/views.rs
@@ -1,14 +1,22 @@
-use self::diesel::prelude::*;
+use chrono::Utc;
+use rocket::http::uri::Uri;
+use rocket::http::RawStr;
use rocket::http::{ContentType, Status};
+use rocket::request::FromFormValue;
use rocket::request::Request;
use rocket::response;
use rocket::response::{Responder, Response};
use rocket_contrib::databases::diesel;
use rocket_contrib::json::{Json, JsonValue};
-use super::database::AtuinDbConn;
-use super::models::{NewHistory, User};
+use self::diesel::prelude::*;
+
+use crate::api::AddHistoryRequest;
use crate::schema::history;
+use crate::settings::HISTORY_PAGE_SIZE;
+
+use super::database::AtuinDbConn;
+use super::models::{History, NewHistory, User};
#[derive(Debug)]
pub struct ApiResponse {
@@ -46,40 +54,36 @@ pub fn bad_request(_req: &Request) -> ApiResponse {
}
}
-#[derive(Deserialize)]
-pub struct AddHistory {
- id: String,
- timestamp: i64,
- data: String,
- mac: String,
-}
-
#[post("/history", data = "<add_history>")]
#[allow(
clippy::clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::clippy::needless_pass_by_value
)]
-pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>) -> ApiResponse {
- let secs: i64 = add_history.timestamp / 1_000_000_000;
- let nanosecs: u32 = (add_history.timestamp - (secs * 1_000_000_000)) as u32;
- let datetime = chrono::NaiveDateTime::from_timestamp(secs, nanosecs);
-
- let new_history = NewHistory {
- client_id: add_history.id.as_str(),
- user_id: user.id,
- mac: add_history.mac.as_str(),
- timestamp: datetime,
- data: add_history.data.as_str(),
- };
+pub fn add_history(
+ conn: AtuinDbConn,
+ user: User,
+ add_history: Json<Vec<AddHistoryRequest>>,
+) -> ApiResponse {
+ let new_history: Vec<NewHistory> = add_history
+ .iter()
+ .map(|h| NewHistory {
+ client_id: h.id.as_str(),
+ hostname: h.hostname.to_string(),
+ user_id: user.id,
+ timestamp: h.timestamp.naive_utc(),
+ data: h.data.as_str(),
+ })
+ .collect();
match diesel::insert_into(history::table)
.values(&new_history)
+ .on_conflict_do_nothing()
.execute(&*conn)
{
Ok(_) => ApiResponse {
status: Status::Ok,
- json: json!({"status": "ok", "message": "history added", "id": new_history.client_id}),
+ json: json!({"status": "ok", "message": "history added"}),
},
Err(_) => ApiResponse {
status: Status::BadRequest,
@@ -87,3 +91,95 @@ pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>)
},
}
}
+
+#[get("/sync/count")]
+#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
+pub fn sync_count(conn: AtuinDbConn, user: User) -> ApiResponse {
+ use crate::schema::history::dsl::*;
+
+ // we need to return the number of history items we have for this user
+ // in the future I'd like to use something like a merkel tree to calculate
+ // which day specifically needs syncing
+ let count = history
+ .filter(user_id.eq(user.id))
+ .count()
+ .first::<i64>(&*conn);
+
+ if count.is_err() {
+ error!("failed to count: {}", count.err().unwrap());
+
+ return ApiResponse {
+ json: json!({"message": "internal server error"}),
+ status: Status::InternalServerError,
+ };
+ }
+
+ ApiResponse {
+ status: Status::Ok,
+ json: json!({"count": count.ok()}),
+ }
+}
+
+pub struct UtcDateTime(chrono::DateTime<Utc>);
+
+impl<'v> FromFormValue<'v> for UtcDateTime {
+ type Error = &'v RawStr;
+
+ fn from_form_value(form_value: &'v RawStr) -> Result<UtcDateTime, &'v RawStr> {
+ let time = Uri::percent_decode(form_value.as_bytes()).map_err(|_| form_value)?;
+ let time = time.to_string();
+
+ match chrono::DateTime::parse_from_rfc3339(time.as_str()) {
+ Ok(t) => Ok(UtcDateTime(t.with_timezone(&Utc))),
+ Err(e) => {
+ error!("failed to parse time {}, got: {}", time, e);
+ Err(form_value)
+ }
+ }
+ }
+}
+
+// Request a list of all history items added to the DB after a given timestamp.
+// Provide the current hostname, so that we don't send the client data that
+// originated from them
+#[get("/sync/history?<sync_ts>&<history_ts>&<host>")]
+#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
+pub fn sync_list(
+ conn: AtuinDbConn,
+ user: User,
+ sync_ts: UtcDateTime,
+ history_ts: UtcDateTime,
+ host: String,
+) -> ApiResponse {
+ use crate::schema::history::dsl::*;
+
+ // we need to return the number of history items we have for this user
+ // in the future I'd like to use something like a merkel tree to calculate
+ // which day specifically needs syncing
+ // TODO: Allow for configuring the page size, both from params, and setting
+ // the max in config. 100 is fine for now.
+ let h = history
+ .filter(user_id.eq(user.id))
+ .filter(hostname.ne(host))
+ .filter(created_at.ge(sync_ts.0.naive_utc()))
+ .filter(timestamp.ge(history_ts.0.naive_utc()))
+ .order(timestamp.asc())
+ .limit(HISTORY_PAGE_SIZE)
+ .load::<History>(&*conn);
+
+ if let Err(e) = h {
+ error!("failed to load history: {}", e);
+
+ return ApiResponse {
+ json: json!({"message": "internal server error"}),
+ status: Status::InternalServerError,
+ };
+ }
+
+ let user_data: Vec<String> = h.unwrap().iter().map(|i| i.data.to_string()).collect();
+
+ ApiResponse {
+ status: Status::Ok,
+ json: json!({ "history": user_data }),
+ }
+}
diff --git a/src/schema.rs b/src/schema.rs
index efa9ddcc..84bf5bab 100644
--- a/src/schema.rs
+++ b/src/schema.rs
@@ -3,9 +3,10 @@ table! {
id -> Int8,
client_id -> Text,
user_id -> Int8,
- mac -> Varchar,
+ hostname -> Text,
timestamp -> Timestamp,
data -> Varchar,
+ created_at -> Timestamp,
}
}
@@ -20,6 +21,7 @@ table! {
table! {
users (id) {
id -> Int8,
+ username -> Varchar,
email -> Varchar,
password -> Varchar,
}
diff --git a/src/settings.rs b/src/settings.rs
index 0e554bed..dcf69a7c 100644
--- a/src/settings.rs
+++ b/src/settings.rs
@@ -1,31 +1,90 @@
-use std::path::PathBuf;
+use std::fs::{create_dir_all, File};
+use std::io::prelude::*;
+use std::path::{Path, PathBuf};
-use config::{Config, File};
+use chrono::prelude::*;
+use chrono::Utc;
+use config::{Config, File as ConfigFile};
use directories::ProjectDirs;
use eyre::{eyre, Result};
-use std::fs;
+use parse_duration::parse;
-#[derive(Debug, Deserialize)]
+pub const HISTORY_PAGE_SIZE: i64 = 100;
+
+#[derive(Clone, Debug, Deserialize)]
pub struct Local {
pub dialect: String,
- pub sync: bool,
+ pub auto_sync: bool,
pub sync_address: String,
pub sync_frequency: String,
pub db_path: String,
+ pub key_path: String,
+ pub session_path: String,
+
+ // This is automatically loaded when settings is created. Do not set in
+ // config! Keep secrets and settings apart.
+ pub session_token: String,
}
-#[derive(Debug, Deserialize)]
-pub struct Remote {
+impl Local {
+ pub fn save_sync_time() -> Result<()> {
+ let sync_time_path = ProjectDirs::from("com", "elliehuxtable", "atuin")
+ .ok_or_else(|| eyre!("could not determine key file location"))?;
+ let sync_time_path = sync_time_path.data_dir().join("last_sync_time");
+
+ std::fs::write(sync_time_path, Utc::now().to_rfc3339())?;
+
+ Ok(())
+ }
+
+ pub fn last_sync() -> Result<chrono::DateTime<Utc>> {
+ let sync_time_path = ProjectDirs::from("com", "elliehuxtable", "atuin");
+
+ if sync_time_path.is_none() {
+ debug!("failed to load projectdirs, not syncing");
+ return Err(eyre!("could not load project dirs"));
+ }
+
+ let sync_time_path = sync_time_path.unwrap();
+ let sync_time_path = sync_time_path.data_dir().join("last_sync_time");
+
+ if !sync_time_path.exists() {
+ return Ok(Utc.ymd(1970, 1, 1).and_hms(0, 0, 0));
+ }
+
+ let time = std::fs::read_to_string(sync_time_path)?;
+ let time = chrono::DateTime::parse_from_rfc3339(time.as_str())?;
+
+ Ok(time.with_timezone(&Utc))
+ }
+
+ pub fn should_sync(&self) -> Result<bool> {
+ if !self.auto_sync {
+ return Ok(false);
+ }
+
+ match parse(self.sync_frequency.as_str()) {
+ Ok(d) => {
+ let d = chrono::Duration::from_std(d).unwrap();
+ Ok(Utc::now() - Local::last_sync()? >= d)
+ }
+ Err(e) => Err(eyre!("failed to check sync: {}", e)),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Deserialize)]
+pub struct Server {
pub host: String,
pub port: u16,
pub db_uri: String,
pub open_registration: bool,
}
-#[derive(Debug, Deserialize)]
+#[derive(Clone, Debug, Deserialize)]
pub struct Settings {
pub local: Local,
- pub remote: Remote,
+ pub server: Server,
}
impl Settings {
@@ -33,7 +92,7 @@ impl Settings {
let config_dir = ProjectDirs::from("com", "elliehuxtable", "atuin").unwrap();
let config_dir = config_dir.config_dir();
- fs::create_dir_all(config_dir)?;
+ create_dir_all(config_dir)?;
let mut config_file = PathBuf::new();
config_file.push(config_dir);
@@ -45,31 +104,61 @@ impl Settings {
let mut s = Config::new();
let db_path = ProjectDirs::from("com", "elliehuxtable", "atuin")
- .ok_or_else(|| {
- eyre!("could not determine db file location\nspecify one using the --db flag")
- })?
+ .ok_or_else(|| eyre!("could not determine db file location"))?
.data_dir()
.join("history.db");
+ let key_path = ProjectDirs::from("com", "elliehuxtable", "atuin")
+ .ok_or_else(|| eyre!("could not determine key file location"))?
+ .data_dir()
+ .join("key");
+
+ let session_path = ProjectDirs::from("com", "elliehuxtable", "atuin")
+ .ok_or_else(|| eyre!("could not determine session file location"))?
+ .data_dir()
+ .join("session");
+
s.set_default("local.db_path", db_path.to_str())?;
+ s.set_default("local.key_path", key_path.to_str())?;
+ s.set_default("local.session_path", session_path.to_str())?;
s.set_default("local.dialect", "us")?;
- s.set_default("local.sync", false)?;
+ s.set_default("local.auto_sync", true)?;
s.set_default("local.sync_frequency", "5m")?;
- s.set_default("local.sync_address", "https://atuin.ellie.wtf")?;
+ s.set_default("local.sync_address", "https://api.atuin.sh")?;
- s.set_default("remote.host", "127.0.0.1")?;
- s.set_default("remote.port", 8888)?;
- s.set_default("remote.open_registration", false)?;
- s.set_default("remote.db_uri", "please set a postgres url")?;
+ s.set_default("server.host", "127.0.0.1")?;
+ s.set_default("server.port", 8888)?;
+ s.set_default("server.open_registration", false)?;
+ s.set_default("server.db_uri", "please set a postgres url")?;
if config_file.exists() {
- s.merge(File::with_name(config_file.to_str().unwrap()))?;
+ s.merge(ConfigFile::with_name(config_file.to_str().unwrap()))?;
+ } else {
+ let example_config = include_bytes!("../config.toml");
+ let mut file = File::create(config_file)?;
+ file.write_all(example_config)?;
}
// all paths should be expanded
let db_path = s.get_str("local.db_path")?;
let db_path = shellexpand::full(db_path.as_str())?;
- s.set("local.db.path", db_path.to_string())?;
+ s.set("local.db_path", db_path.to_string())?;
+
+ let key_path = s.get_str("local.key_path")?;
+ let key_path = shellexpand::full(key_path.as_str())?;
+ s.set("local.key_path", key_path.to_string())?;
+
+ let session_path = s.get_str("local.session_path")?;
+ let session_path = shellexpand::full(session_path.as_str())?;
+ s.set("local.session_path", session_path.to_string())?;
+
+ // Finally, set the auth token
+ if Path::new(session_path.to_string().as_str()).exists() {
+ let token = std::fs::read_to_string(session_path.to_string())?;
+ s.set("local.session_token", token)?;
+ } else {
+ s.set("local.session_token", "not logged in")?;
+ }
s.try_into()
.map_err(|e| eyre!("failed to deserialize: {}", e))
diff --git a/src/shell/atuin.zsh b/src/shell/atuin.zsh
index 8407efd2..d2abf3c1 100644
--- a/src/shell/atuin.zsh
+++ b/src/shell/atuin.zsh
@@ -1,4 +1,6 @@
# Source this in your ~/.zshrc
+autoload -U add-zsh-hook
+
export ATUIN_SESSION=$(atuin uuid)
export ATUIN_HISTORY="atuin history list"
export ATUIN_BINDKEYS="true"
@@ -20,24 +22,12 @@ _atuin_search(){
emulate -L zsh
zle -I
+ # Switch to cursor mode, then back to application
+ echoti rmkx
# swap stderr and stdout, so that the tui stuff works
# TODO: not this
output=$(atuin search $BUFFER 3>&1 1>&2 2>&3)
-
- if [[ -n $output ]] ; then
- LBUFFER=$output
- fi
-
- zle reset-prompt
-}
-
-_atuin_up_search(){
- emulate -L zsh
- zle -I
-
- # swap stderr and stdout, so that the tui stuff works
- # TODO: not this
- output=$(atuin search $BUFFER 3>&1 1>&2 2>&3)
+ echoti smkx
if [[ -n $output ]] ; then
LBUFFER=$output
@@ -50,9 +40,11 @@ add-zsh-hook preexec _atuin_preexec
add-zsh-hook precmd _atuin_precmd
zle -N _atuin_search_widget _atuin_search
-zle -N _atuin_up_search_widget _atuin_up_search
if [[ $ATUIN_BINDKEYS == "true" ]]; then
bindkey '^r' _atuin_search_widget
- bindkey '^[[A' _atuin_up_search_widget
+
+ # depends on terminal mode
+ bindkey '^[[A' _atuin_search_widget
+ bindkey '^[OA' _atuin_search_widget
fi
diff --git a/src/utils.rs b/src/utils.rs
new file mode 100644
index 00000000..b395b148
--- /dev/null
+++ b/src/utils.rs
@@ -0,0 +1,24 @@
+use crypto::digest::Digest;
+use crypto::sha2::Sha256;
+use sodiumoxide::crypto::pwhash::argon2id13;
+
+pub fn hash_secret(secret: &str) -> String {
+ sodiumoxide::init().unwrap();
+ let hash = argon2id13::pwhash(
+ secret.as_bytes(),
+ argon2id13::OPSLIMIT_INTERACTIVE,
+ argon2id13::MEMLIMIT_INTERACTIVE,
+ )
+ .unwrap();
+ let texthash = std::str::from_utf8(&hash.0).unwrap().to_string();
+
+ // postgres hates null chars. don't do that to postgres
+ texthash.trim_end_matches('\u{0}').to_string()
+}
+
+pub fn hash_str(string: &str) -> String {
+ let mut hasher = Sha256::new();
+ hasher.input_str(string);
+
+ hasher.result_str()
+}