aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorConrad Ludgate <conradludgate@gmail.com>2023-06-12 09:04:35 +0100
committerGitHub <noreply@github.com>2023-06-12 09:04:35 +0100
commit8655c93853506acf05f6ae4e58bfc2c6198be254 (patch)
tree22d20b35636ad2eb717d58c93ae07378adbb76eb
parentMake Ctrl-d behaviour match other tools (#1040) (diff)
downloadatuin-8655c93853506acf05f6ae4e58bfc2c6198be254.zip
refactor server to allow pluggable db and tracing (#1036)
* refactor server to allow pluggable db and tracing * clean up * fix descriptions * remove dependencies
-rw-r--r--Cargo.lock104
-rw-r--r--Cargo.toml10
-rw-r--r--atuin-client/Cargo.toml1
-rw-r--r--atuin-server-database/Cargo.toml21
-rw-r--r--atuin-server-database/src/calendar.rs (renamed from atuin-server/src/calendar.rs)0
-rw-r--r--atuin-server-database/src/lib.rs220
-rw-r--r--atuin-server-database/src/models.rs (renamed from atuin-server/src/models.rs)3
-rw-r--r--atuin-server-postgres/Cargo.toml21
-rw-r--r--atuin-server-postgres/migrations/20210425153745_create_history.sql (renamed from atuin-server/migrations/20210425153745_create_history.sql)0
-rw-r--r--atuin-server-postgres/migrations/20210425153757_create_users.sql (renamed from atuin-server/migrations/20210425153757_create_users.sql)0
-rw-r--r--atuin-server-postgres/migrations/20210425153800_create_sessions.sql (renamed from atuin-server/migrations/20210425153800_create_sessions.sql)0
-rw-r--r--atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql (renamed from atuin-server/migrations/20220419082412_add_count_trigger.sql)0
-rw-r--r--atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql (renamed from atuin-server/migrations/20220421073605_fix_count_trigger_delete.sql)0
-rw-r--r--atuin-server-postgres/migrations/20220421174016_larger-commands.sql (renamed from atuin-server/migrations/20220421174016_larger-commands.sql)0
-rw-r--r--atuin-server-postgres/migrations/20220426172813_user-created-at.sql (renamed from atuin-server/migrations/20220426172813_user-created-at.sql)0
-rw-r--r--atuin-server-postgres/migrations/20220505082442_create-events.sql (renamed from atuin-server/migrations/20220505082442_create-events.sql)0
-rw-r--r--atuin-server-postgres/migrations/20220610074049_history-length.sql (renamed from atuin-server/migrations/20220610074049_history-length.sql)0
-rw-r--r--atuin-server-postgres/migrations/20230315220537_drop-events.sql (renamed from atuin-server/migrations/20230315220537_drop-events.sql)0
-rw-r--r--atuin-server-postgres/migrations/20230315224203_create-deleted.sql (renamed from atuin-server/migrations/20230315224203_create-deleted.sql)0
-rw-r--r--atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql (renamed from atuin-server/migrations/20230515221038_trigger-delete-only.sql)0
-rw-r--r--atuin-server-postgres/src/lib.rs332
-rw-r--r--atuin-server-postgres/src/wrappers.rs42
-rw-r--r--atuin-server/Cargo.toml4
-rw-r--r--atuin-server/src/auth.rs222
-rw-r--r--atuin-server/src/database.rs510
-rw-r--r--atuin-server/src/handlers/history.rs44
-rw-r--r--atuin-server/src/handlers/status.rs5
-rw-r--r--atuin-server/src/handlers/user.rs22
-rw-r--r--atuin-server/src/lib.rs43
-rw-r--r--atuin-server/src/router.rs20
-rw-r--r--atuin-server/src/settings.rs12
-rw-r--r--atuin/Cargo.toml7
-rw-r--r--atuin/src/command/server.rs5
33 files changed, 760 insertions, 888 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 16e67c25..a3ac81b1 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -97,6 +97,7 @@ dependencies = [
"atuin-client",
"atuin-common",
"atuin-server",
+ "atuin-server-postgres",
"base64 0.21.0",
"bitflags",
"cassowary",
@@ -104,7 +105,6 @@ dependencies = [
"clap",
"clap_complete",
"colored",
- "crossbeam-channel",
"crossterm",
"directories",
"env_logger",
@@ -160,7 +160,6 @@ dependencies = [
"serde_regex",
"sha2",
"shellexpand",
- "sodiumoxide",
"sql-builder",
"sqlx",
"tokio",
@@ -187,6 +186,7 @@ dependencies = [
"argon2",
"async-trait",
"atuin-common",
+ "atuin-server-database",
"axum",
"base64 0.21.0",
"chrono",
@@ -200,14 +200,39 @@ dependencies = [
"semver",
"serde",
"serde_json",
- "sodiumoxide",
- "sqlx",
"tokio",
"tower",
"tower-http",
"tracing",
"uuid",
- "whoami",
+]
+
+[[package]]
+name = "atuin-server-database"
+version = "15.0.0"
+dependencies = [
+ "async-trait",
+ "atuin-common",
+ "chrono",
+ "chronoutil",
+ "eyre",
+ "serde",
+ "tracing",
+ "uuid",
+]
+
+[[package]]
+name = "atuin-server-postgres"
+version = "15.0.0"
+dependencies = [
+ "async-trait",
+ "atuin-common",
+ "atuin-server-database",
+ "chrono",
+ "futures-util",
+ "serde",
+ "sqlx",
+ "tracing",
]
[[package]]
@@ -516,16 +541,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d0165d2900ae6778e36e80bbc4da3b5eefccee9ba939761f9c2882a5d9af3ff"
[[package]]
-name = "crossbeam-channel"
-version = "0.5.8"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
-dependencies = [
- "cfg-if",
- "crossbeam-utils",
-]
-
-[[package]]
name = "crossbeam-queue"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -632,15 +647,6 @@ dependencies = [
]
[[package]]
-name = "ed25519"
-version = "1.5.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1e9c280362032ea4203659fc489832d0204ef09f247a0506f170dafcac08c369"
-dependencies = [
- "signature",
-]
-
-[[package]]
name = "either"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1176,18 +1182,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5"
[[package]]
-name = "libsodium-sys"
-version = "0.2.7"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6b779387cd56adfbc02ea4a668e704f729be8d6a6abd2c27ca5ee537849a92fd"
-dependencies = [
- "cc",
- "libc",
- "pkg-config",
- "walkdir",
-]
-
-[[package]]
name = "libsqlite3-sys"
version = "0.24.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1876,15 +1870,6 @@ dependencies = [
]
[[package]]
-name = "same-file"
-version = "1.0.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
-dependencies = [
- "winapi-util",
-]
-
-[[package]]
name = "schannel"
version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2072,12 +2057,6 @@ dependencies = [
]
[[package]]
-name = "signature"
-version = "1.6.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e90531723b08e4d6d71b791108faf51f03e1b4a7784f96b2b87f852ebc247228"
-
-[[package]]
name = "slab"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2103,18 +2082,6 @@ dependencies = [
]
[[package]]
-name = "sodiumoxide"
-version = "0.2.7"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e26be3acb6c2d9a7aac28482586a7856436af4cfe7100031d219de2d2ecb0028"
-dependencies = [
- "ed25519",
- "libc",
- "libsodium-sys",
- "serde",
-]
-
-[[package]]
name = "spin"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2660,17 +2627,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
-name = "walkdir"
-version = "2.3.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
-dependencies = [
- "same-file",
- "winapi",
- "winapi-util",
-]
-
-[[package]]
name = "want"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 652efb8a..00b0434e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,5 +1,12 @@
[workspace]
-members = ["atuin", "atuin-client", "atuin-server", "atuin-common"]
+members = [
+ "atuin",
+ "atuin-client",
+ "atuin-server",
+ "atuin-server-postgres",
+ "atuin-server-database",
+ "atuin-common",
+]
[workspace.package]
name = "atuin"
@@ -27,7 +34,6 @@ rand = { version = "0.8.5", features = ["std"] }
semver = "1.0.14"
serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.86"
-sodiumoxide = "0.2.6"
tokio = { version = "1", features = ["full"] }
uuid = { version = "1.2", features = ["v4"] }
whoami = "1.1.2"
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml
index fee3eb5f..770d7741 100644
--- a/atuin-client/Cargo.toml
+++ b/atuin-client/Cargo.toml
@@ -53,7 +53,6 @@ memchr = "2.5"
# sync
urlencoding = { version = "2.1.0", optional = true }
-sodiumoxide = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
hex = { version = "0.4", optional = true }
sha2 = { version = "0.10", optional = true }
diff --git a/atuin-server-database/Cargo.toml b/atuin-server-database/Cargo.toml
new file mode 100644
index 00000000..485b3246
--- /dev/null
+++ b/atuin-server-database/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "atuin-server-database"
+edition = "2021"
+description = "server database library for atuin"
+
+version = { workspace = true }
+authors = { workspace = true }
+license = { workspace = true }
+homepage = { workspace = true }
+repository = { workspace = true }
+
+[dependencies]
+atuin-common = { path = "../atuin-common", version = "15.0.0" }
+
+tracing = "0.1"
+chrono = { workspace = true }
+eyre = { workspace = true }
+uuid = { workspace = true }
+serde = { workspace = true }
+async-trait = { workspace = true }
+chronoutil = "0.2.3"
diff --git a/atuin-server/src/calendar.rs b/atuin-server-database/src/calendar.rs
index 7c05dce3..7c05dce3 100644
--- a/atuin-server/src/calendar.rs
+++ b/atuin-server-database/src/calendar.rs
diff --git a/atuin-server-database/src/lib.rs b/atuin-server-database/src/lib.rs
new file mode 100644
index 00000000..de33ba44
--- /dev/null
+++ b/atuin-server-database/src/lib.rs
@@ -0,0 +1,220 @@
+#![forbid(unsafe_code)]
+
+pub mod calendar;
+pub mod models;
+
+use std::{
+ collections::HashMap,
+ fmt::{Debug, Display},
+};
+
+use self::{
+ calendar::{TimePeriod, TimePeriodInfo},
+ models::{History, NewHistory, NewSession, NewUser, Session, User},
+};
+use async_trait::async_trait;
+use atuin_common::utils::get_days_from_month;
+use chrono::{Datelike, TimeZone};
+use chronoutil::RelativeDuration;
+use serde::{de::DeserializeOwned, Serialize};
+use tracing::instrument;
+
+#[derive(Debug)]
+pub enum DbError {
+ NotFound,
+ Other(eyre::Report),
+}
+
+impl Display for DbError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{self:?}")
+ }
+}
+
+impl std::error::Error for DbError {}
+
+pub type DbResult<T> = Result<T, DbError>;
+
+#[async_trait]
+pub trait Database: Sized + Clone + Send + Sync + 'static {
+ type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static;
+ async fn new(settings: &Self::Settings) -> DbResult<Self>;
+
+ async fn get_session(&self, token: &str) -> DbResult<Session>;
+ async fn get_session_user(&self, token: &str) -> DbResult<User>;
+ async fn add_session(&self, session: &NewSession) -> DbResult<()>;
+
+ async fn get_user(&self, username: &str) -> DbResult<User>;
+ async fn get_user_session(&self, u: &User) -> DbResult<Session>;
+ async fn add_user(&self, user: &NewUser) -> DbResult<i64>;
+ async fn delete_user(&self, u: &User) -> DbResult<()>;
+
+ async fn count_history(&self, user: &User) -> DbResult<i64>;
+ async fn count_history_cached(&self, user: &User) -> DbResult<i64>;
+
+ async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
+ async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
+
+ async fn count_history_range(
+ &self,
+ user: &User,
+ start: chrono::NaiveDateTime,
+ end: chrono::NaiveDateTime,
+ ) -> DbResult<i64>;
+
+ async fn list_history(
+ &self,
+ user: &User,
+ created_after: chrono::NaiveDateTime,
+ since: chrono::NaiveDateTime,
+ host: &str,
+ page_size: i64,
+ ) -> DbResult<Vec<History>>;
+
+ async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>;
+
+ async fn oldest_history(&self, user: &User) -> DbResult<History>;
+
+ /// Count the history for a given year
+ #[instrument(skip_all)]
+ async fn count_history_year(&self, user: &User, year: i32) -> DbResult<i64> {
+ let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0);
+ let end = start + RelativeDuration::years(1);
+
+ let res = self
+ .count_history_range(user, start.naive_utc(), end.naive_utc())
+ .await?;
+ Ok(res)
+ }
+
+ /// Count the history for a given month
+ #[instrument(skip_all)]
+ async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> DbResult<i64> {
+ let start = chrono::Utc
+ .ymd(month.year(), month.month(), 1)
+ .and_hms_nano(0, 0, 0, 0);
+
+ // ofc...
+ let end = if month.month() < 12 {
+ chrono::Utc
+ .ymd(month.year(), month.month() + 1, 1)
+ .and_hms_nano(0, 0, 0, 0)
+ } else {
+ chrono::Utc
+ .ymd(month.year() + 1, 1, 1)
+ .and_hms_nano(0, 0, 0, 0)
+ };
+
+ tracing::debug!("start: {}, end: {}", start, end);
+
+ let res = self
+ .count_history_range(user, start.naive_utc(), end.naive_utc())
+ .await?;
+ Ok(res)
+ }
+
+ /// Count the history for a given day
+ #[instrument(skip_all)]
+ async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> DbResult<i64> {
+ let start = chrono::Utc
+ .ymd(day.year(), day.month(), day.day())
+ .and_hms_nano(0, 0, 0, 0);
+ let end = chrono::Utc
+ .ymd(day.year(), day.month(), day.day() + 1)
+ .and_hms_nano(0, 0, 0, 0);
+
+ let res = self
+ .count_history_range(user, start.naive_utc(), end.naive_utc())
+ .await?;
+ Ok(res)
+ }
+
+ #[instrument(skip_all)]
+ async fn calendar(
+ &self,
+ user: &User,
+ period: TimePeriod,
+ year: u64,
+ month: u64,
+ ) -> DbResult<HashMap<u64, TimePeriodInfo>> {
+ // TODO: Support different timezones. Right now we assume UTC and
+ // everything is stored as such. But it _should_ be possible to
+ // interpret the stored date with a different TZ
+
+ match period {
+ TimePeriod::YEAR => {
+ let mut ret = HashMap::new();
+ // First we need to work out how far back to calculate. Get the
+ // oldest history item
+ let oldest = self.oldest_history(user).await?.timestamp.year();
+ let current_year = chrono::Utc::now().year();
+
+ // All the years we need to get data for
+ // The upper bound is exclusive, so include current +1
+ let years = oldest..current_year + 1;
+
+ for year in years {
+ let count = self.count_history_year(user, year).await?;
+
+ ret.insert(
+ year as u64,
+ TimePeriodInfo {
+ count: count as u64,
+ hash: "".to_string(),
+ },
+ );
+ }
+
+ Ok(ret)
+ }
+
+ TimePeriod::MONTH => {
+ let mut ret = HashMap::new();
+
+ for month in 1..13 {
+ let count = self
+ .count_history_month(
+ user,
+ chrono::Utc.ymd(year as i32, month, 1).naive_utc(),
+ )
+ .await?;
+
+ ret.insert(
+ month as u64,
+ TimePeriodInfo {
+ count: count as u64,
+ hash: "".to_string(),
+ },
+ );
+ }
+
+ Ok(ret)
+ }
+
+ TimePeriod::DAY => {
+ let mut ret = HashMap::new();
+
+ for day in 1..get_days_from_month(year as i32, month as u32) {
+ let count = self
+ .count_history_day(
+ user,
+ chrono::Utc
+ .ymd(year as i32, month as u32, day as u32)
+ .naive_utc(),
+ )
+ .await?;
+
+ ret.insert(
+ day as u64,
+ TimePeriodInfo {
+ count: count as u64,
+ hash: "".to_string(),
+ },
+ );
+ }
+
+ Ok(ret)
+ }
+ }
+ }
+}
diff --git a/atuin-server/src/models.rs b/atuin-server-database/src/models.rs
index ee84f58a..a95ceba2 100644
--- a/atuin-server/src/models.rs
+++ b/atuin-server-database/src/models.rs
@@ -1,6 +1,5 @@
use chrono::prelude::*;
-#[derive(sqlx::FromRow)]
pub struct History {
pub id: i64,
pub client_id: String, // a client generated ID
@@ -22,7 +21,6 @@ pub struct NewHistory {
pub data: String,
}
-#[derive(sqlx::FromRow)]
pub struct User {
pub id: i64,
pub username: String,
@@ -30,7 +28,6 @@ pub struct User {
pub password: String,
}
-#[derive(sqlx::FromRow)]
pub struct Session {
pub id: i64,
pub user_id: i64,
diff --git a/atuin-server-postgres/Cargo.toml b/atuin-server-postgres/Cargo.toml
new file mode 100644
index 00000000..18864f6c
--- /dev/null
+++ b/atuin-server-postgres/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "atuin-server-postgres"
+edition = "2018"
+description = "server postgres database library for atuin"
+
+version = { workspace = true }
+authors = { workspace = true }
+license = { workspace = true }
+homepage = { workspace = true }
+repository = { workspace = true }
+
+[dependencies]
+atuin-common = { path = "../atuin-common", version = "15.0.0" }
+atuin-server-database = { path = "../atuin-server-database", version = "15.0.0" }
+
+tracing = "0.1"
+chrono = { workspace = true }
+serde = { workspace = true }
+sqlx = { workspace = true }
+async-trait = { workspace = true }
+futures-util = "0.3"
diff --git a/atuin-server/migrations/20210425153745_create_history.sql b/atuin-server-postgres/migrations/20210425153745_create_history.sql
index 2c2d17b0..2c2d17b0 100644
--- a/atuin-server/migrations/20210425153745_create_history.sql
+++ b/atuin-server-postgres/migrations/20210425153745_create_history.sql
diff --git a/atuin-server/migrations/20210425153757_create_users.sql b/atuin-server-postgres/migrations/20210425153757_create_users.sql
index a25dcced..a25dcced 100644
--- a/atuin-server/migrations/20210425153757_create_users.sql
+++ b/atuin-server-postgres/migrations/20210425153757_create_users.sql
diff --git a/atuin-server/migrations/20210425153800_create_sessions.sql b/atuin-server-postgres/migrations/20210425153800_create_sessions.sql
index c2fb6559..c2fb6559 100644
--- a/atuin-server/migrations/20210425153800_create_sessions.sql
+++ b/atuin-server-postgres/migrations/20210425153800_create_sessions.sql
diff --git a/atuin-server/migrations/20220419082412_add_count_trigger.sql b/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql
index dd1afa88..dd1afa88 100644
--- a/atuin-server/migrations/20220419082412_add_count_trigger.sql
+++ b/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql
diff --git a/atuin-server/migrations/20220421073605_fix_count_trigger_delete.sql b/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql
index 6198f300..6198f300 100644
--- a/atuin-server/migrations/20220421073605_fix_count_trigger_delete.sql
+++ b/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql
diff --git a/atuin-server/migrations/20220421174016_larger-commands.sql b/atuin-server-postgres/migrations/20220421174016_larger-commands.sql
index 0ac43433..0ac43433 100644
--- a/atuin-server/migrations/20220421174016_larger-commands.sql
+++ b/atuin-server-postgres/migrations/20220421174016_larger-commands.sql
diff --git a/atuin-server/migrations/20220426172813_user-created-at.sql b/atuin-server-postgres/migrations/20220426172813_user-created-at.sql
index a9138194..a9138194 100644
--- a/atuin-server/migrations/20220426172813_user-created-at.sql
+++ b/atuin-server-postgres/migrations/20220426172813_user-created-at.sql
diff --git a/atuin-server/migrations/20220505082442_create-events.sql b/atuin-server-postgres/migrations/20220505082442_create-events.sql
index 57e16ec7..57e16ec7 100644
--- a/atuin-server/migrations/20220505082442_create-events.sql
+++ b/atuin-server-postgres/migrations/20220505082442_create-events.sql
diff --git a/atuin-server/migrations/20220610074049_history-length.sql b/atuin-server-postgres/migrations/20220610074049_history-length.sql
index b1c23016..b1c23016 100644
--- a/atuin-server/migrations/20220610074049_history-length.sql
+++ b/atuin-server-postgres/migrations/20220610074049_history-length.sql
diff --git a/atuin-server/migrations/20230315220537_drop-events.sql b/atuin-server-postgres/migrations/20230315220537_drop-events.sql
index fe3cae17..fe3cae17 100644
--- a/atuin-server/migrations/20230315220537_drop-events.sql
+++ b/atuin-server-postgres/migrations/20230315220537_drop-events.sql
diff --git a/atuin-server/migrations/20230315224203_create-deleted.sql b/atuin-server-postgres/migrations/20230315224203_create-deleted.sql
index 9a9e6263..9a9e6263 100644
--- a/atuin-server/migrations/20230315224203_create-deleted.sql
+++ b/atuin-server-postgres/migrations/20230315224203_create-deleted.sql
diff --git a/atuin-server/migrations/20230515221038_trigger-delete-only.sql b/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql
index 3d0bba52..3d0bba52 100644
--- a/atuin-server/migrations/20230515221038_trigger-delete-only.sql
+++ b/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql
diff --git a/atuin-server-postgres/src/lib.rs b/atuin-server-postgres/src/lib.rs
new file mode 100644
index 00000000..0dc51daf
--- /dev/null
+++ b/atuin-server-postgres/src/lib.rs
@@ -0,0 +1,332 @@
+use async_trait::async_trait;
+use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
+use atuin_server_database::{Database, DbError, DbResult};
+use futures_util::TryStreamExt;
+use serde::{Deserialize, Serialize};
+use sqlx::postgres::PgPoolOptions;
+
+use sqlx::Row;
+
+use tracing::instrument;
+use wrappers::{DbHistory, DbSession, DbUser};
+
+mod wrappers;
+
+#[derive(Clone)]
+pub struct Postgres {
+ pool: sqlx::Pool<sqlx::postgres::Postgres>,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
+pub struct PostgresSettings {
+ pub db_uri: String,
+}
+
+fn fix_error(error: sqlx::Error) -> DbError {
+ match error {
+ sqlx::Error::RowNotFound => DbError::NotFound,
+ error => DbError::Other(error.into()),
+ }
+}
+
+#[async_trait]
+impl Database for Postgres {
+ type Settings = PostgresSettings;
+ async fn new(settings: &PostgresSettings) -> DbResult<Self> {
+ let pool = PgPoolOptions::new()
+ .max_connections(100)
+ .connect(settings.db_uri.as_str())
+ .await
+ .map_err(fix_error)?;
+
+ sqlx::migrate!("./migrations")
+ .run(&pool)
+ .await
+ .map_err(|error| DbError::Other(error.into()))?;
+
+ Ok(Self { pool })
+ }
+
+ #[instrument(skip_all)]
+ async fn get_session(&self, token: &str) -> DbResult<Session> {
+ sqlx::query_as("select id, user_id, token from sessions where token = $1")
+ .bind(token)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbSession(session)| session)
+ }
+
+ #[instrument(skip_all)]
+ async fn get_user(&self, username: &str) -> DbResult<User> {
+ sqlx::query_as("select id, username, email, password from users where username = $1")
+ .bind(username)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbUser(user)| user)
+ }
+
+ #[instrument(skip_all)]
+ async fn get_session_user(&self, token: &str) -> DbResult<User> {
+ sqlx::query_as(
+ "select users.id, users.username, users.email, users.password from users
+ inner join sessions
+ on users.id = sessions.user_id
+ and sessions.token = $1",
+ )
+ .bind(token)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbUser(user)| user)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history(&self, user: &User) -> DbResult<i64> {
+ // The cache is new, and the user might not yet have a cache value.
+ // They will have one as soon as they post up some new history, but handle that
+ // edge case.
+
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
+ let res: (i32,) = sqlx::query_as(
+ "select total from total_history_count_user
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res.0 as i64)
+ }
+
+ async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
+ sqlx::query(
+ "update history
+ set deleted_at = $3
+ where user_id = $1
+ and client_id = $2
+ and deleted_at is null", // don't just keep setting it
+ )
+ .bind(user.id)
+ .bind(id)
+ .bind(chrono::Utc::now().naive_utc())
+ .fetch_all(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
+ // The cache is new, and the user might not yet have a cache value.
+ // They will have one as soon as they post up some new history, but handle that
+ // edge case.
+
+ let res = sqlx::query(
+ "select client_id from history
+ where user_id = $1
+ and deleted_at is not null",
+ )
+ .bind(user.id)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ let res = res
+ .iter()
+ .map(|row| row.get::<String, _>("client_id"))
+ .collect();
+
+ Ok(res)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history_range(
+ &self,
+ user: &User,
+ start: chrono::NaiveDateTime,
+ end: chrono::NaiveDateTime,
+ ) -> DbResult<i64> {
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1
+ and timestamp >= $2::date
+ and timestamp < $3::date",
+ )
+ .bind(user.id)
+ .bind(start)
+ .bind(end)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn list_history(
+ &self,
+ user: &User,
+ created_after: chrono::NaiveDateTime,
+ since: chrono::NaiveDateTime,
+ host: &str,
+ page_size: i64,
+ ) -> DbResult<Vec<History>> {
+ let res = sqlx::query_as(
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ where user_id = $1
+ and hostname != $2
+ and created_at >= $3
+ and timestamp >= $4
+ order by timestamp asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(host)
+ .bind(created_after)
+ .bind(since)
+ .bind(page_size)
+ .fetch(&self.pool)
+ .map_ok(|DbHistory(h)| h)
+ .try_collect()
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res)
+ }
+
+ #[instrument(skip_all)]
+ async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
+ let mut tx = self.pool.begin().await.map_err(fix_error)?;
+
+ for i in history {
+ let client_id: &str = &i.client_id;
+ let hostname: &str = &i.hostname;
+ let data: &str = &i.data;
+
+ sqlx::query(
+ "insert into history
+ (client_id, user_id, hostname, timestamp, data)
+ values ($1, $2, $3, $4, $5)
+ on conflict do nothing
+ ",
+ )
+ .bind(client_id)
+ .bind(i.user_id)
+ .bind(hostname)
+ .bind(i.timestamp)
+ .bind(data)
+ .execute(&mut tx)
+ .await
+ .map_err(fix_error)?;
+ }
+
+ tx.commit().await.map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn delete_user(&self, u: &User) -> DbResult<()> {
+ sqlx::query("delete from sessions where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ sqlx::query("delete from users where id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ sqlx::query("delete from history where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
+ let email: &str = &user.email;
+ let username: &str = &user.username;
+ let password: &str = &user.password;
+
+ let res: (i64,) = sqlx::query_as(
+ "insert into users
+ (username, email, password)
+ values($1, $2, $3)
+ returning id",
+ )
+ .bind(username)
+ .bind(email)
+ .bind(password)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn add_session(&self, session: &NewSession) -> DbResult<()> {
+ let token: &str = &session.token;
+
+ sqlx::query(
+ "insert into sessions
+ (user_id, token)
+ values($1, $2)",
+ )
+ .bind(session.user_id)
+ .bind(token)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn get_user_session(&self, u: &User) -> DbResult<Session> {
+ sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
+ .bind(u.id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbSession(session)| session)
+ }
+
+ #[instrument(skip_all)]
+ async fn oldest_history(&self, user: &User) -> DbResult<History> {
+ sqlx::query_as(
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ where user_id = $1
+ order by timestamp asc
+ limit 1",
+ )
+ .bind(user.id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbHistory(h)| h)
+ }
+}
diff --git a/atuin-server-postgres/src/wrappers.rs b/atuin-server-postgres/src/wrappers.rs
new file mode 100644
index 00000000..cb3d5a96
--- /dev/null
+++ b/atuin-server-postgres/src/wrappers.rs
@@ -0,0 +1,42 @@
+use ::sqlx::{FromRow, Result};
+use atuin_server_database::models::{History, Session, User};
+use sqlx::{postgres::PgRow, Row};
+
+pub struct DbUser(pub User);
+pub struct DbSession(pub Session);
+pub struct DbHistory(pub History);
+
+impl<'a> FromRow<'a, PgRow> for DbUser {
+ fn from_row(row: &'a PgRow) -> Result<Self> {
+ Ok(Self(User {
+ id: row.try_get("id")?,
+ username: row.try_get("username")?,
+ email: row.try_get("email")?,
+ password: row.try_get("password")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession {
+ fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
+ Ok(Self(Session {
+ id: row.try_get("id")?,
+ user_id: row.try_get("user_id")?,
+ token: row.try_get("token")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
+ fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
+ Ok(Self(History {
+ id: row.try_get("id")?,
+ client_id: row.try_get("client_id")?,
+ user_id: row.try_get("user_id")?,
+ hostname: row.try_get("hostname")?,
+ timestamp: row.try_get("timestamp")?,
+ data: row.try_get("data")?,
+ created_at: row.try_get("created_at")?,
+ }))
+ }
+}
diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml
index e4cbf3e0..f308fa30 100644
--- a/atuin-server/Cargo.toml
+++ b/atuin-server/Cargo.toml
@@ -11,20 +11,18 @@ repository = { workspace = true }
[dependencies]
atuin-common = { path = "../atuin-common", version = "15.0.0" }
+atuin-server-database = { path = "../atuin-server-database", version = "15.0.0" }
tracing = "0.1"
chrono = { workspace = true }
eyre = { workspace = true }
uuid = { workspace = true }
-whoami = { workspace = true }
config = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
-sodiumoxide = { workspace = true }
base64 = { workspace = true }
rand = { workspace = true }
tokio = { workspace = true }
-sqlx = { workspace = true }
async-trait = { workspace = true }
axum = "0.6.4"
http = "0.2"
diff --git a/atuin-server/src/auth.rs b/atuin-server/src/auth.rs
deleted file mode 100644
index 52a73108..00000000
--- a/atuin-server/src/auth.rs
+++ /dev/null
@@ -1,222 +0,0 @@
-/*
-use self::diesel::prelude::*;
-use eyre::Result;
-use rocket::http::Status;
-use rocket::request::{self, FromRequest, Outcome, Request};
-use rocket::State;
-use rocket_contrib::databases::diesel;
-use sodiumoxide::crypto::pwhash::argon2id13;
-
-use rocket_contrib::json::Json;
-use uuid::Uuid;
-
-use super::models::{NewSession, NewUser, Session, User};
-use super::views::ApiResponse;
-
-use crate::api::{LoginRequest, RegisterRequest};
-use crate::schema::{sessions, users};
-use crate::settings::Settings;
-use crate::utils::hash_secret;
-
-use super::database::AtuinDbConn;
-
-#[derive(Debug)]
-pub enum KeyError {
- Missing,
- Invalid,
-}
-
-pub fn verify_str(secret: &str, verify: &str) -> bool {
- sodiumoxide::init().unwrap();
-
- let mut padded = [0_u8; 128];
- secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
- padded[i] = *val;
- });
-
- match argon2id13::HashedPassword::from_slice(&padded) {
- Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
- None => false,
- }
-}
-
-impl<'a, 'r> FromRequest<'a, 'r> for User {
- type Error = KeyError;
-
- fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> {
- let session: Vec<_> = request.headers().get("authorization").collect();
-
- if session.is_empty() {
- return Outcome::Failure((Status::BadRequest, KeyError::Missing));
- } else if session.len() > 1 {
- return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
- }
-
- let session: Vec<_> = session[0].split(' ').collect();
-
- if session.len() != 2 {
- return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
- }
-
- if session[0] != "Token" {
- return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
- }
-
- let session = session[1];
-
- let db = request
- .guard::<AtuinDbConn>()
- .succeeded()
- .expect("failed to load database");
-
- let session = sessions::table
- .filter(sessions::token.eq(session))
- .first::<Session>(&*db);
-
- if session.is_err() {
- return Outcome::Failure((Status::Unauthorized, KeyError::Invalid));
- }
-
- let session = session.unwrap();
-
- let user = users::table.find(session.user_id).first(&*db);
-
- match user {
- Ok(user) => Outcome::Success(user),
- Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)),
- }
- }
-}
-
-#[get("/user/<user>")]
-#[allow(clippy::clippy::needless_pass_by_value)]
-pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse {
- use crate::schema::users::dsl::{username, users};
-
- let user: Result<String, diesel::result::Error> = users
- .select(username)
- .filter(username.eq(user))
- .first(&*conn);
-
- if user.is_err() {
- return ApiResponse {
- json: json!({
- "message": "could not find user",
- }),
- status: Status::NotFound,
- };
- }
-
- let user = user.unwrap();
-
- ApiResponse {
- json: json!({ "username": user.as_str() }),
- status: Status::Ok,
- }
-}
-
-#[post("/register", data = "<register>")]
-#[allow(clippy::clippy::needless_pass_by_value)]
-pub fn register(
- conn: AtuinDbConn,
- register: Json<RegisterRequest>,
- settings: State<Settings>,
-) -> ApiResponse {
- if !settings.server.open_registration {
- return ApiResponse {
- status: Status::BadRequest,
- json: json!({
- "message": "registrations are not open"
- }),
- };
- }
-
- let hashed = hash_secret(register.password.as_str());
-
- let new_user = NewUser {
- email: register.email.as_str(),
- username: register.username.as_str(),
- password: hashed.as_str(),
- };
-
- let user = diesel::insert_into(users::table)
- .values(&new_user)
- .get_result(&*conn);
-
- if user.is_err() {
- return ApiResponse {
- status: Status::BadRequest,
- json: json!({
- "message": "failed to create user - username or email in use?",
- }),
- };
- }
-
- let user: User = user.unwrap();
- let token = Uuid::new_v4().to_simple().to_string();
-
- let new_session = NewSession {
- user_id: user.id,
- token: token.as_str(),
- };
-
- match diesel::insert_into(sessions::table)
- .values(&new_session)
- .execute(&*conn)
- {
- Ok(_) => ApiResponse {
- status: Status::Ok,
- json: json!({"message": "user created!", "session": token}),
- },
- Err(_) => ApiResponse {
- status: Status::BadRequest,
- json: json!({ "message": "failed to create user"}),
- },
- }
-}
-
-#[post("/login", data = "<login>")]
-#[allow(clippy::clippy::needless_pass_by_value)]
-pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse {
- let user = users::table
- .filter(users::username.eq(login.username.as_str()))
- .first(&*conn);
-
- if user.is_err() {
- return ApiResponse {
- status: Status::NotFound,
- json: json!({"message": "user not found"}),
- };
- }
-
- let user: User = user.unwrap();
-
- let session = sessions::table
- .filter(sessions::user_id.eq(user.id))
- .first(&*conn);
-
- // a session should exist...
- if session.is_err() {
- return ApiResponse {
- status: Status::InternalServerError,
- json: json!({"message": "something went wrong"}),
- };
- }
-
- let verified = verify_str(user.password.as_str(), login.password.as_str());
-
- if !verified {
- return ApiResponse {
- status: Status::NotFound,
- json: json!({"message": "user not found"}),
- };
- }
-
- let session: Session = session.unwrap();
-
- ApiResponse {
- status: Status::Ok,
- json: json!({"session": session.token}),
- }
-}
-*/
diff --git a/atuin-server/src/database.rs b/atuin-server/src/database.rs
deleted file mode 100644
index 894fab7b..00000000
--- a/atuin-server/src/database.rs
+++ /dev/null
@@ -1,510 +0,0 @@
-use std::collections::HashMap;
-
-use async_trait::async_trait;
-use chrono::{Datelike, TimeZone};
-use chronoutil::RelativeDuration;
-use sqlx::{postgres::PgPoolOptions, Result};
-
-use sqlx::Row;
-
-use tracing::{debug, instrument, warn};
-
-use super::{
- calendar::{TimePeriod, TimePeriodInfo},
- models::{History, NewHistory, NewSession, NewUser, Session, User},
-};
-use crate::settings::Settings;
-
-use atuin_common::utils::get_days_from_month;
-
-#[async_trait]
-pub trait Database {
- async fn get_session(&self, token: &str) -> Result<Session>;
- async fn get_session_user(&self, token: &str) -> Result<User>;
- async fn add_session(&self, session: &NewSession) -> Result<()>;
-
- async fn get_user(&self, username: &str) -> Result<User>;
- async fn get_user_session(&self, u: &User) -> Result<Session>;
- async fn add_user(&self, user: &NewUser) -> Result<i64>;
- async fn delete_user(&self, u: &User) -> Result<()>;
-
- async fn count_history(&self, user: &User) -> Result<i64>;
- async fn count_history_cached(&self, user: &User) -> Result<i64>;
-
- async fn delete_history(&self, user: &User, id: String) -> Result<()>;
- async fn deleted_history(&self, user: &User) -> Result<Vec<String>>;
-
- async fn count_history_range(
- &self,
- user: &User,
- start: chrono::NaiveDateTime,
- end: chrono::NaiveDateTime,
- ) -> Result<i64>;
- async fn count_history_day(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>;
- async fn count_history_month(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>;
- async fn count_history_year(&self, user: &User, year: i32) -> Result<i64>;
-
- async fn list_history(
- &self,
- user: &User,
- created_after: chrono::NaiveDateTime,
- since: chrono::NaiveDateTime,
- host: &str,
- page_size: i64,
- ) -> Result<Vec<History>>;
-
- async fn add_history(&self, history: &[NewHistory]) -> Result<()>;
-
- async fn oldest_history(&self, user: &User) -> Result<History>;
-
- async fn calendar(
- &self,
- user: &User,
- period: TimePeriod,
- year: u64,
- month: u64,
- ) -> Result<HashMap<u64, TimePeriodInfo>>;
-}
-
-#[derive(Clone)]
-pub struct Postgres {
- pool: sqlx::Pool<sqlx::postgres::Postgres>,
- settings: Settings,
-}
-
-impl Postgres {
- pub async fn new(settings: Settings) -> Result<Self> {
- let pool = PgPoolOptions::new()
- .max_connections(100)
- .connect(settings.db_uri.as_str())
- .await?;
-
- sqlx::migrate!("./migrations").run(&pool).await?;
-
- Ok(Self { pool, settings })
- }
-}
-
-#[async_trait]
-impl Database for Postgres {
- #[instrument(skip_all)]
- async fn get_session(&self, token: &str) -> Result<Session> {
- sqlx::query_as::<_, Session>("select id, user_id, token from sessions where token = $1")
- .bind(token)
- .fetch_one(&self.pool)
- .await
- }
-
- #[instrument(skip_all)]
- async fn get_user(&self, username: &str) -> Result<User> {
- sqlx::query_as::<_, User>(
- "select id, username, email, password from users where username = $1",
- )
- .bind(username)
- .fetch_one(&self.pool)
- .await
- }
-
- #[instrument(skip_all)]
- async fn get_session_user(&self, token: &str) -> Result<User> {
- sqlx::query_as::<_, User>(
- "select users.id, users.username, users.email, users.password from users
- inner join sessions
- on users.id = sessions.user_id
- and sessions.token = $1",
- )
- .bind(token)
- .fetch_one(&self.pool)
- .await
- }
-
- #[instrument(skip_all)]
- async fn count_history(&self, user: &User) -> Result<i64> {
- // The cache is new, and the user might not yet have a cache value.
- // They will have one as soon as they post up some new history, but handle that
- // edge case.
-
- let res: (i64,) = sqlx::query_as(
- "select count(1) from history
- where user_id = $1",
- )
- .bind(user.id)
- .fetch_one(&self.pool)
- .await?;
-
- Ok(res.0)
- }
-
- #[instrument(skip_all)]
- async fn count_history_cached(&self, user: &User) -> Result<i64> {
- let res: (i32,) = sqlx::query_as(
- "select total from total_history_count_user
- where user_id = $1",
- )
- .bind(user.id)
- .fetch_one(&self.pool)
- .await?;
-
- Ok(res.0 as i64)
- }
-
- async fn delete_history(&self, user: &User, id: String) -> Result<()> {
- sqlx::query(
- "update history
- set deleted_at = $3
- where user_id = $1
- and client_id = $2
- and deleted_at is null", // don't just keep setting it
- )
- .bind(user.id)
- .bind(id)
- .bind(chrono::Utc::now().naive_utc())
- .fetch_all(&self.pool)
- .await?;
-
- Ok(())
- }
-
- #[instrument(skip_all)]
- async fn deleted_history(&self, user: &User) -> Result<Vec<String>> {
- // The cache is new, and the user might not yet have a cache value.
- // They will have one as soon as they post up some new history, but handle that
- // edge case.
-
- let res = sqlx::query(
- "select client_id from history
- where user_id = $1
- and deleted_at is not null",
- )
- .bind(user.id)
- .fetch_all(&self.pool)
- .await?;
-
- let res = res
- .iter()
- .map(|row| row.get::<String, _>("client_id"))
- .collect();
-
- Ok(res)
- }
-
- #[instrument(skip_all)]
- async fn count_history_range(
- &self,
- user: &User,
- start: chrono::NaiveDateTime,
- end: chrono::NaiveDateTime,
- ) -> Result<i64> {
- let res: (i64,) = sqlx::query_as(
- "select count(1) from history
- where user_id = $1
- and timestamp >= $2::date
- and timestamp < $3::date",
- )
- .bind(user.id)
- .bind(start)
- .bind(end)
- .fetch_one(&self.pool)
- .await?;
-
- Ok(res.0)
- }
-
- // Count the history for a given year
- #[instrument(skip_all)]
- async fn count_history_year(&self, user: &User, year: i32) -> Result<i64> {
- let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0);
- let end = start + RelativeDuration::years(1);
-
- let res = self
- .count_history_range(user, start.naive_utc(), end.naive_utc())
- .await?;
- Ok(res)
- }
-
- // Count the history for a given month
- #[instrument(skip_all)]
- async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> Result<i64> {
- let start = chrono::Utc
- .ymd(month.year(), month.month(), 1)
- .and_hms_nano(0, 0, 0, 0);
-
- // ofc...
- let end = if month.month() < 12 {
- chrono::Utc
- .ymd(month.year(), month.month() + 1, 1)
- .and_hms_nano(0, 0, 0, 0)
- } else {
- chrono::Utc
- .ymd(month.year() + 1, 1, 1)
- .and_hms_nano(0, 0, 0, 0)
- };
-
- debug!("start: {}, end: {}", start, end);
-
- let res = self
- .count_history_range(user, start.naive_utc(), end.naive_utc())
- .await?;
- Ok(res)
- }
-
- // Count the history for a given day
- #[instrument(skip_all)]
- async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> Result<i64> {
- let start = chrono::Utc
- .ymd(day.year(), day.month(), day.day())
- .and_hms_nano(0, 0, 0, 0);
- let end = chrono::Utc
- .ymd(day.year(), day.month(), day.day() + 1)
- .and_hms_nano(0, 0, 0, 0);
-
- let res = self
- .count_history_range(user, start.naive_utc(), end.naive_utc())
- .await?;
- Ok(res)
- }
-
- #[instrument(skip_all)]
- async fn list_history(
- &self,
- user: &User,
- created_after: chrono::NaiveDateTime,
- since: chrono::NaiveDateTime,
- host: &str,
- page_size: i64,
- ) -> Result<Vec<History>> {
- let res = sqlx::query_as::<_, History>(
- "select id, client_id, user_id, hostname, timestamp, data, created_at from history
- where user_id = $1
- and hostname != $2
- and created_at >= $3
- and timestamp >= $4
- order by timestamp asc
- limit $5",
- )
- .bind(user.id)
- .bind(host)
- .bind(created_after)
- .bind(since)
- .bind(page_size)
- .fetch_all(&self.pool)
- .await?;
-
- Ok(res)
- }
-
- #[instrument(skip_all)]
- async fn add_history(&self, history: &[NewHistory]) -> Result<()> {
- let mut tx = self.pool.begin().await?;
-
- for i in history {
- let client_id: &str = &i.client_id;
- let hostname: &str = &i.hostname;
- let data: &str = &i.data;
-
- if data.len() > self.settings.max_history_length
- && self.settings.max_history_length != 0
- {
- // Don't return an error here. We want to insert as much of the
- // history list as we can, so log the error and continue going.
-
- warn!(
- "history too long, got length {}, max {}",
- data.len(),
- self.settings.max_history_length
- );
-
- continue;
- }
-
- sqlx::query(
- "insert into history
- (client_id, user_id, hostname, timestamp, data)
- values ($1, $2, $3, $4, $5)
- on conflict do nothing
- ",
- )
- .bind(client_id)
- .bind(i.user_id)
- .bind(hostname)
- .bind(i.timestamp)
- .bind(data)
- .execute(&mut tx)
- .await?;
- }
-
- tx.commit().await?;
-
- Ok(())
- }
-
- #[instrument(skip_all)]
- async fn delete_user(&self, u: &User) -> Result<()> {
- sqlx::query("delete from sessions where user_id = $1")
- .bind(u.id)
- .execute(&self.pool)
- .await?;
-
- sqlx::query("delete from users where id = $1")
- .bind(u.id)
- .execute(&self.pool)
- .await?;
-
- sqlx::query("delete from history where user_id = $1")
- .bind(u.id)
- .execute(&self.pool)
- .await?;
-
- Ok(())
- }
-
- #[instrument(skip_all)]
- async fn add_user(&self, user: &NewUser) -> Result<i64> {
- let email: &str = &user.email;
- let username: &str = &user.username;
- let password: &str = &user.password;
-
- let res: (i64,) = sqlx::query_as(
- "insert into users
- (username, email, password)
- values($1, $2, $3)
- returning id",
- )
- .bind(username)
- .bind(email)
- .bind(password)
- .fetch_one(&self.pool)
- .await?;
-
- Ok(res.0)
- }
-
- #[instrument(skip_all)]
- async fn add_session(&self, session: &NewSession) -> Result<()> {
- let token: &str = &session.token;
-
- sqlx::query(
- "insert into sessions
- (user_id, token)
- values($1, $2)",
- )
- .bind(session.user_id)
- .bind(token)
- .execute(&self.pool)
- .await?;
-
- Ok(())
- }
-
- #[instrument(skip_all)]
- async fn get_user_session(&self, u: &User) -> Result<Session> {
- sqlx::query_as::<_, Session>("select id, user_id, token from sessions where user_id = $1")
- .bind(u.id)
- .fetch_one(&self.pool)
- .await
- }
-
- #[instrument(skip_all)]
- async fn oldest_history(&self, user: &User) -> Result<History> {
- let res = sqlx::query_as::<_, History>(
- "select id, client_id, user_id, hostname, timestamp, data, created_at from history
- where user_id = $1
- order by timestamp asc
- limit 1",
- )
- .bind(user.id)
- .fetch_one(&self.pool)
- .await?;
-
- Ok(res)
- }
-
- #[instrument(skip_all)]
- async fn calendar(
- &self,
- user: &User,
- period: TimePeriod,
- year: u64,
- month: u64,
- ) -> Result<HashMap<u64, TimePeriodInfo>> {
- // TODO: Support different timezones. Right now we assume UTC and
- // everything is stored as such. But it _should_ be possible to
- // interpret the stored date with a different TZ
-
- match period {
- TimePeriod::YEAR => {
- let mut ret = HashMap::new();
- // First we need to work out how far back to calculate. Get the
- // oldest history item
- let oldest = self.oldest_history(user).await?.timestamp.year();
- let current_year = chrono::Utc::now().year();
-
- // All the years we need to get data for
- // The upper bound is exclusive, so include current +1
- let years = oldest..current_year + 1;
-
- for year in years {
- let count = self.count_history_year(user, year).await?;
-
- ret.insert(
- year as u64,
- TimePeriodInfo {
- count: count as u64,
- hash: "".to_string(),
- },
- );
- }
-
- Ok(ret)
- }
-
- TimePeriod::MONTH => {
- let mut ret = HashMap::new();
-
- for month in 1..13 {
- let count = self
- .count_history_month(
- user,
- chrono::Utc.ymd(year as i32, month, 1).naive_utc(),
- )
- .await?;
-
- ret.insert(
- month as u64,
- TimePeriodInfo {
- count: count as u64,
- hash: "".to_string(),
- },
- );
- }
-
- Ok(ret)
- }
-
- TimePeriod::DAY => {
- let mut ret = HashMap::new();
-
- for day in 1..get_days_from_month(year as i32, month as u32) {
- let count = self
- .count_history_day(
- user,
- chrono::Utc
- .ymd(year as i32, month as u32, day as u32)
- .naive_utc(),
- )
- .await?;
-
- ret.insert(
- day as u64,
- TimePeriodInfo {
- count: count as u64,
- hash: "".to_string(),
- },
- );
- }
-
- Ok(ret)
- }
- }
- }
-}
diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs
index 1c9dff5f..bb0aa321 100644
--- a/atuin-server/src/handlers/history.rs
+++ b/atuin-server/src/handlers/history.rs
@@ -10,18 +10,20 @@ use tracing::{debug, error, instrument};
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::{
- calendar::{TimePeriod, TimePeriodInfo},
- database::Database,
- models::{NewHistory, User},
- router::AppState,
+ router::{AppState, UserAuth},
utils::client_version_min,
};
+use atuin_server_database::{
+ calendar::{TimePeriod, TimePeriodInfo},
+ models::NewHistory,
+ Database,
+};
use atuin_common::api::*;
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn count<DB: Database>(
- user: User,
+ UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.database;
@@ -42,7 +44,7 @@ pub async fn count<DB: Database>(
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn list<DB: Database>(
req: Query<SyncHistoryRequest>,
- user: User,
+ UserAuth(user): UserAuth,
headers: HeaderMap,
state: State<AppState<DB>>,
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
@@ -101,7 +103,7 @@ pub async fn list<DB: Database>(
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn delete<DB: Database>(
- user: User,
+ UserAuth(user): UserAuth,
state: State<AppState<DB>>,
Json(req): Json<DeleteHistoryRequest>,
) -> Result<Json<MessageResponse>, ErrorResponseStatus<'static>> {
@@ -123,13 +125,15 @@ pub async fn delete<DB: Database>(
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn add<DB: Database>(
- user: User,
+ UserAuth(user): UserAuth,
state: State<AppState<DB>>,
Json(req): Json<Vec<AddHistoryRequest>>,
) -> Result<(), ErrorResponseStatus<'static>> {
+ let State(AppState { database, settings }) = state;
+
debug!("request to add {} history items", req.len());
- let history: Vec<NewHistory> = req
+ let mut history: Vec<NewHistory> = req
.into_iter()
.map(|h| NewHistory {
client_id: h.id,
@@ -140,8 +144,24 @@ pub async fn add<DB: Database>(
})
.collect();
- let db = &state.0.database;
- if let Err(e) = db.add_history(&history).await {
+ history.retain(|h| {
+ // keep if within limit, or limit is 0 (unlimited)
+ let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0;
+
+ // Don't return an error here. We want to insert as much of the
+ // history list as we can, so log the error and continue going.
+ if !keep {
+ tracing::warn!(
+ "history too long, got length {}, max {}",
+ h.data.len(),
+ settings.max_history_length
+ );
+ }
+
+ keep
+ });
+
+ if let Err(e) = database.add_history(&history).await {
error!("failed to add history: {}", e);
return Err(ErrorResponse::reply("failed to add history")
@@ -155,7 +175,7 @@ pub async fn add<DB: Database>(
pub async fn calendar<DB: Database>(
Path(focus): Path<String>,
Query(params): Query<HashMap<String, u64>>,
- user: User,
+ UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
let focus = focus.as_str();
diff --git a/atuin-server/src/handlers/status.rs b/atuin-server/src/handlers/status.rs
index 97c02886..d9b6afaf 100644
--- a/atuin-server/src/handlers/status.rs
+++ b/atuin-server/src/handlers/status.rs
@@ -3,7 +3,8 @@ use http::StatusCode;
use tracing::instrument;
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
-use crate::{database::Database, models::User, router::AppState};
+use crate::router::{AppState, UserAuth};
+use atuin_server_database::Database;
use atuin_common::api::*;
@@ -11,7 +12,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION");
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn status<DB: Database>(
- user: User,
+ UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<StatusResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.database;
diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs
index e67828e4..75081155 100644
--- a/atuin-server/src/handlers/user.rs
+++ b/atuin-server/src/handlers/user.rs
@@ -16,10 +16,10 @@ use tracing::{debug, error, info, instrument};
use uuid::Uuid;
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
-use crate::{
- database::Database,
- models::{NewSession, NewUser, User},
- router::AppState,
+use crate::router::{AppState, UserAuth};
+use atuin_server_database::{
+ models::{NewSession, NewUser},
+ Database, DbError,
};
use reqwest::header::CONTENT_TYPE;
@@ -64,11 +64,11 @@ pub async fn get<DB: Database>(
let db = &state.0.database;
let user = match db.get_user(username.as_ref()).await {
Ok(user) => user,
- Err(sqlx::Error::RowNotFound) => {
+ Err(DbError::NotFound) => {
debug!("user not found: {}", username);
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
- Err(err) => {
+ Err(DbError::Other(err)) => {
error!("database error: {}", err);
return Err(ErrorResponse::reply("database error")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
@@ -152,7 +152,7 @@ pub async fn register<DB: Database>(
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn delete<DB: Database>(
- user: User,
+ UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<DeleteUserResponse>, ErrorResponseStatus<'static>> {
debug!("request to delete user {}", user.id);
@@ -175,10 +175,10 @@ pub async fn login<DB: Database>(
let db = &state.0.database;
let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u,
- Err(sqlx::Error::RowNotFound) => {
+ Err(DbError::NotFound) => {
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
- Err(e) => {
+ Err(DbError::Other(e)) => {
error!("failed to get user {}: {}", login.username.clone(), e);
return Err(ErrorResponse::reply("database error")
@@ -188,11 +188,11 @@ pub async fn login<DB: Database>(
let session = match db.get_user_session(&user).await {
Ok(u) => u,
- Err(sqlx::Error::RowNotFound) => {
+ Err(DbError::NotFound) => {
debug!("user session not found for user id={}", user.id);
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
- Err(err) => {
+ Err(DbError::Other(err)) => {
error!("database error for user {}: {}", login.username, err);
return Err(ErrorResponse::reply("database error")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs
index 01873af9..aa2250d3 100644
--- a/atuin-server/src/lib.rs
+++ b/atuin-server/src/lib.rs
@@ -2,45 +2,38 @@
use std::net::{IpAddr, SocketAddr};
+use atuin_server_database::Database;
use axum::Server;
-use database::Postgres;
use eyre::{Context, Result};
-use crate::settings::Settings;
+mod handlers;
+mod router;
+mod settings;
+mod utils;
+pub use settings::Settings;
use tokio::signal;
-pub mod auth;
-pub mod calendar;
-pub mod database;
-pub mod handlers;
-pub mod models;
-pub mod router;
-pub mod settings;
-pub mod utils;
-
async fn shutdown_signal() {
- let terminate = async {
- signal::unix::signal(signal::unix::SignalKind::terminate())
- .expect("failed to register signal handler")
- .recv()
- .await;
- };
-
- tokio::select! {
- _ = terminate => (),
- }
+ signal::unix::signal(signal::unix::SignalKind::terminate())
+ .expect("failed to register signal handler")
+ .recv()
+ .await;
eprintln!("Shutting down gracefully...");
}
-pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> {
+pub async fn launch<Db: Database>(
+ settings: Settings<Db::Settings>,
+ host: String,
+ port: u16,
+) -> Result<()> {
let host = host.parse::<IpAddr>()?;
- let postgres = Postgres::new(settings.clone())
+ let db = Db::new(&settings.db_settings)
.await
- .wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?;
+ .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
- let r = router::router(postgres, settings);
+ let r = router::router(db, settings);
Server::bind(&SocketAddr::new(host, port))
.serve(r.into_make_service())
diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs
index 20b11f45..ec558e78 100644
--- a/atuin-server/src/router.rs
+++ b/atuin-server/src/router.rs
@@ -10,11 +10,14 @@ use http::request::Parts;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
-use super::{database::Database, handlers};
-use crate::{models::User, settings::Settings};
+use super::handlers;
+use crate::settings::Settings;
+use atuin_server_database::{models::User, Database};
+
+pub struct UserAuth(pub User);
#[async_trait]
-impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for User
+impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth
where
DB: Database,
{
@@ -45,7 +48,7 @@ where
.await
.map_err(|_| http::StatusCode::FORBIDDEN)?;
- Ok(user)
+ Ok(UserAuth(user))
}
}
@@ -54,15 +57,12 @@ async fn teapot() -> impl IntoResponse {
}
#[derive(Clone)]
-pub struct AppState<DB> {
+pub struct AppState<DB: Database> {
pub database: DB,
- pub settings: Settings,
+ pub settings: Settings<DB::Settings>,
}
-pub fn router<DB: Database + Clone + Send + Sync + 'static>(
- database: DB,
- settings: Settings,
-) -> Router {
+pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> Router {
let routes = Router::new()
.route("/", get(handlers::index))
.route("/sync/count", get(handlers::history::count))
diff --git a/atuin-server/src/settings.rs b/atuin-server/src/settings.rs
index 981d239f..fb5325d4 100644
--- a/atuin-server/src/settings.rs
+++ b/atuin-server/src/settings.rs
@@ -3,24 +3,24 @@ use std::{io::prelude::*, path::PathBuf};
use config::{Config, Environment, File as ConfigFile, FileFormat};
use eyre::{eyre, Result};
use fs_err::{create_dir_all, File};
-use serde::{Deserialize, Serialize};
-
-pub const HISTORY_PAGE_SIZE: i64 = 100;
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)]
-pub struct Settings {
+pub struct Settings<DbSettings> {
pub host: String,
pub port: u16,
pub path: String,
- pub db_uri: String,
pub open_registration: bool,
pub max_history_length: usize,
pub page_size: i64,
pub register_webhook_url: Option<String>,
pub register_webhook_username: String,
+
+ #[serde(flatten)]
+ pub db_settings: DbSettings,
}
-impl Settings {
+impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
pub fn new() -> Result<Self> {
let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") {
PathBuf::from(p)
diff --git a/atuin/Cargo.toml b/atuin/Cargo.toml
index 9085623c..bf038050 100644
--- a/atuin/Cargo.toml
+++ b/atuin/Cargo.toml
@@ -33,15 +33,13 @@ buildflags = ["--release"]
atuin = { path = "/usr/bin/atuin" }
[features]
-# TODO(conradludgate)
-# Currently, this keeps the same default built behaviour for v0.8
-# We should rethink this by the time we hit a new breaking change
default = ["client", "sync", "server"]
client = ["atuin-client"]
sync = ["atuin-client/sync"]
-server = ["atuin-server", "tracing-subscriber"]
+server = ["atuin-server", "atuin-server-postgres", "tracing-subscriber"]
[dependencies]
+atuin-server-postgres = { path = "../atuin-server-postgres", version = "15.0.0", optional = true }
atuin-server = { path = "../atuin-server", version = "15.0.0", optional = true }
atuin-client = { path = "../atuin-client", version = "15.0.0", optional = true, default-features = false }
atuin-common = { path = "../atuin-common", version = "15.0.0" }
@@ -61,7 +59,6 @@ tokio = { workspace = true }
async-trait = { workspace = true }
interim = { workspace = true }
base64 = { workspace = true }
-crossbeam-channel = "0.5.1"
clap = { workspace = true }
clap_complete = "4.0.3"
fs-err = { workspace = true }
diff --git a/atuin/src/command/server.rs b/atuin/src/command/server.rs
index 495f85d0..c65cb505 100644
--- a/atuin/src/command/server.rs
+++ b/atuin/src/command/server.rs
@@ -1,9 +1,10 @@
+use atuin_server_postgres::Postgres;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use clap::Parser;
use eyre::{Context, Result};
-use atuin_server::{launch, settings::Settings};
+use atuin_server::{launch, Settings};
#[derive(Parser)]
#[clap(infer_subcommands = true)]
@@ -37,7 +38,7 @@ impl Cmd {
.map_or(settings.host.clone(), std::string::ToString::to_string);
let port = port.map_or(settings.port, |p| p);
- launch(settings, host, port).await
+ launch::<Postgres>(settings, host, port).await
}
}
}