aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml1
-rw-r--r--atuin-client/Cargo.toml1
-rw-r--r--atuin-client/record-migrations/20231127090831_create-store.sql15
-rw-r--r--atuin-client/src/api_client.rs46
-rw-r--r--atuin-client/src/history.rs210
-rw-r--r--atuin-client/src/history/store.rs219
-rw-r--r--atuin-client/src/kv.rs100
-rw-r--r--atuin-client/src/record/encryption.rs29
-rw-r--r--atuin-client/src/record/sqlite_store.rs250
-rw-r--r--atuin-client/src/record/store.rs36
-rw-r--r--atuin-client/src/record/sync.rs486
-rw-r--r--atuin-client/src/settings.rs9
-rw-r--r--atuin-common/src/record.rs166
-rw-r--r--atuin-server-database/src/lib.rs6
-rw-r--r--atuin-server-postgres/migrations/20231202170508_create-store.sql15
-rw-r--r--atuin-server-postgres/migrations/20231203124112_create-store-idx.sql2
-rw-r--r--atuin-server-postgres/src/lib.rs92
-rw-r--r--atuin-server-postgres/src/wrappers.rs7
-rw-r--r--atuin-server/src/handlers/mod.rs1
-rw-r--r--atuin-server/src/handlers/record.rs107
-rw-r--r--atuin-server/src/handlers/v0/mod.rs1
-rw-r--r--atuin-server/src/handlers/v0/record.rs111
-rw-r--r--atuin-server/src/router.rs11
-rw-r--r--atuin/src/command/client.rs14
-rw-r--r--atuin/src/command/client/history.rs70
-rw-r--r--atuin/src/command/client/kv.rs6
-rw-r--r--atuin/src/command/client/record.rs63
-rw-r--r--atuin/src/command/client/sync.rs15
30 files changed, 1473 insertions, 619 deletions
diff --git a/.gitignore b/.gitignore
index 17c0b070..3d544414 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,8 @@
+.DS_Store
/target
*/target
.env
.idea/
.vscode/
result
+publish.sh
diff --git a/Cargo.lock b/Cargo.lock
index 28c00fb8..46499c34 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -242,6 +242,7 @@ dependencies = [
"shellexpand",
"sql-builder",
"sqlx",
+ "thiserror",
"time",
"tokio",
"typed-builder",
diff --git a/Cargo.toml b/Cargo.toml
index 1d9f6176..9aa03831 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -43,6 +43,7 @@ uuid = { version = "1.3", features = ["v4", "v7", "serde"] }
whoami = "1.1.2"
typed-builder = "0.18.0"
pretty_assertions = "1.3.0"
+thiserror = "1.0"
[workspace.dependencies.reqwest]
version = "0.11"
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml
index c3d9cab7..cbb8d016 100644
--- a/atuin-client/Cargo.toml
+++ b/atuin-client/Cargo.toml
@@ -48,6 +48,7 @@ rmp = { version = "0.8.11" }
typed-builder = { workspace = true }
tokio = { workspace = true }
semver = { workspace = true }
+thiserror = { workspace = true }
futures = "0.3"
crypto_secretbox = "0.1.1"
generic-array = { version = "0.14", features = ["serde"] }
diff --git a/atuin-client/record-migrations/20231127090831_create-store.sql b/atuin-client/record-migrations/20231127090831_create-store.sql
new file mode 100644
index 00000000..53d78860
--- /dev/null
+++ b/atuin-client/record-migrations/20231127090831_create-store.sql
@@ -0,0 +1,15 @@
+-- Add migration script here
+create table if not exists store (
+ id text primary key, -- globally unique ID
+
+ idx integer, -- incrementing integer ID unique per (host, tag)
+ host text not null, -- references the host row
+ tag text not null,
+
+ timestamp integer not null,
+ version text not null,
+ data blob not null,
+ cek blob not null
+);
+
+create unique index record_uniq ON store(host, tag, idx);
diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs
index ae8df5ad..9007b9ab 100644
--- a/atuin-client/src/api_client.rs
+++ b/atuin-client/src/api_client.rs
@@ -13,11 +13,11 @@ use atuin_common::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
},
- record::RecordIndex,
+ record::RecordStatus,
};
use atuin_common::{
api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
- record::{EncryptedData, HostId, Record, RecordId},
+ record::{EncryptedData, HostId, Record, RecordIdx},
};
use semver::Version;
use time::format_description::well_known::Rfc3339;
@@ -267,10 +267,18 @@ impl<'a> Client<'a> {
}
pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
- let url = format!("{}/record", self.sync_addr);
+ let url = format!("{}/api/v0/record", self.sync_addr);
let url = Url::parse(url.as_str())?;
- self.client.post(url).json(records).send().await?;
+ let resp = self.client.post(url).json(records).send().await?;
+ info!("posted records, got {}", resp.status());
+
+ if !resp.status().is_success() {
+ error!(
+ "failed to post records to server; got: {:?}",
+ resp.text().await
+ );
+ }
Ok(())
}
@@ -279,24 +287,22 @@ impl<'a> Client<'a> {
&self,
host: HostId,
tag: String,
- start: Option<RecordId>,
+ start: RecordIdx,
count: u64,
) -> Result<Vec<Record<EncryptedData>>> {
+ debug!(
+ "fetching record/s from host {}/{}/{}",
+ host.0.to_string(),
+ tag,
+ start
+ );
+
let url = format!(
- "{}/record/next?host={}&tag={}&count={}",
- self.sync_addr, host.0, tag, count
+ "{}/api/v0/record/next?host={}&tag={}&count={}&start={}",
+ self.sync_addr, host.0, tag, count, start
);
- let mut url = Url::parse(url.as_str())?;
- if let Some(start) = start {
- url.set_query(Some(
- format!(
- "host={}&tag={}&count={}&start={}",
- host.0, tag, count, start.0
- )
- .as_str(),
- ));
- }
+ let url = Url::parse(url.as_str())?;
let resp = self.client.get(url).send().await?;
@@ -305,8 +311,8 @@ impl<'a> Client<'a> {
Ok(records)
}
- pub async fn record_index(&self) -> Result<RecordIndex> {
- let url = format!("{}/record", self.sync_addr);
+ pub async fn record_status(&self) -> Result<RecordStatus> {
+ let url = format!("{}/api/v0/record", self.sync_addr);
let url = Url::parse(url.as_str())?;
let resp = self.client.get(url).send().await?;
@@ -317,6 +323,8 @@ impl<'a> Client<'a> {
let index = resp.json().await?;
+ debug!("got remote index {:?}", index);
+
Ok(index)
}
diff --git a/atuin-client/src/history.rs b/atuin-client/src/history.rs
index fbcb169c..2b2c41ee 100644
--- a/atuin-client/src/history.rs
+++ b/atuin-client/src/history.rs
@@ -1,12 +1,21 @@
+use rmp::decode::ValueReadError;
+use rmp::{decode::Bytes, Marker};
use std::env;
+use atuin_common::record::DecryptedData;
use atuin_common::utils::uuid_v7;
+
+use eyre::{bail, eyre, Result};
use regex::RegexSet;
use crate::{secrets::SECRET_PATTERNS, settings::Settings};
use time::OffsetDateTime;
mod builder;
+pub mod store;
+
+const HISTORY_VERSION: &str = "v0";
+const HISTORY_TAG: &str = "history";
/// Client-side history entry.
///
@@ -81,6 +90,108 @@ impl History {
}
}
+ pub fn serialize(&self) -> Result<DecryptedData> {
+ // This is pretty much the same as what we used for the old history, with one difference -
+ // it uses integers for timestamps rather than a string format.
+
+ use rmp::encode;
+
+ let mut output = vec![];
+
+ // write the version
+ encode::write_u16(&mut output, 0)?;
+ // INFO: ensure this is updated when adding new fields
+ encode::write_array_len(&mut output, 9)?;
+
+ encode::write_str(&mut output, &self.id)?;
+ encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?;
+ encode::write_sint(&mut output, self.duration)?;
+ encode::write_sint(&mut output, self.exit)?;
+ encode::write_str(&mut output, &self.command)?;
+ encode::write_str(&mut output, &self.cwd)?;
+ encode::write_str(&mut output, &self.session)?;
+ encode::write_str(&mut output, &self.hostname)?;
+
+ match self.deleted_at {
+ Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?,
+ None => encode::write_nil(&mut output)?,
+ }
+
+ Ok(DecryptedData(output))
+ }
+
+ fn deserialize_v0(bytes: &[u8]) -> Result<History> {
+ use rmp::decode;
+
+ fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
+ eyre!("{err:?}")
+ }
+
+ let mut bytes = Bytes::new(bytes);
+
+ let version = decode::read_u16(&mut bytes).map_err(error_report)?;
+
+ if version != 0 {
+ bail!("expected decoding v0 record, found v{version}");
+ }
+
+ let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
+
+ if nfields != 9 {
+ bail!("cannot decrypt history from a different version of Atuin");
+ }
+
+ let bytes = bytes.remaining_slice();
+ let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+
+ let mut bytes = Bytes::new(bytes);
+ let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?;
+ let duration = decode::read_int(&mut bytes).map_err(error_report)?;
+ let exit = decode::read_int(&mut bytes).map_err(error_report)?;
+
+ let bytes = bytes.remaining_slice();
+ let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+ let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+ let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+ let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+
+ // if we have more fields, try and get the deleted_at
+ let mut bytes = Bytes::new(bytes);
+
+ let (deleted_at, bytes) = match decode::read_u64(&mut bytes) {
+ Ok(unix) => (Some(unix), bytes.remaining_slice()),
+ // we accept null here
+ Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()),
+ Err(err) => return Err(error_report(err)),
+ };
+
+ if !bytes.is_empty() {
+ bail!("trailing bytes in encoded history. malformed")
+ }
+
+ Ok(History {
+ id: id.to_owned(),
+ timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?,
+ duration,
+ exit,
+ command: command.to_owned(),
+ cwd: cwd.to_owned(),
+ session: session.to_owned(),
+ hostname: hostname.to_owned(),
+ deleted_at: deleted_at
+ .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128))
+ .transpose()?,
+ })
+ }
+
+ pub fn deserialize(bytes: &[u8], version: &str) -> Result<History> {
+ match version {
+ HISTORY_VERSION => Self::deserialize_v0(bytes),
+
+ _ => bail!("unknown version {version:?}"),
+ }
+ }
+
/// Builder for a history entry that is imported from shell history.
///
/// The only two required fields are `timestamp` and `command`.
@@ -202,8 +313,9 @@ impl History {
#[cfg(test)]
mod tests {
use regex::RegexSet;
+ use time::macros::datetime;
- use crate::settings::Settings;
+ use crate::{history::HISTORY_VERSION, settings::Settings};
use super::History;
@@ -274,4 +386,100 @@ mod tests {
assert!(stripe_key.should_save(&settings));
}
+
+ #[test]
+ fn test_serialize_deserialize() {
+ let bytes = [
+ 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55,
+ 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99,
+ 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116,
+ 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100,
+ 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115,
+ 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97,
+ 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49,
+ 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58,
+ 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192,
+ ];
+
+ let history = History {
+ id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned(),
+ timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00),
+ duration: 49206000,
+ exit: 0,
+ command: "git status".to_owned(),
+ cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(),
+ session: "b97d9a306f274473a203d2eba41f9457".to_owned(),
+ hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(),
+ deleted_at: None,
+ };
+
+ let serialized = history.serialize().expect("failed to serialize history");
+ assert_eq!(serialized.0, bytes);
+
+ let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION)
+ .expect("failed to deserialize history");
+ assert_eq!(history, deserialized);
+
+ // test the snapshot too
+ let deserialized =
+ History::deserialize(&bytes, HISTORY_VERSION).expect("failed to deserialize history");
+ assert_eq!(history, deserialized);
+ }
+
+ #[test]
+ fn test_serialize_deserialize_deleted() {
+ let history = History {
+ id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned(),
+ timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00),
+ duration: 49206000,
+ exit: 0,
+ command: "git status".to_owned(),
+ cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(),
+ session: "b97d9a306f274473a203d2eba41f9457".to_owned(),
+ hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(),
+ deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)),
+ };
+
+ let serialized = history.serialize().expect("failed to serialize history");
+
+ let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION)
+ .expect("failed to deserialize history");
+
+ assert_eq!(history, deserialized);
+ }
+
+ #[test]
+ fn test_serialize_deserialize_version() {
+ // v0
+ let bytes_v0 = [
+ 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55,
+ 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99,
+ 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116,
+ 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100,
+ 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115,
+ 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97,
+ 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49,
+ 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58,
+ 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192,
+ ];
+
+ // some other version
+ let bytes_v1 = [
+ 205, 1, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55,
+ 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99,
+ 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116,
+ 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100,
+ 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115,
+ 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97,
+ 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49,
+ 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58,
+ 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192,
+ ];
+
+ let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION);
+ assert!(deserialized.is_ok());
+
+ let deserialized = History::deserialize(&bytes_v1, HISTORY_VERSION);
+ assert!(deserialized.is_err());
+ }
}
diff --git a/atuin-client/src/history/store.rs b/atuin-client/src/history/store.rs
new file mode 100644
index 00000000..bf74a0a8
--- /dev/null
+++ b/atuin-client/src/history/store.rs
@@ -0,0 +1,219 @@
+use eyre::{bail, eyre, Result};
+use rmp::decode::Bytes;
+
+use crate::record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store};
+use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordIdx};
+
+use super::{History, HISTORY_TAG, HISTORY_VERSION};
+
+#[derive(Debug)]
+pub struct HistoryStore {
+ pub store: SqliteStore,
+ pub host_id: HostId,
+ pub encryption_key: [u8; 32],
+}
+
+#[derive(Debug, Eq, PartialEq, Clone)]
+pub enum HistoryRecord {
+ Create(History), // Create a history record
+ Delete(String), // Delete a history record, identified by ID
+}
+
+impl HistoryRecord {
+ /// Serialize a history record, returning DecryptedData
+ /// The record will be of a certain type
+ /// We map those like so:
+ ///
+ /// HistoryRecord::Create -> 0
+ /// HistoryRecord::Delete-> 1
+ ///
+ /// This numeric identifier is then written as the first byte to the buffer. For history, we
+ /// append the serialized history right afterwards, to avoid having to handle serialization
+ /// twice.
+ ///
+ /// Deletion simply refers to the history by ID
+ pub fn serialize(&self) -> Result<DecryptedData> {
+ // probably don't actually need to use rmp here, but if we ever need to extend it, it's a
+ // nice wrapper around raw byte stuff
+ use rmp::encode;
+
+ let mut output = vec![];
+
+ match self {
+ HistoryRecord::Create(history) => {
+ // 0 -> a history create
+ encode::write_u8(&mut output, 0)?;
+
+ let bytes = history.serialize()?;
+
+ encode::write_bin(&mut output, &bytes.0)?;
+ }
+ HistoryRecord::Delete(id) => {
+ // 1 -> a history delete
+ encode::write_u8(&mut output, 1)?;
+ encode::write_str(&mut output, id)?;
+ }
+ };
+
+ Ok(DecryptedData(output))
+ }
+
+ pub fn deserialize(bytes: &[u8], version: &str) -> Result<Self> {
+ use rmp::decode;
+
+ fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
+ eyre!("{err:?}")
+ }
+
+ let mut bytes = Bytes::new(bytes);
+
+ let record_type = decode::read_u8(&mut bytes).map_err(error_report)?;
+
+ match record_type {
+ // 0 -> HistoryRecord::Create
+ 0 => {
+ // not super useful to us atm, but perhaps in the future
+ // written by write_bin above
+ let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?;
+
+ let record = History::deserialize(bytes.remaining_slice(), version)?;
+
+ Ok(HistoryRecord::Create(record))
+ }
+
+ // 1 -> HistoryRecord::Delete
+ 1 => {
+ let bytes = bytes.remaining_slice();
+ let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+
+ if !bytes.is_empty() {
+ bail!(
+ "trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}"
+ );
+ }
+
+ Ok(HistoryRecord::Delete(id.to_string()))
+ }
+
+ n => {
+ bail!("unknown HistoryRecord type {n}")
+ }
+ }
+ }
+}
+
+impl HistoryStore {
+ pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self {
+ HistoryStore {
+ store,
+ host_id,
+ encryption_key,
+ }
+ }
+
+ async fn push_record(&self, record: HistoryRecord) -> Result<RecordIdx> {
+ let bytes = record.serialize()?;
+ let idx = self
+ .store
+ .last(self.host_id, HISTORY_TAG)
+ .await?
+ .map_or(0, |p| p.idx + 1);
+
+ let record = Record::builder()
+ .host(Host::new(self.host_id))
+ .version(HISTORY_VERSION.to_string())
+ .tag(HISTORY_TAG.to_string())
+ .idx(idx)
+ .data(bytes)
+ .build();
+
+ self.store
+ .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
+ .await?;
+
+ Ok(idx)
+ }
+
+ pub async fn delete(&self, id: String) -> Result<RecordIdx> {
+ let record = HistoryRecord::Delete(id);
+
+ self.push_record(record).await
+ }
+
+ pub async fn push(&self, history: History) -> Result<RecordIdx> {
+ // TODO(ellie): move the history store to its own file
+ // it's tiny rn so fine as is
+ let record = HistoryRecord::Create(history);
+
+ self.push_record(record).await
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use time::macros::datetime;
+
+ use crate::history::{store::HistoryRecord, HISTORY_VERSION};
+
+ use super::History;
+
+ #[test]
+ fn test_serialize_deserialize_create() {
+ let bytes = [
+ 204, 0, 196, 141, 205, 0, 0, 153, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49,
+ 55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99,
+ 56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85,
+ 115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116,
+ 104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117,
+ 105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55,
+ 56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112,
+ 58, 101, 108, 108, 105, 101, 192,
+ ];
+
+ let history = History {
+ id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned(),
+ timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00),
+ duration: 100,
+ exit: 0,
+ command: "ls".to_owned(),
+ cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(),
+ session: "018cd4fead897597852527a31c998059".to_owned(),
+ hostname: "boop:ellie".to_owned(),
+ deleted_at: None,
+ };
+
+ let record = HistoryRecord::Create(history);
+
+ let serialized = record.serialize().expect("failed to serialize history");
+ assert_eq!(serialized.0, bytes);
+
+ let deserialized = HistoryRecord::deserialize(&serialized.0, HISTORY_VERSION)
+ .expect("failed to deserialize HistoryRecord");
+ assert_eq!(deserialized, record);
+
+ // check the snapshot too
+ let deserialized = HistoryRecord::deserialize(&bytes, HISTORY_VERSION)
+ .expect("failed to deserialize HistoryRecord");
+ assert_eq!(deserialized, record);
+ }
+
+ #[test]
+ fn test_serialize_deserialize_delete() {
+ let bytes = [
+ 204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50,
+ 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49,
+ ];
+ let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string());
+
+ let serialized = record.serialize().expect("failed to serialize history");
+ assert_eq!(serialized.0, bytes);
+
+ let deserialized = HistoryRecord::deserialize(&serialized.0, HISTORY_VERSION)
+ .expect("failed to deserialize HistoryRecord");
+ assert_eq!(deserialized, record);
+
+ let deserialized = HistoryRecord::deserialize(&bytes, HISTORY_VERSION)
+ .expect("failed to deserialize HistoryRecord");
+ assert_eq!(deserialized, record);
+ }
+}
diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs
index 1ca6b5e8..cee7063d 100644
--- a/atuin-client/src/kv.rs
+++ b/atuin-client/src/kv.rs
@@ -1,6 +1,6 @@
use std::collections::BTreeMap;
-use atuin_common::record::{DecryptedData, HostId};
+use atuin_common::record::{DecryptedData, Host, HostId};
use eyre::{bail, ensure, eyre, Result};
use serde::Deserialize;
@@ -89,7 +89,7 @@ impl KvStore {
pub async fn set(
&self,
- store: &mut (impl Store + Send + Sync),
+ store: &(impl Store + Send + Sync),
encryption_key: &[u8; 32],
host_id: HostId,
namespace: &str,
@@ -111,13 +111,16 @@ impl KvStore {
let bytes = record.serialize()?;
- let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id);
+ let idx = store
+ .last(host_id, KV_TAG)
+ .await?
+ .map_or(0, |entry| entry.idx + 1);
let record = atuin_common::record::Record::builder()
- .host(host_id)
+ .host(Host::new(host_id))
.version(KV_VERSION.to_string())
.tag(KV_TAG.to_string())
- .parent(parent)
+ .idx(idx)
.data(bytes)
.build();
@@ -137,43 +140,18 @@ impl KvStore {
namespace: &str,
key: &str,
) -> Result<Option<KvRecord>> {
- // Currently, this is O(n). When we have an actual KV store, it can be better
- // Just a poc for now!
+ // TODO: don't rebuild every time...
+ let map = self.build_kv(store, encryption_key).await?;
- // iterate records to find the value we want
- // start at the end, so we get the most recent version
- let tails = store.tag_tails(KV_TAG).await?;
+ let res = map.get(namespace);
- if tails.is_empty() {
- return Ok(None);
- }
-
- // first, decide on a record.
- // try getting the newest first
- // we always need a way of deciding the "winner" of a write
- // TODO(ellie): something better than last-write-wins, what if two write at the same time?
- let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
-
- loop {
- let decrypted = match record.version.as_str() {
- KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
- version => bail!("unknown version {version:?}"),
- };
-
- let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
- if kv.key == key && kv.namespace == namespace {
- return Ok(Some(kv));
- }
+ if let Some(ns) = res {
+ let value = ns.get(key);
- if let Some(parent) = decrypted.parent {
- record = store.get(parent).await?;
- } else {
- break;
- }
+ Ok(value.cloned())
+ } else {
+ Ok(None)
}
-
- // if we get here, then... we didn't find the record with that key :(
- Ok(None)
}
// Build a kv map out of the linked list kv store
@@ -184,32 +162,30 @@ impl KvStore {
&self,
store: &impl Store,
encryption_key: &[u8; 32],
- ) -> Result<BTreeMap<String, BTreeMap<String, String>>> {
+ ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> {
let mut map = BTreeMap::new();
- let tails = store.tag_tails(KV_TAG).await?;
-
- if tails.is_empty() {
- return Ok(map);
- }
- let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
+ // TODO: maybe don't load the entire tag into memory to build the kv
+ // we can be smart about it and only load values since the last build
+ // or, iterate/paginate
+ let tagged = store.all_tagged(KV_TAG).await?;
- loop {
+ // iterate through all tags and play each KV record at a time
+ // this is "last write wins"
+ // probably good enough for now, but revisit in future
+ for record in tagged {
let decrypted = match record.version.as_str() {
KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
version => bail!("unknown version {version:?}"),
};
- let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
+ let kv = KvRecord::deserialize(&decrypted.data, KV_VERSION)?;
- let ns = map.entry(kv.namespace).or_insert_with(BTreeMap::new);
- ns.entry(kv.key).or_insert_with(|| kv.value);
+ let ns = map
+ .entry(kv.namespace.clone())
+ .or_insert_with(BTreeMap::new);
- if let Some(parent) = decrypted.parent {
- record = store.get(parent).await?;
- } else {
- break;
- }
+ ns.insert(kv.key.clone(), kv);
}
Ok(map)
@@ -261,19 +237,27 @@ mod tests {
let map = kv.build_kv(&store, &key).await.unwrap();
assert_eq!(
- map.get("test-kv")
+ *map.get("test-kv")
.expect("map namespace not set")
.get("foo")
.expect("map key not set"),
- "bar"
+ KvRecord {
+ namespace: String::from("test-kv"),
+ key: String::from("foo"),
+ value: String::from("bar")
+ }
);
assert_eq!(
- map.get("test-kv")
+ *map.get("test-kv")
.expect("map namespace not set")
.get("1")
.expect("map key not set"),
- "2"
+ KvRecord {
+ namespace: String::from("test-kv"),
+ key: String::from("1"),
+ value: String::from("2")
+ }
);
}
}
diff --git a/atuin-client/src/record/encryption.rs b/atuin-client/src/record/encryption.rs
index 3074a9c2..c2cdaa6a 100644
--- a/atuin-client/src/record/encryption.rs
+++ b/atuin-client/src/record/encryption.rs
@@ -1,5 +1,5 @@
use atuin_common::record::{
- AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId,
+ AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx,
};
use base64::{engine::general_purpose, Engine};
use eyre::{ensure, Context, Result};
@@ -170,10 +170,10 @@ struct AtuinFooter {
#[derive(Debug, Copy, Clone, Serialize)]
struct Assertions<'a> {
id: &'a RecordId,
+ idx: &'a RecordIdx,
version: &'a str,
tag: &'a str,
host: &'a HostId,
- parent: Option<&'a RecordId>,
}
impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
@@ -183,7 +183,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
version: ad.version,
tag: ad.tag,
host: ad.host,
- parent: ad.parent,
+ idx: ad.idx,
}
}
}
@@ -196,7 +196,10 @@ impl Assertions<'_> {
#[cfg(test)]
mod tests {
- use atuin_common::{record::Record, utils::uuid_v7};
+ use atuin_common::{
+ record::{Host, Record},
+ utils::uuid_v7,
+ };
use super::*;
@@ -209,7 +212,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -228,7 +231,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -252,7 +255,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -270,7 +273,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -294,7 +297,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -323,9 +326,10 @@ mod tests {
.id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
- .host(HostId(uuid_v7()))
+ .host(Host::new(HostId(uuid_v7())))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
+ .idx(0)
.build();
let encrypted = record.encrypt::<PASETO_V4>(&key);
@@ -345,15 +349,16 @@ mod tests {
.id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
- .host(HostId(uuid_v7()))
+ .host(Host::new(HostId(uuid_v7())))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
+ .idx(0)
.build();
let encrypted = record.encrypt::<PASETO_V4>(&key);
let mut enc1 = encrypted.clone();
- enc1.host = HostId(uuid_v7());
+ enc1.host = Host::new(HostId(uuid_v7()));
let _ = enc1
.decrypt::<PASETO_V4>(&key)
.expect_err("tampering with the host should result in auth failure");
diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs
index db709f20..8112aa96 100644
--- a/atuin-client/src/record/sqlite_store.rs
+++ b/atuin-client/src/record/sqlite_store.rs
@@ -8,17 +8,20 @@ use std::str::FromStr;
use async_trait::async_trait;
use eyre::{eyre, Result};
use fs_err as fs;
-use futures::TryStreamExt;
+
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Row,
};
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
+use atuin_common::record::{
+ EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
+};
use uuid::Uuid;
use super::store::Store;
+#[derive(Debug, Clone)]
pub struct SqliteStore {
pool: SqlitePool,
}
@@ -38,6 +41,7 @@ impl SqliteStore {
let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
.journal_mode(SqliteJournalMode::Wal)
+ .foreign_keys(true)
.create_if_missing(true);
let pool = SqlitePoolOptions::new().connect_with(opts).await?;
@@ -61,14 +65,14 @@ impl SqliteStore {
) -> Result<()> {
// In sqlite, we are "limited" to i64. But that is still fine, until 2262.
sqlx::query(
- "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
+ "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
- .bind(r.id.0.as_simple().to_string())
- .bind(r.host.0.as_simple().to_string())
+ .bind(r.id.0.as_hyphenated().to_string())
+ .bind(r.idx as i64)
+ .bind(r.host.id.0.as_hyphenated().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
- .bind(r.parent.map(|p| p.0.as_simple().to_string()))
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
@@ -79,20 +83,17 @@ impl SqliteStore {
}
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
+ let idx: i64 = row.get("idx");
let timestamp: i64 = row.get("timestamp");
// tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
- let parent: Option<&str> = row.get("parent");
-
- let parent = parent
- .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
Record {
id: RecordId(id),
- host: HostId(host),
- parent: parent.map(RecordId),
+ idx: idx as u64,
+ host: Host::new(HostId(host)),
timestamp: timestamp as u64,
tag: row.get("tag"),
version: row.get("version"),
@@ -122,8 +123,8 @@ impl Store for SqliteStore {
}
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
- let res = sqlx::query("select * from records where id = ?1")
- .bind(id.0.as_simple().to_string())
+ let res = sqlx::query("select * from store where store.id = ?1")
+ .bind(id.0.as_hyphenated().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
.await?;
@@ -131,20 +132,66 @@ impl Store for SqliteStore {
Ok(res)
}
- async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
- let res: (i64,) =
- sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
- .bind(host.0.as_simple().to_string())
+ async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ let res =
+ sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
+ .bind(host.0.as_hyphenated().to_string())
.bind(tag)
+ .map(Self::query_row)
.fetch_one(&self.pool)
+ .await;
+
+ match res {
+ Err(sqlx::Error::RowNotFound) => Ok(None),
+ Err(e) => Err(eyre!("an error occured: {}", e)),
+ Ok(record) => Ok(Some(record)),
+ }
+ }
+
+ async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ self.idx(host, tag, 0).await
+ }
+
+ async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
+ let last = self.last(host, tag).await?;
+
+ if let Some(last) = last {
+ return Ok(last.idx + 1);
+ }
+
+ return Ok(0);
+ }
+
+ async fn next(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ limit: u64,
+ ) -> Result<Vec<Record<EncryptedData>>> {
+ let res =
+ sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4")
+ .bind(idx as i64)
+ .bind(host.0.as_hyphenated().to_string())
+ .bind(tag)
+ .bind(limit as i64)
+ .map(Self::query_row)
+ .fetch_all(&self.pool)
.await?;
- Ok(res.0 as u64)
+ Ok(res)
}
- async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query("select * from records where parent = ?1")
- .bind(record.id.0.as_simple().to_string())
+ async fn idx(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ ) -> Result<Option<Record<EncryptedData>>> {
+ let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
+ .bind(idx as i64)
+ .bind(host.0.as_hyphenated().to_string())
+ .bind(tag)
.map(Self::query_row)
.fetch_one(&self.pool)
.await;
@@ -156,58 +203,36 @@ impl Store for SqliteStore {
}
}
- async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
- )
- .bind(host.0.as_simple().to_string())
- .bind(tag)
- .map(Self::query_row)
- .fetch_optional(&self.pool)
- .await?;
+ async fn status(&self) -> Result<RecordStatus> {
+ let mut status = RecordStatus::new();
- Ok(res)
- }
+ let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
+ sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
+ .fetch_all(&self.pool)
+ .await;
- async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
- )
- .bind(tag)
- .bind(host.0.as_simple().to_string())
- .map(Self::query_row)
- .fetch_optional(&self.pool)
- .await?;
+ let res = match res {
+ Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
+ Ok(v) => v,
+ };
- Ok(res)
- }
+ for i in res {
+ let host = HostId(
+ Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
+ );
- async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
- )
- .bind(tag)
- .map(Self::query_row)
- .fetch_all(&self.pool)
- .await?;
+ status.set_raw(host, i.1, i.2 as u64);
+ }
- Ok(res)
+ Ok(status)
}
- async fn tail_records(&self) -> Result<RecordIndex> {
- let res = sqlx::query(
- "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
- )
- .map(|row: SqliteRow| {
- let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
- let tag: String= row.get("tag");
- let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
-
- (HostId(host), tag, RecordId(id))
- })
- .fetch(&self.pool)
- .try_collect()
- .await?;
+ async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
+ let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc")
+ .bind(tag)
+ .map(Self::query_row)
+ .fetch_all(&self.pool)
+ .await?;
Ok(res)
}
@@ -215,7 +240,7 @@ impl Store for SqliteStore {
#[cfg(test)]
mod tests {
- use atuin_common::record::{EncryptedData, HostId, Record};
+ use atuin_common::record::{EncryptedData, Host, HostId, Record};
use crate::record::{encryption::PASETO_V4, store::Store};
@@ -223,13 +248,14 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
- .host(HostId(atuin_common::utils::uuid_v7()))
+ .host(Host::new(HostId(atuin_common::utils::uuid_v7())))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: "1234".into(),
content_encryption_key: "1234".into(),
})
+ .idx(0)
.build()
}
@@ -264,13 +290,49 @@ mod tests {
}
#[tokio::test]
+ async fn last() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let last = db
+ .last(record.host.id, record.tag.as_str())
+ .await
+ .expect("failed to get store len");
+
+ assert_eq!(
+ last.unwrap().id,
+ record.id,
+ "expected to get back the same record that was inserted"
+ );
+ }
+
+ #[tokio::test]
+ async fn first() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let first = db
+ .first(record.host.id, record.tag.as_str())
+ .await
+ .expect("failed to get store len");
+
+ assert_eq!(
+ first.unwrap().id,
+ record.id,
+ "expected to get back the same record that was inserted"
+ );
+ }
+
+ #[tokio::test]
async fn len() {
let db = SqliteStore::new(":memory:").await.unwrap();
let record = test_record();
db.push(&record).await.unwrap();
let len = db
- .len(record.host, record.tag.as_str())
+ .len(record.host.id, record.tag.as_str())
.await
.expect("failed to get store len");
@@ -290,8 +352,8 @@ mod tests {
db.push(&first).await.unwrap();
db.push(&second).await.unwrap();
- let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
- let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
+ let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
+ let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert");
@@ -305,14 +367,12 @@ mod tests {
db.push(&tail).await.expect("failed to push record");
for _ in 1..100 {
- tail = tail
- .new_child(vec![1, 2, 3, 4])
- .encrypt::<PASETO_V4>(&[0; 32]);
+ tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
db.push(&tail).await.unwrap();
}
assert_eq!(
- db.len(tail.host, tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
@@ -328,50 +388,16 @@ mod tests {
records.push(tail.clone());
for _ in 1..10000 {
- tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
+ tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
records.push(tail.clone());
}
db.push_batch(records.iter()).await.unwrap();
assert_eq!(
- db.len(tail.host, tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
10000,
"failed to insert 10k records"
);
}
-
- #[tokio::test]
- async fn test_chain() {
- let db = SqliteStore::new(":memory:").await.unwrap();
-
- let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(1000);
-
- let mut tail = test_record();
- records.push(tail.clone());
-
- for _ in 1..1000 {
- tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
- records.push(tail.clone());
- }
-
- db.push_batch(records.iter()).await.unwrap();
-
- let mut record = db
- .head(tail.host, tail.tag.as_str())
- .await
- .expect("in memory sqlite should not fail")
- .expect("entry exists");
-
- let mut count = 1;
-
- while let Some(next) = db.next(&record).await.unwrap() {
- assert_eq!(record.id, next.clone().parent.unwrap());
- record = next;
-
- count += 1;
- }
-
- assert_eq!(count, 1000);
- }
}
diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs
index 45d554ef..a5c156d6 100644
--- a/atuin-client/src/record/store.rs
+++ b/atuin-client/src/record/store.rs
@@ -1,8 +1,7 @@
use async_trait::async_trait;
use eyre::Result;
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
-
+use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
/// A record store stores records
/// In more detail - we tend to need to process this into _another_ format to actually query it.
/// As is, the record store is intended as the source of truth for arbitratry data, which could
@@ -23,19 +22,30 @@ pub trait Store {
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
- /// Get the record that follows this record
- async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
+ async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
- /// Get the first record for a given host and tag
- async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ /// Get the next `limit` records, after and including the given index
+ async fn next(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ limit: u64,
+ ) -> Result<Vec<Record<EncryptedData>>>;
- /// Get the last record for a given host and tag
- async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ /// Get the first record for a given host and tag
+ async fn idx(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ ) -> Result<Option<Record<EncryptedData>>>;
- // Get the last record for all hosts for a given tag, useful for the read path of apps.
- async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
+ async fn status(&self) -> Result<RecordStatus>;
- // Get the latest host/tag/record tuple for every set in the store. useful for building an
- // index
- async fn tail_records(&self) -> Result<RecordIndex>;
+ /// Get every start record for a given tag, regardless of host.
+ /// Useful when actually operating on synchronized data, and will often have conflict
+ /// resolution applied.
+ async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
}
diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs
index 56be0638..2694e0ff 100644
--- a/atuin-client/src/record/sync.rs
+++ b/atuin-client/src/record/sync.rs
@@ -1,27 +1,51 @@
// do a sync :O
+use std::cmp::Ordering;
+
use eyre::Result;
+use thiserror::Error;
use super::store::Store;
use crate::{api_client::Client, settings::Settings};
-use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
+use atuin_common::record::{Diff, HostId, RecordIdx, RecordStatus};
+
+#[derive(Error, Debug)]
+pub enum SyncError {
+ #[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
+ LocalAheadOtherHost,
+
+ #[error("an issue with the local database occured")]
+ LocalStoreError,
+
+ #[error("something has gone wrong with the sync logic: {msg:?}")]
+ SyncLogicError { msg: String },
+
+ #[error("a request to the sync server failed")]
+ RemoteRequestError,
+}
#[derive(Debug, Eq, PartialEq)]
pub enum Operation {
- // Either upload or download until the tail matches the below
+ // Either upload or download until the states matches the below
Upload {
- tail: RecordId,
+ local: RecordIdx,
+ remote: Option<RecordIdx>,
host: HostId,
tag: String,
},
Download {
- tail: RecordId,
+ local: Option<RecordIdx>,
+ remote: RecordIdx,
+ host: HostId,
+ tag: String,
+ },
+ Noop {
host: HostId,
tag: String,
},
}
-pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
+pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
@@ -29,8 +53,8 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
settings.network_timeout,
)?;
- let local_index = store.tail_records().await?;
- let remote_index = client.record_index().await?;
+ let local_index = store.status().await?;
+ let remote_index = client.record_status().await?;
let diff = local_index.diff(&remote_index);
@@ -41,39 +65,57 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
// In theory this could be done as a part of the diffing stage, but it's easier to reason
// about and test this way
-pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
+pub async fn operations(
+ diffs: Vec<Diff>,
+ _store: &impl Store,
+) -> Result<Vec<Operation>, SyncError> {
let mut operations = Vec::with_capacity(diffs.len());
for diff in diffs {
- // First, try to fetch the tail
- // If it exists locally, then that means we need to update the remote
- // host until it has the same tail. Ie, upload.
- // If it does not exist locally, that means remote is ahead of us.
- // Therefore, we need to download until our local tail matches
- let record = store.get(diff.tail).await;
-
- let op = if record.is_ok() {
- // if local has the ID, then we should find the actual tail of this
- // store, so we know what we need to update the remote to.
- let tail = store
- .tail(diff.host, diff.tag.as_str())
- .await?
- .expect("failed to fetch last record, expected tag/host to exist");
-
- // TODO(ellie) update the diffing so that it stores the context of the current tail
- // that way, we can determine how much we need to upload.
- // For now just keep uploading until tails match
+ let op = match (diff.local, diff.remote) {
+ // We both have it! Could be either. Compare.
+ (Some(local), Some(remote)) => match local.cmp(&remote) {
+ Ordering::Equal => Operation::Noop {
+ host: diff.host,
+ tag: diff.tag,
+ },
+ Ordering::Greater => Operation::Upload {
+ local,
+ remote: Some(remote),
+ host: diff.host,
+ tag: diff.tag,
+ },
+ Ordering::Less => Operation::Download {
+ local: Some(local),
+ remote,
+ host: diff.host,
+ tag: diff.tag,
+ },
+ },
- Operation::Upload {
- tail: tail.id,
+ // Remote has it, we don't. Gotta be download
+ (None, Some(remote)) => Operation::Download {
+ local: None,
+ remote,
host: diff.host,
tag: diff.tag,
- }
- } else {
- Operation::Download {
- tail: diff.tail,
+ },
+
+ // We have it, remote doesn't. Gotta be upload.
+ (Some(local), None) => Operation::Upload {
+ local,
+ remote: None,
host: diff.host,
tag: diff.tag,
+ },
+
+ // something is pretty fucked.
+ (None, None) => {
+ return Err(SyncError::SyncLogicError {
+ msg: String::from(
+ "diff has nothing for local or remote - (host, tag) does not exist",
+ ),
+ })
}
};
@@ -86,149 +128,130 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
// with the same properties
operations.sort_by_key(|op| match op {
- Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
- Operation::Download { tail, host, .. } => ("download", *host, *tail),
+ Operation::Noop { host, tag } => (0, *host, tag.clone()),
+
+ Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
+
+ Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
Ok(operations)
}
async fn sync_upload(
- store: &mut impl Store,
- remote_index: &RecordIndex,
+ store: &impl Store,
client: &Client<'_>,
- op: (HostId, String, RecordId),
-) -> Result<i64> {
+ host: HostId,
+ tag: String,
+ local: RecordIdx,
+ remote: Option<RecordIdx>,
+) -> Result<i64, SyncError> {
+ let remote = remote.unwrap_or(0);
+ let expected = local - remote;
let upload_page_size = 100;
- let mut total = 0;
-
- // so. we have an upload operation, with the tail representing the state
- // we want to get the remote to
- let current_tail = remote_index.get(op.0, op.1.clone());
+ let mut progress = 0;
println!(
- "Syncing local {:?}/{}/{:?}, remote has {:?}",
- op.0, op.1, op.2, current_tail
+ "Uploading {} records to {}/{}",
+ expected,
+ host.0.as_simple(),
+ tag
);
- let start = if let Some(current_tail) = current_tail {
- current_tail
- } else {
- store
- .head(op.0, op.1.as_str())
+ // preload with the first entry if remote does not know of this store
+ loop {
+ let page = store
+ .next(host, tag.as_str(), remote + progress, upload_page_size)
.await
- .expect("failed to fetch host/tag head")
- .expect("host/tag not in current index")
- .id
- };
+ .map_err(|e| {
+ error!("failed to read upload page: {e:?}");
- debug!("starting push to remote from: {:?}", start);
+ SyncError::LocalStoreError
+ })?;
- // we have the start point for sync. it is either the head of the store if
- // the remote has no data for it, or the tail that the remote has
- // we need to iterate from the remote tail, and keep going until
- // remote tail = current local tail
+ client.post_records(&page).await.map_err(|e| {
+ error!("failed to post records: {e:?}");
- let mut record = if current_tail.is_some() {
- let r = store.get(start).await.unwrap();
- store.next(&r).await?
- } else {
- Some(store.get(start).await.unwrap())
- };
+ SyncError::RemoteRequestError
+ })?;
- let mut buf = Vec::with_capacity(upload_page_size);
-
- while let Some(r) = record {
- if buf.len() < upload_page_size {
- buf.push(r.clone());
- } else {
- client.post_records(&buf).await?;
+ println!(
+ "uploaded {} to remote, progress {}/{}",
+ page.len(),
+ progress,
+ expected
+ );
+ progress += page.len() as u64;
- // can we reset what we have? len = 0 but keep capacity
- buf = Vec::with_capacity(upload_page_size);
+ if progress >= expected {
+ break;
}
- record = store.next(&r).await?;
-
- total += 1;
}
- if !buf.is_empty() {
- client.post_records(&buf).await?;
- }
-
- Ok(total)
+ Ok(progress as i64)
}
async fn sync_download(
- store: &mut impl Store,
- remote_index: &RecordIndex,
+ store: &impl Store,
client: &Client<'_>,
- op: (HostId, String, RecordId),
-) -> Result<i64> {
- // TODO(ellie): implement variable page sizing like on history sync
- let download_page_size = 1000;
+ host: HostId,
+ tag: String,
+ local: Option<RecordIdx>,
+ remote: RecordIdx,
+) -> Result<i64, SyncError> {
+ let local = local.unwrap_or(0);
+ let expected = remote - local;
+ let download_page_size = 100;
+ let mut progress = 0;
- let mut total = 0;
+ println!(
+ "Downloading {} records from {}/{}",
+ expected,
+ host.0.as_simple(),
+ tag
+ );
- // We know that the remote is ahead of us, so let's keep downloading until both
- // 1) The remote stops returning full pages
- // 2) The tail equals what we expect
- //
- // If (1) occurs without (2), then something is wrong with our index calculation
- // and we should bail.
- let remote_tail = remote_index
- .get(op.0, op.1.clone())
- .expect("remote index does not contain expected tail during download");
- let local_tail = store.tail(op.0, op.1.as_str()).await?;
- //
- // We expect that the operations diff will represent the desired state
- // In this case, that contains the remote tail.
- assert_eq!(remote_tail, op.2);
+ // preload with the first entry if remote does not know of this store
+ loop {
+ let page = client
+ .next_records(host, tag.clone(), local + progress, download_page_size)
+ .await
+ .map_err(|_| SyncError::RemoteRequestError)?;
- println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2);
+ store
+ .push_batch(page.iter())
+ .await
+ .map_err(|_| SyncError::LocalStoreError)?;
- let mut records = client
- .next_records(
- op.0,
- op.1.clone(),
- local_tail.map(|r| r.id),
- download_page_size,
- )
- .await?;
+ println!(
+ "downloaded {} records from remote, progress {}/{}",
+ page.len(),
+ progress,
+ expected
+ );
- while !records.is_empty() {
- total += std::cmp::min(download_page_size, records.len() as u64);
- store.push_batch(records.iter()).await?;
+ progress += page.len() as u64;
- if records.last().unwrap().id == remote_tail {
+ if progress >= expected {
break;
}
-
- records = client
- .next_records(
- op.0,
- op.1.clone(),
- records.last().map(|r| r.id),
- download_page_size,
- )
- .await?;
}
- Ok(total as i64)
+ Ok(progress as i64)
}
pub async fn sync_remote(
operations: Vec<Operation>,
- remote_index: &RecordIndex,
- local_store: &mut impl Store,
+ local_store: &impl Store,
settings: &Settings,
-) -> Result<(i64, i64)> {
+) -> Result<(i64, i64), SyncError> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
settings.network_connect_timeout,
settings.network_timeout,
- )?;
+ )
+ .expect("failed to create client");
let mut uploaded = 0;
let mut downloaded = 0;
@@ -236,14 +259,23 @@ pub async fn sync_remote(
// this can totally run in parallel, but lets get it working first
for i in operations {
match i {
- Operation::Upload { tail, host, tag } => {
- uploaded +=
- sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
- }
- Operation::Download { tail, host, tag } => {
- downloaded +=
- sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
+ Operation::Upload {
+ host,
+ tag,
+ local,
+ remote,
+ } => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?,
+
+ Operation::Download {
+ host,
+ tag,
+ local,
+ remote,
+ } => {
+ downloaded += sync_download(local_store, &client, host, tag, local, remote).await?
}
+
+ Operation::Noop { .. } => continue,
}
}
@@ -264,13 +296,16 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
- .host(HostId(atuin_common::utils::uuid_v7()))
+ .host(atuin_common::record::Host::new(HostId(
+ atuin_common::utils::uuid_v7(),
+ )))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: String::new(),
content_encryption_key: String::new(),
})
+ .idx(0)
.build()
}
@@ -296,8 +331,8 @@ mod tests {
remote_store.push(&i).await.unwrap();
}
- let local_index = local_store.tail_records().await.unwrap();
- let remote_index = remote_store.tail_records().await.unwrap();
+ let local_index = local_store.status().await.unwrap();
+ let remote_index = remote_store.status().await.unwrap();
let diff = local_index.diff(&remote_index);
@@ -320,9 +355,10 @@ mod tests {
assert_eq!(
operations[0],
Operation::Upload {
- host: record.host,
+ host: record.host.id,
tag: record.tag,
- tail: record.id
+ local: record.idx,
+ remote: None,
}
);
}
@@ -333,12 +369,14 @@ mod tests {
// another. One upload, one download
let shared_record = test_record();
-
let remote_ahead = test_record();
+
let local_ahead = shared_record
- .new_child(vec![1, 2, 3])
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
+ assert_eq!(local_ahead.idx, 1);
+
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store
@@ -350,15 +388,19 @@ mod tests {
assert_eq!(
operations,
vec![
- Operation::Download {
- tail: remote_ahead.id,
- host: remote_ahead.host,
- tag: remote_ahead.tag,
- },
+ // Or in otherwords, local is ahead by one
Operation::Upload {
- tail: local_ahead.id,
- host: local_ahead.host,
+ host: local_ahead.host.id,
tag: local_ahead.tag,
+ local: 1,
+ remote: Some(0),
+ },
+ // Or in other words, remote knows of a record in an entirely new store (tag)
+ Operation::Download {
+ host: remote_ahead.host.id,
+ tag: remote_ahead.tag,
+ local: None,
+ remote: 0,
},
]
);
@@ -371,66 +413,160 @@ mod tests {
// One known only by remote
let shared_record = test_record();
+ let local_only = test_record();
+
+ let local_only_20 = test_record();
+ let local_only_21 = local_only_20
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let local_only_22 = local_only_21
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let local_only_23 = local_only_22
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
- let remote_known = test_record();
- let local_known = test_record();
+ let remote_only = test_record();
+
+ let remote_only_20 = test_record();
+ let remote_only_21 = remote_only_20
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_22 = remote_only_21
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_23 = remote_only_22
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_24 = remote_only_23
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
let second_shared = test_record();
let second_shared_remote_ahead = second_shared
- .new_child(vec![1, 2, 3])
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let second_shared_remote_ahead2 = second_shared_remote_ahead
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
- let local_ahead = shared_record
- .new_child(vec![1, 2, 3])
+ let third_shared = test_record();
+ let third_shared_local_ahead = third_shared
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let third_shared_local_ahead2 = third_shared_local_ahead
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+
+ let fourth_shared = test_record();
+ let fourth_shared_remote_ahead = fourth_shared
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local = vec![
shared_record.clone(),
second_shared.clone(),
- local_known.clone(),
- local_ahead.clone(),
+ third_shared.clone(),
+ fourth_shared.clone(),
+ fourth_shared_remote_ahead.clone(),
+ // single store, only local has it
+ local_only.clone(),
+ // bigger store, also only known by local
+ local_only_20.clone(),
+ local_only_21.clone(),
+ local_only_22.clone(),
+ local_only_23.clone(),
+ // another shared store, but local is ahead on this one
+ third_shared_local_ahead.clone(),
+ third_shared_local_ahead2.clone(),
];
let remote = vec![
+ remote_only.clone(),
+ remote_only_20.clone(),
+ remote_only_21.clone(),
+ remote_only_22.clone(),
+ remote_only_23.clone(),
+ remote_only_24.clone(),
shared_record.clone(),
second_shared.clone(),
+ third_shared.clone(),
second_shared_remote_ahead.clone(),
- remote_known.clone(),
+ second_shared_remote_ahead2.clone(),
+ fourth_shared.clone(),
+ fourth_shared_remote_ahead.clone(),
+ fourth_shared_remote_ahead2.clone(),
]; // remote knows about the already-synced, and one new record in a new store
let (store, diff) = build_test_diff(local, remote).await;
let operations = sync::operations(diff, &store).await.unwrap();
- assert_eq!(operations.len(), 4);
+ assert_eq!(operations.len(), 7);
let mut result_ops = vec![
+ // We started with a shared record, but the remote knows of two newer records in the
+ // same store
+ Operation::Download {
+ local: Some(0),
+ remote: 2,
+ host: second_shared_remote_ahead.host.id,
+ tag: second_shared_remote_ahead.tag,
+ },
+ // We have a shared record, local knows of the first two but not the last
+ Operation::Download {
+ local: Some(1),
+ remote: 2,
+ host: fourth_shared_remote_ahead2.host.id,
+ tag: fourth_shared_remote_ahead2.tag,
+ },
+ // Remote knows of a store with a single record that local does not have
Operation::Download {
- tail: remote_known.id,
- host: remote_known.host,
- tag: remote_known.tag,
+ local: None,
+ remote: 0,
+ host: remote_only.host.id,
+ tag: remote_only.tag,
},
+ // Remote knows of a store with a bunch of records that local does not have
Operation::Download {
- tail: second_shared_remote_ahead.id,
- host: second_shared.host,
- tag: second_shared.tag,
+ local: None,
+ remote: 4,
+ host: remote_only_20.host.id,
+ tag: remote_only_20.tag,
},
+ // Local knows of a record in a store that remote does not have
Operation::Upload {
- tail: local_ahead.id,
- host: local_ahead.host,
- tag: local_ahead.tag,
+ local: 0,
+ remote: None,
+ host: local_only.host.id,
+ tag: local_only.tag,
},
+ // Local knows of 4 records in a store that remote does not have
Operation::Upload {
- tail: local_known.id,
- host: local_known.host,
- tag: local_known.tag,
+ local: 3,
+ remote: None,
+ host: local_only_20.host.id,
+ tag: local_only_20.tag,
+ },
+ // Local knows of 2 more records in a shared store that remote only has one of
+ Operation::Upload {
+ local: 2,
+ remote: Some(0),
+ host: third_shared.host.id,
+ tag: third_shared.tag,
},
];
result_ops.sort_by_key(|op| match op {
- Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
- Operation::Download { tail, host, .. } => ("download", *host, *tail),
+ Operation::Noop { host, tag } => (0, *host, tag.clone()),
+
+ Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
+
+ Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
- assert_eq!(operations, result_ops);
+ assert_eq!(result_ops, operations);
}
}
diff --git a/atuin-client/src/settings.rs b/atuin-client/src/settings.rs
index 7e251550..0798a890 100644
--- a/atuin-client/src/settings.rs
+++ b/atuin-client/src/settings.rs
@@ -173,6 +173,11 @@ impl Default for Stats {
}
}
+#[derive(Clone, Debug, Deserialize, Default)]
+pub struct Sync {
+ pub records: bool,
+}
+
#[derive(Clone, Debug, Deserialize)]
pub struct Settings {
pub dialect: Dialect,
@@ -217,6 +222,9 @@ pub struct Settings {
#[serde(default)]
pub stats: Stats,
+ #[serde(default)]
+ pub sync: Sync,
+
// This is automatically loaded when settings is created. Do not set in
// config! Keep secrets and settings apart.
#[serde(skip)]
@@ -427,6 +435,7 @@ impl Settings {
// muscle memory.
// New users will get the new default, that is more similar to what they are used to.
.set_default("enter_accept", false)?
+ .set_default("sync.records", false)?
.add_source(
Environment::with_prefix("atuin")
.prefix_separator("_")
diff --git a/atuin-common/src/record.rs b/atuin-common/src/record.rs
index cba0917a..e6ce2647 100644
--- a/atuin-common/src/record.rs
+++ b/atuin-common/src/record.rs
@@ -14,13 +14,34 @@ pub struct EncryptedData {
pub content_encryption_key: String,
}
-#[derive(Debug, PartialEq)]
+#[derive(Debug, PartialEq, PartialOrd, Ord, Eq)]
pub struct Diff {
pub host: HostId,
pub tag: String,
- pub tail: RecordId,
+ pub local: Option<RecordIdx>,
+ pub remote: Option<RecordIdx>,
}
+#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
+pub struct Host {
+ pub id: HostId,
+ pub name: String,
+}
+
+impl Host {
+ pub fn new(id: HostId) -> Self {
+ Host {
+ id,
+ name: String::new(),
+ }
+ }
+}
+
+new_uuid!(RecordId);
+new_uuid!(HostId);
+
+pub type RecordIdx = u64;
+
/// A single record stored inside of our local database
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
pub struct Record<Data> {
@@ -28,18 +49,14 @@ pub struct Record<Data> {
#[builder(default = RecordId(crate::utils::uuid_v7()))]
pub id: RecordId,
+ /// The integer record ID. This is only unique per (host, tag).
+ pub idx: RecordIdx,
+
/// The unique ID of the host.
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
// as strings. I would rather avoid normalization, so store as UUID binary instead of
// encoding to a string and wasting much more storage.
- pub host: HostId,
-
- /// The ID of the parent entry
- // A store is technically just a double linked list
- // We can do some cheating with the timestamps, but should not rely upon them.
- // Clocks are tricksy.
- #[builder(default)]
- pub parent: Option<RecordId>,
+ pub host: Host,
/// The creation time in nanoseconds since unix epoch
#[builder(default = time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)]
@@ -56,25 +73,22 @@ pub struct Record<Data> {
pub data: Data,
}
-new_uuid!(RecordId);
-new_uuid!(HostId);
-
/// Extra data from the record that should be encoded in the data
#[derive(Debug, Copy, Clone)]
pub struct AdditionalData<'a> {
pub id: &'a RecordId,
+ pub idx: &'a u64,
pub version: &'a str,
pub tag: &'a str,
pub host: &'a HostId,
- pub parent: Option<&'a RecordId>,
}
impl<Data> Record<Data> {
- pub fn new_child(&self, data: Vec<u8>) -> Record<DecryptedData> {
+ pub fn append(&self, data: Vec<u8>) -> Record<DecryptedData> {
Record::builder()
- .host(self.host)
+ .host(self.host.clone())
.version(self.version.clone())
- .parent(Some(self.id))
+ .idx(self.idx + 1)
.tag(self.tag.clone())
.data(DecryptedData(data))
.build()
@@ -84,74 +98,76 @@ impl<Data> Record<Data> {
/// An index representing the current state of the record stores
/// This can be both remote, or local, and compared in either direction
#[derive(Debug, Serialize, Deserialize)]
-pub struct RecordIndex {
- // A map of host -> tag -> tail
- pub hosts: HashMap<HostId, HashMap<String, RecordId>>,
+pub struct RecordStatus {
+ // A map of host -> tag -> max(idx)
+ pub hosts: HashMap<HostId, HashMap<String, RecordIdx>>,
}
-impl Default for RecordIndex {
+impl Default for RecordStatus {
fn default() -> Self {
Self::new()
}
}
-impl Extend<(HostId, String, RecordId)> for RecordIndex {
- fn extend<T: IntoIterator<Item = (HostId, String, RecordId)>>(&mut self, iter: T) {
- for (host, tag, tail_id) in iter {
- self.set_raw(host, tag, tail_id);
+impl Extend<(HostId, String, RecordIdx)> for RecordStatus {
+ fn extend<T: IntoIterator<Item = (HostId, String, RecordIdx)>>(&mut self, iter: T) {
+ for (host, tag, tail_idx) in iter {
+ self.set_raw(host, tag, tail_idx);
}
}
}
-impl RecordIndex {
- pub fn new() -> RecordIndex {
- RecordIndex {
+impl RecordStatus {
+ pub fn new() -> RecordStatus {
+ RecordStatus {
hosts: HashMap::new(),
}
}
/// Insert a new tail record into the store
pub fn set(&mut self, tail: Record<DecryptedData>) {
- self.set_raw(tail.host, tail.tag, tail.id)
+ self.set_raw(tail.host.id, tail.tag, tail.idx)
}
- pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordId) {
+ pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) {
self.hosts.entry(host).or_default().insert(tag, tail_id);
}
- pub fn get(&self, host: HostId, tag: String) -> Option<RecordId> {
+ pub fn get(&self, host: HostId, tag: String) -> Option<RecordIdx> {
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
}
/// Diff this index with another, likely remote index.
/// The two diffs can then be reconciled, and the optimal change set calculated
/// Returns a tuple, with (host, tag, Option(OTHER))
- /// OTHER is set to the value of the tail on the other machine. For example, if the
- /// other machine has a different tail, it will be the differing tail. This is useful to
- /// check if the other index is ahead of us, or behind.
- /// If the other index does not have the (host, tag) pair, then the other value will be None.
+ /// OTHER is set to the value of the idx on the other machine. If it is greater than our index,
+ /// then we need to do some downloading. If it is smaller, then we need to do some uploading
+ /// Note that we cannot upload if we are not the owner of the record store - hosts can only
+ /// write to their own store.
pub fn diff(&self, other: &Self) -> Vec<Diff> {
let mut ret = Vec::new();
// First, we check if other has everything that self has
for (host, tag_map) in self.hosts.iter() {
- for (tag, tail) in tag_map.iter() {
+ for (tag, idx) in tag_map.iter() {
match other.get(*host, tag.clone()) {
// The other store is all up to date! No diff.
- Some(t) if t.eq(tail) => continue,
+ Some(t) if t.eq(idx) => continue,
- // The other store does exist, but it is either ahead or behind us. A diff regardless
+ // The other store does exist, and it is either ahead or behind us. A diff regardless
Some(t) => ret.push(Diff {
host: *host,
tag: tag.clone(),
- tail: t,
+ local: Some(*idx),
+ remote: Some(t),
}),
// The other store does not exist :O
None => ret.push(Diff {
host: *host,
tag: tag.clone(),
- tail: *tail,
+ local: Some(*idx),
+ remote: None,
}),
};
}
@@ -162,7 +178,7 @@ impl RecordIndex {
// account for that!
for (host, tag_map) in other.hosts.iter() {
- for (tag, tail) in tag_map.iter() {
+ for (tag, idx) in tag_map.iter() {
match self.get(*host, tag.clone()) {
// If we have this host/tag combo, the comparison and diff will have already happened above
Some(_) => continue,
@@ -170,13 +186,15 @@ impl RecordIndex {
None => ret.push(Diff {
host: *host,
tag: tag.clone(),
- tail: *tail,
+ remote: Some(*idx),
+ local: None,
}),
};
}
}
- ret.sort_by(|a, b| (a.host, a.tag.clone(), a.tail).cmp(&(b.host, b.tag.clone(), b.tail)));
+ // Stability is a nice property to have
+ ret.sort();
ret
}
}
@@ -201,14 +219,14 @@ impl Record<DecryptedData> {
id: &self.id,
version: &self.version,
tag: &self.tag,
- host: &self.host,
- parent: self.parent.as_ref(),
+ host: &self.host.id,
+ idx: &self.idx,
};
Record {
data: E::encrypt(self.data, ad, key),
id: self.id,
host: self.host,
- parent: self.parent,
+ idx: self.idx,
timestamp: self.timestamp,
version: self.version,
tag: self.tag,
@@ -222,14 +240,14 @@ impl Record<EncryptedData> {
id: &self.id,
version: &self.version,
tag: &self.tag,
- host: &self.host,
- parent: self.parent.as_ref(),
+ host: &self.host.id,
+ idx: &self.idx,
};
Ok(Record {
data: E::decrypt(self.data, ad, key)?,
id: self.id,
host: self.host,
- parent: self.parent,
+ idx: self.idx,
timestamp: self.timestamp,
version: self.version,
tag: self.tag,
@@ -245,14 +263,14 @@ impl Record<EncryptedData> {
id: &self.id,
version: &self.version,
tag: &self.tag,
- host: &self.host,
- parent: self.parent.as_ref(),
+ host: &self.host.id,
+ idx: &self.idx,
};
Ok(Record {
data: E::re_encrypt(self.data, ad, old_key, new_key)?,
id: self.id,
host: self.host,
- parent: self.parent,
+ idx: self.idx,
timestamp: self.timestamp,
version: self.version,
tag: self.tag,
@@ -262,31 +280,32 @@ impl Record<EncryptedData> {
#[cfg(test)]
mod tests {
- use crate::record::HostId;
+ use crate::record::{Host, HostId};
- use super::{DecryptedData, Diff, Record, RecordIndex};
+ use super::{DecryptedData, Diff, Record, RecordStatus};
use pretty_assertions::assert_eq;
fn test_record() -> Record<DecryptedData> {
Record::builder()
- .host(HostId(crate::utils::uuid_v7()))
+ .host(Host::new(HostId(crate::utils::uuid_v7())))
.version("v1".into())
.tag(crate::utils::uuid_v7().simple().to_string())
.data(DecryptedData(vec![0, 1, 2, 3]))
+ .idx(0)
.build()
}
#[test]
fn record_index() {
- let mut index = RecordIndex::new();
+ let mut index = RecordStatus::new();
let record = test_record();
index.set(record.clone());
- let tail = index.get(record.host, record.tag);
+ let tail = index.get(record.host.id, record.tag);
assert_eq!(
- record.id,
+ record.idx,
tail.expect("tail not in store"),
"tail in store did not match"
);
@@ -294,17 +313,17 @@ mod tests {
#[test]
fn record_index_overwrite() {
- let mut index = RecordIndex::new();
+ let mut index = RecordStatus::new();
let record = test_record();
- let child = record.new_child(vec![1, 2, 3]);
+ let child = record.append(vec![1, 2, 3]);
index.set(record.clone());
index.set(child.clone());
- let tail = index.get(record.host, record.tag);
+ let tail = index.get(record.host.id, record.tag);
assert_eq!(
- child.id,
+ child.idx,
tail.expect("tail not in store"),
"tail in store did not match"
);
@@ -314,8 +333,8 @@ mod tests {
fn record_index_no_diff() {
// Here, they both have the same version and should have no diff
- let mut index1 = RecordIndex::new();
- let mut index2 = RecordIndex::new();
+ let mut index1 = RecordStatus::new();
+ let mut index2 = RecordStatus::new();
let record1 = test_record();
@@ -331,11 +350,11 @@ mod tests {
fn record_index_single_diff() {
// Here, they both have the same stores, but one is ahead by a single record
- let mut index1 = RecordIndex::new();
- let mut index2 = RecordIndex::new();
+ let mut index1 = RecordStatus::new();
+ let mut index2 = RecordStatus::new();
let record1 = test_record();
- let record2 = record1.new_child(vec![1, 2, 3]);
+ let record2 = record1.append(vec![1, 2, 3]);
index1.set(record1);
index2.set(record2.clone());
@@ -346,9 +365,10 @@ mod tests {
assert_eq!(
diff[0],
Diff {
- host: record2.host,
+ host: record2.host.id,
tag: record2.tag,
- tail: record2.id
+ remote: Some(1),
+ local: Some(0)
}
);
}
@@ -356,14 +376,14 @@ mod tests {
#[test]
fn record_index_multi_diff() {
// A much more complex case, with a bunch more checks
- let mut index1 = RecordIndex::new();
- let mut index2 = RecordIndex::new();
+ let mut index1 = RecordStatus::new();
+ let mut index2 = RecordStatus::new();
let store1record1 = test_record();
- let store1record2 = store1record1.new_child(vec![1, 2, 3]);
+ let store1record2 = store1record1.append(vec![1, 2, 3]);
let store2record1 = test_record();
- let store2record2 = store2record1.new_child(vec![1, 2, 3]);
+ let store2record2 = store2record1.append(vec![1, 2, 3]);
let store3record1 = test_record();
diff --git a/atuin-server-database/src/lib.rs b/atuin-server-database/src/lib.rs
index d529655e..9b154ea1 100644
--- a/atuin-server-database/src/lib.rs
+++ b/atuin-server-database/src/lib.rs
@@ -14,7 +14,7 @@ use self::{
models::{History, NewHistory, NewSession, NewUser, Session, User},
};
use async_trait::async_trait;
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
+use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
use serde::{de::DeserializeOwned, Serialize};
use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
use tracing::instrument;
@@ -68,12 +68,12 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
user: &User,
host: HostId,
tag: String,
- start: Option<RecordId>,
+ start: Option<RecordIdx>,
count: u64,
) -> DbResult<Vec<Record<EncryptedData>>>;
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
- async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>;
+ async fn status(&self, user: &User) -> DbResult<RecordStatus>;
async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
-> DbResult<i64>;
diff --git a/atuin-server-postgres/migrations/20231202170508_create-store.sql b/atuin-server-postgres/migrations/20231202170508_create-store.sql
new file mode 100644
index 00000000..ffb57966
--- /dev/null
+++ b/atuin-server-postgres/migrations/20231202170508_create-store.sql
@@ -0,0 +1,15 @@
+-- Add migration script here
+create table store (
+ id uuid primary key, -- remember to use uuidv7 for happy indices <3
+ client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically
+ host uuid not null, -- a unique identifier for the host
+ idx bigint not null, -- the index of the record in this store, identified by (host, tag)
+ timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
+ version text not null,
+ tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host
+ data text not null, -- store the actual history data, encrypted. I don't wanna know!
+ cek text not null,
+
+ user_id bigint not null, -- allow multiple users
+ created_at timestamp not null default current_timestamp
+);
diff --git a/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql b/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql
new file mode 100644
index 00000000..56d67145
--- /dev/null
+++ b/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql
@@ -0,0 +1,2 @@
+-- Add migration script here
+create unique index record_uniq ON store(user_id, host, tag, idx);
diff --git a/atuin-server-postgres/src/lib.rs b/atuin-server-postgres/src/lib.rs
index f22e6bee..c1de4d50 100644
--- a/atuin-server-postgres/src/lib.rs
+++ b/atuin-server-postgres/src/lib.rs
@@ -1,7 +1,7 @@
use std::ops::Range;
use async_trait::async_trait;
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
+use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
use atuin_server_database::{Database, DbError, DbResult};
use futures_util::TryStreamExt;
@@ -11,6 +11,7 @@ use sqlx::Row;
use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
use tracing::instrument;
+use uuid::Uuid;
use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
mod wrappers;
@@ -361,16 +362,16 @@ impl Database for Postgres {
let id = atuin_common::utils::uuid_v7();
sqlx::query(
- "insert into records
- (id, client_id, host, parent, timestamp, version, tag, data, cek, user_id)
+ "insert into store
+ (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
on conflict do nothing
",
)
.bind(id)
.bind(i.id)
- .bind(i.host)
- .bind(i.parent)
+ .bind(i.host.id)
+ .bind(i.idx as i64)
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
.bind(&i.version)
.bind(&i.tag)
@@ -393,62 +394,69 @@ impl Database for Postgres {
user: &User,
host: HostId,
tag: String,
- start: Option<RecordId>,
+ start: Option<RecordIdx>,
count: u64,
) -> DbResult<Vec<Record<EncryptedData>>> {
tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
- let mut ret = Vec::with_capacity(count as usize);
- let mut parent = start;
+ let start = start.unwrap_or(0);
- // yeah let's do something better
- for _ in 0..count {
- // a very much not ideal query. but it's simple at least?
- // we are basically using postgres as a kv store here, so... maybe consider using an actual
- // kv store?
- let record: Result<DbRecord, DbError> = sqlx::query_as(
- "select client_id, host, parent, timestamp, version, tag, data, cek from records
+ let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
+ "select client_id, host, idx, timestamp, version, tag, data, cek from store
where user_id = $1
and tag = $2
and host = $3
- and parent is not distinct from $4",
- )
- .bind(user.id)
- .bind(tag.clone())
- .bind(host)
- .bind(parent)
- .fetch_one(&self.pool)
- .await
- .map_err(fix_error);
+ and idx >= $4
+ order by idx asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(tag.clone())
+ .bind(host)
+ .bind(start as i64)
+ .bind(count as i64)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(fix_error);
- match record {
- Ok(record) => {
- let record: Record<EncryptedData> = record.into();
- ret.push(record.clone());
+ let ret = match records {
+ Ok(records) => {
+ let records: Vec<Record<EncryptedData>> = records
+ .into_iter()
+ .map(|f| {
+ let record: Record<EncryptedData> = f.into();
+ record
+ })
+ .collect();
- parent = Some(record.id);
- }
- Err(DbError::NotFound) => {
- tracing::debug!("hit tail of store: {:?}/{}", host, tag);
- return Ok(ret);
- }
- Err(e) => return Err(e),
+ records
}
- }
+ Err(DbError::NotFound) => {
+ tracing::debug!("no records found in store: {:?}/{}", host, tag);
+ return Ok(vec![]);
+ }
+ Err(e) => return Err(e),
+ };
Ok(ret)
}
- async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> {
- const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0 and user_id = $1;";
+ async fn status(&self, user: &User) -> DbResult<RecordStatus> {
+ const STATUS_SQL: &str =
+ "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
- let res = sqlx::query_as(TAIL_RECORDS_SQL)
+ let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
.bind(user.id)
- .fetch(&self.pool)
- .try_collect()
+ .fetch_all(&self.pool)
.await
.map_err(fix_error)?;
- Ok(res)
+ let mut status = RecordStatus::new();
+
+ for i in res {
+ status.set_raw(HostId(i.0), i.1, i.2 as u64);
+ }
+
+ Ok(status)
}
}
diff --git a/atuin-server-postgres/src/wrappers.rs b/atuin-server-postgres/src/wrappers.rs
index b4ae48ae..3ccf9c19 100644
--- a/atuin-server-postgres/src/wrappers.rs
+++ b/atuin-server-postgres/src/wrappers.rs
@@ -1,5 +1,5 @@
use ::sqlx::{FromRow, Result};
-use atuin_common::record::{EncryptedData, Record};
+use atuin_common::record::{EncryptedData, Host, Record};
use atuin_server_database::models::{History, Session, User};
use sqlx::{postgres::PgRow, Row};
use time::PrimitiveDateTime;
@@ -51,6 +51,7 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
let timestamp: i64 = row.try_get("timestamp")?;
+ let idx: i64 = row.try_get("idx")?;
let data = EncryptedData {
data: row.try_get("data")?,
@@ -59,8 +60,8 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
Ok(Self(Record {
id: row.try_get("client_id")?,
- host: row.try_get("host")?,
- parent: row.try_get("parent")?,
+ host: Host::new(row.try_get("host")?),
+ idx: idx as u64,
timestamp: timestamp as u64,
version: row.try_get("version")?,
tag: row.try_get("tag")?,
diff --git a/atuin-server/src/handlers/mod.rs b/atuin-server/src/handlers/mod.rs
index 18b1af8e..b66a20bf 100644
--- a/atuin-server/src/handlers/mod.rs
+++ b/atuin-server/src/handlers/mod.rs
@@ -8,6 +8,7 @@ pub mod history;
pub mod record;
pub mod status;
pub mod user;
+pub mod v0;
const VERSION: &str = env!("CARGO_PKG_VERSION");
diff --git a/atuin-server/src/handlers/record.rs b/atuin-server/src/handlers/record.rs
index 91b937b3..b5c07c5b 100644
--- a/atuin-server/src/handlers/record.rs
+++ b/atuin-server/src/handlers/record.rs
@@ -1,109 +1,46 @@
-use axum::{extract::Query, extract::State, Json};
+use axum::{response::IntoResponse, Json};
use http::StatusCode;
-use metrics::counter;
-use serde::Deserialize;
-use tracing::{error, instrument};
+use serde_json::json;
+use tracing::instrument;
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
-use crate::router::{AppState, UserAuth};
+use crate::router::UserAuth;
use atuin_server_database::Database;
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
+use atuin_common::record::{EncryptedData, Record};
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn post<DB: Database>(
UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
- Json(records): Json<Vec<Record<EncryptedData>>>,
) -> Result<(), ErrorResponseStatus<'static>> {
- let State(AppState { database, settings }) = state;
+ // anyone who has actually used the old record store (a very small number) will see this error
+ // upon trying to sync.
+ // 1. The status endpoint will say that the server has nothing
+ // 2. The client will try to upload local records
+ // 3. Sync will fail with this error
- tracing::debug!(
- count = records.len(),
- user = user.username,
- "request to add records"
+ // If the client has no local records, they will see the empty index and do nothing. For the
+ // vast majority of users, this is the case.
+ return Err(
+ ErrorResponse::reply("record store deprecated; please upgrade")
+ .with_status(StatusCode::BAD_REQUEST),
);
-
- counter!("atuin_record_uploaded", records.len() as u64);
-
- let too_big = records
- .iter()
- .any(|r| r.data.data.len() >= settings.max_record_size || settings.max_record_size == 0);
-
- if too_big {
- counter!("atuin_record_too_large", 1);
-
- return Err(
- ErrorResponse::reply("could not add records; record too large")
- .with_status(StatusCode::BAD_REQUEST),
- );
- }
-
- if let Err(e) = database.add_records(&user, &records).await {
- error!("failed to add record: {}", e);
-
- return Err(ErrorResponse::reply("failed to add record")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR));
- };
-
- Ok(())
}
#[instrument(skip_all, fields(user.id = user.id))]
-pub async fn index<DB: Database>(
- UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
-) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> {
- let State(AppState {
- database,
- settings: _,
- }) = state;
-
- let record_index = match database.tail_records(&user).await {
- Ok(index) => index,
- Err(e) => {
- error!("failed to get record index: {}", e);
+pub async fn index<DB: Database>(UserAuth(user): UserAuth) -> axum::response::Response {
+ let ret = json!({
+ "hosts": {}
+ });
- return Err(ErrorResponse::reply("failed to calculate record index")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR));
- }
- };
-
- Ok(Json(record_index))
-}
-
-#[derive(Deserialize)]
-pub struct NextParams {
- host: HostId,
- tag: String,
- start: Option<RecordId>,
- count: u64,
+ ret.to_string().into_response()
}
#[instrument(skip_all, fields(user.id = user.id))]
-pub async fn next<DB: Database>(
- params: Query<NextParams>,
+pub async fn next(
UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> {
- let State(AppState {
- database,
- settings: _,
- }) = state;
- let params = params.0;
-
- let records = match database
- .next_records(&user, params.host, params.tag, params.start, params.count)
- .await
- {
- Ok(records) => records,
- Err(e) => {
- error!("failed to get record index: {}", e);
-
- return Err(ErrorResponse::reply("failed to calculate record index")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR));
- }
- };
+ let records = Vec::new();
Ok(Json(records))
}
diff --git a/atuin-server/src/handlers/v0/mod.rs b/atuin-server/src/handlers/v0/mod.rs
new file mode 100644
index 00000000..78fb47b8
--- /dev/null
+++ b/atuin-server/src/handlers/v0/mod.rs
@@ -0,0 +1 @@
+pub(crate) mod record;
diff --git a/atuin-server/src/handlers/v0/record.rs b/atuin-server/src/handlers/v0/record.rs
new file mode 100644
index 00000000..79b2f80c
--- /dev/null
+++ b/atuin-server/src/handlers/v0/record.rs
@@ -0,0 +1,111 @@
+use axum::{extract::Query, extract::State, Json};
+use http::StatusCode;
+use metrics::counter;
+use serde::Deserialize;
+use tracing::{error, instrument};
+
+use crate::{
+ handlers::{ErrorResponse, ErrorResponseStatus, RespExt},
+ router::{AppState, UserAuth},
+};
+use atuin_server_database::Database;
+
+use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
+
+#[instrument(skip_all, fields(user.id = user.id))]
+pub async fn post<DB: Database>(
+ UserAuth(user): UserAuth,
+ state: State<AppState<DB>>,
+ Json(records): Json<Vec<Record<EncryptedData>>>,
+) -> Result<(), ErrorResponseStatus<'static>> {
+ let State(AppState { database, settings }) = state;
+
+ tracing::debug!(
+ count = records.len(),
+ user = user.username,
+ "request to add records"
+ );
+
+ counter!("atuin_record_uploaded", records.len() as u64);
+
+ let too_big = records
+ .iter()
+ .any(|r| r.data.data.len() >= settings.max_record_size || settings.max_record_size == 0);
+
+ if too_big {
+ counter!("atuin_record_too_large", 1);
+
+ return Err(
+ ErrorResponse::reply("could not add records; record too large")
+ .with_status(StatusCode::BAD_REQUEST),
+ );
+ }
+
+ if let Err(e) = database.add_records(&user, &records).await {
+ error!("failed to add record: {}", e);
+
+ return Err(ErrorResponse::reply("failed to add record")
+ .with_status(StatusCode::INTERNAL_SERVER_ERROR));
+ };
+
+ Ok(())
+}
+
+#[instrument(skip_all, fields(user.id = user.id))]
+pub async fn index<DB: Database>(
+ UserAuth(user): UserAuth,
+ state: State<AppState<DB>>,
+) -> Result<Json<RecordStatus>, ErrorResponseStatus<'static>> {
+ let State(AppState {
+ database,
+ settings: _,
+ }) = state;
+
+ let record_index = match database.status(&user).await {
+ Ok(index) => index,
+ Err(e) => {
+ error!("failed to get record index: {}", e);
+
+ return Err(ErrorResponse::reply("failed to calculate record index")
+ .with_status(StatusCode::INTERNAL_SERVER_ERROR));
+ }
+ };
+
+ Ok(Json(record_index))
+}
+
+#[derive(Deserialize)]
+pub struct NextParams {
+ host: HostId,
+ tag: String,
+ start: Option<RecordIdx>,
+ count: u64,
+}
+
+#[instrument(skip_all, fields(user.id = user.id))]
+pub async fn next<DB: Database>(
+ params: Query<NextParams>,
+ UserAuth(user): UserAuth,
+ state: State<AppState<DB>>,
+) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> {
+ let State(AppState {
+ database,
+ settings: _,
+ }) = state;
+ let params = params.0;
+
+ let records = match database
+ .next_records(&user, params.host, params.tag, params.start, params.count)
+ .await
+ {
+ Ok(records) => records,
+ Err(e) => {
+ error!("failed to get record index: {}", e);
+
+ return Err(ErrorResponse::reply("failed to calculate record index")
+ .with_status(StatusCode::INTERNAL_SERVER_ERROR));
+ }
+ };
+
+ Ok(Json(records))
+}
diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs
index 42cfaa86..500e1a29 100644
--- a/atuin-server/src/router.rs
+++ b/atuin-server/src/router.rs
@@ -118,13 +118,16 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R
.route("/sync/status", get(handlers::status::status))
.route("/history", post(handlers::history::add))
.route("/history", delete(handlers::history::delete))
- .route("/record", post(handlers::record::post))
- .route("/record", get(handlers::record::index))
- .route("/record/next", get(handlers::record::next))
.route("/user/:username", get(handlers::user::get))
.route("/account", delete(handlers::user::delete))
.route("/register", post(handlers::user::register))
- .route("/login", post(handlers::user::login));
+ .route("/login", post(handlers::user::login))
+ .route("/record", post(handlers::record::post::<DB>))
+ .route("/record", get(handlers::record::index::<DB>))
+ .route("/record/next", get(handlers::record::next))
+ .route("/api/v0/record", post(handlers::v0::record::post))
+ .route("/api/v0/record", get(handlers::v0::record::index))
+ .route("/api/v0/record/next", get(handlers::v0::record::next));
let path = settings.path.as_str();
if path.is_empty() {
diff --git a/atuin/src/command/client.rs b/atuin/src/command/client.rs
index d0f58c4a..9ca199fd 100644
--- a/atuin/src/command/client.rs
+++ b/atuin/src/command/client.rs
@@ -16,6 +16,7 @@ mod config;
mod history;
mod import;
mod kv;
+mod record;
mod search;
mod stats;
@@ -46,6 +47,9 @@ pub enum Cmd {
#[command(subcommand)]
Kv(kv::Cmd),
+ #[command(subcommand)]
+ Record(record::Cmd),
+
/// Print example configuration
#[command()]
DefaultConfig,
@@ -79,21 +83,23 @@ impl Cmd {
let record_store_path = PathBuf::from(settings.record_store_path.as_str());
let db = Sqlite::new(db_path).await?;
- let mut store = SqliteStore::new(record_store_path).await?;
+ let store = SqliteStore::new(record_store_path).await?;
match self {
- Self::History(history) => history.run(&settings, &db).await,
+ Self::History(history) => history.run(&settings, &db, store).await,
Self::Import(import) => import.run(&db).await,
Self::Stats(stats) => stats.run(&db, &settings).await,
Self::Search(search) => search.run(db, &mut settings).await,
#[cfg(feature = "sync")]
- Self::Sync(sync) => sync.run(settings, &db, &mut store).await,
+ Self::Sync(sync) => sync.run(settings, &db, &store).await,
#[cfg(feature = "sync")]
Self::Account(account) => account.run(settings).await,
- Self::Kv(kv) => kv.run(&settings, &mut store).await,
+ Self::Kv(kv) => kv.run(&settings, &store).await,
+
+ Self::Record(record) => record.run(&settings, &store).await,
Self::DefaultConfig => {
config::run();
diff --git a/atuin/src/command/client/history.rs b/atuin/src/command/client/history.rs
index 85ca69ff..e22ee6db 100644
--- a/atuin/src/command/client/history.rs
+++ b/atuin/src/command/client/history.rs
@@ -12,7 +12,9 @@ use runtime_format::{FormatKey, FormatKeyError, ParseSegment, ParsedFmt};
use atuin_client::{
database::{current_context, Database},
- history::History,
+ encryption,
+ history::{store::HistoryStore, History},
+ record::{self, sqlite_store::SqliteStore},
settings::Settings,
};
@@ -84,6 +86,10 @@ pub enum Cmd {
#[arg(long, short)]
format: Option<String>,
},
+
+ /// Import all old history.db data into the record store. Do not run more than once, and do not
+ /// run unless you know what you're doing (or the docs ask you to)
+ InitStore,
}
#[derive(Clone, Copy, Debug)]
@@ -266,11 +272,14 @@ impl Cmd {
// we use this as the key for calling end
println!("{}", h.id);
db.save(&h).await?;
+
Ok(())
}
async fn handle_end(
db: &impl Database,
+ store: SqliteStore,
+ history_store: HistoryStore,
settings: &Settings,
id: &str,
exit: i64,
@@ -300,10 +309,20 @@ impl Cmd {
};
db.update(&h).await?;
+ history_store.push(h).await?;
if settings.should_sync()? {
#[cfg(feature = "sync")]
{
+ if settings.sync.records {
+ let (diff, _) = record::sync::diff(settings, &store).await?;
+ let operations = record::sync::operations(diff, &store).await?;
+ let (uploaded, downloaded) =
+ record::sync::sync_remote(operations, &store, settings).await?;
+
+ println!("{uploaded}/{downloaded} up/down to record store");
+ }
+
debug!("running periodic background sync");
sync::sync(settings, false, db).await?;
}
@@ -367,13 +386,56 @@ impl Cmd {
Ok(())
}
- pub async fn run(self, settings: &Settings, db: &impl Database) -> Result<()> {
+ async fn init_store(
+ context: atuin_client::database::Context,
+ db: &impl Database,
+ store: HistoryStore,
+ ) -> Result<()> {
+ println!("Importing all history.db data into records.db");
+
+ let history = db
+ .list(
+ atuin_client::settings::FilterMode::Global,
+ &context,
+ None,
+ false,
+ true,
+ )
+ .await?;
+
+ for i in history {
+ println!("loaded {}", i.id);
+
+ if i.deleted_at.is_some() {
+ store.push(i.clone()).await?;
+ store.delete(i.id).await?;
+ } else {
+ store.push(i).await?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub async fn run(
+ self,
+ settings: &Settings,
+ db: &impl Database,
+ store: SqliteStore,
+ ) -> Result<()> {
let context = current_context();
+ let encryption_key: [u8; 32] = encryption::load_key(settings)
+ .context("could not load encryption key")?
+ .into();
+
+ let host_id = Settings::host_id().expect("failed to get host_id");
+ let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
+
match self {
Self::Start { command } => Self::handle_start(db, settings, &command).await,
Self::End { id, exit, duration } => {
- Self::handle_end(db, settings, &id, exit, duration).await
+ Self::handle_end(db, store, history_store, settings, &id, exit, duration).await
}
Self::List {
session,
@@ -408,6 +470,8 @@ impl Cmd {
Ok(())
}
+
+ Self::InitStore => Self::init_store(context, db, history_store).await,
}
}
}
diff --git a/atuin/src/command/client/kv.rs b/atuin/src/command/client/kv.rs
index 48ebe9e5..b97f31b7 100644
--- a/atuin/src/command/client/kv.rs
+++ b/atuin/src/command/client/kv.rs
@@ -35,11 +35,7 @@ pub enum Cmd {
}
impl Cmd {
- pub async fn run(
- &self,
- settings: &Settings,
- store: &mut (impl Store + Send + Sync),
- ) -> Result<()> {
+ pub async fn run(&self, settings: &Settings, store: &(impl Store + Send + Sync)) -> Result<()> {
let kv_store = KvStore::new();
let encryption_key: [u8; 32] = encryption::load_key(settings)
diff --git a/atuin/src/command/client/record.rs b/atuin/src/command/client/record.rs
new file mode 100644
index 00000000..3c91cdcc
--- /dev/null
+++ b/atuin/src/command/client/record.rs
@@ -0,0 +1,63 @@
+use clap::Subcommand;
+use eyre::Result;
+
+use atuin_client::{record::store::Store, settings::Settings};
+use time::OffsetDateTime;
+
+#[derive(Subcommand, Debug)]
+#[command(infer_subcommands = true)]
+pub enum Cmd {
+ Status,
+}
+
+impl Cmd {
+ pub async fn run(
+ &self,
+ _settings: &Settings,
+ store: &(impl Store + Send + Sync),
+ ) -> Result<()> {
+ let host_id = Settings::host_id().expect("failed to get host_id");
+
+ let status = store.status().await?;
+
+ // TODO: should probs build some data structure and then pretty-print it or smth
+ for (host, st) in &status.hosts {
+ let host_string = if host == &host_id {
+ format!("host: {} <- CURRENT HOST", host.0.as_hyphenated())
+ } else {
+ format!("host: {}", host.0.as_hyphenated())
+ };
+
+ println!("{host_string}");
+
+ for (tag, idx) in st {
+ println!("\tstore: {tag}");
+
+ let first = store.first(*host, tag).await?;
+ let last = store.last(*host, tag).await?;
+
+ println!("\t\tidx: {idx}");
+
+ if let Some(first) = first {
+ println!("\t\tfirst: {}", first.id.0.as_hyphenated());
+
+ let time =
+ OffsetDateTime::from_unix_timestamp_nanos(i128::from(first.timestamp))?;
+ println!("\t\t\tcreated: {time}");
+ }
+
+ if let Some(last) = last {
+ println!("\t\tlast: {}", last.id.0.as_hyphenated());
+
+ let time =
+ OffsetDateTime::from_unix_timestamp_nanos(i128::from(last.timestamp))?;
+ println!("\t\t\tcreated: {time}");
+ }
+ }
+
+ println!();
+ }
+
+ Ok(())
+ }
+}
diff --git a/atuin/src/command/client/sync.rs b/atuin/src/command/client/sync.rs
index 50a1d835..1d2cdf4f 100644
--- a/atuin/src/command/client/sync.rs
+++ b/atuin/src/command/client/sync.rs
@@ -45,7 +45,7 @@ impl Cmd {
self,
settings: Settings,
db: &impl Database,
- store: &mut (impl Store + Send + Sync),
+ store: &(impl Store + Send + Sync),
) -> Result<()> {
match self {
Self::Sync { force } => run(&settings, force, db, store).await,
@@ -75,14 +75,15 @@ async fn run(
settings: &Settings,
force: bool,
db: &impl Database,
- store: &mut (impl Store + Send + Sync),
+ store: &(impl Store + Send + Sync),
) -> Result<()> {
- let (diff, remote_index) = sync::diff(settings, store).await?;
- let operations = sync::operations(diff, store).await?;
- let (uploaded, downloaded) =
- sync::sync_remote(operations, &remote_index, store, settings).await?;
+ if settings.sync.records {
+ let (diff, _) = sync::diff(settings, store).await?;
+ let operations = sync::operations(diff, store).await?;
+ let (uploaded, downloaded) = sync::sync_remote(operations, store, settings).await?;
- println!("{uploaded}/{downloaded} up/down to record store");
+ println!("{uploaded}/{downloaded} up/down to record store");
+ }
atuin_client::sync::sync(settings, force, db).await?;