diff options
Diffstat (limited to 'atuin-server')
23 files changed, 81 insertions, 1037 deletions
diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml index e4cbf3e0..f308fa30 100644 --- a/atuin-server/Cargo.toml +++ b/atuin-server/Cargo.toml @@ -11,20 +11,18 @@ repository = { workspace = true } [dependencies] atuin-common = { path = "../atuin-common", version = "15.0.0" } +atuin-server-database = { path = "../atuin-server-database", version = "15.0.0" } tracing = "0.1" chrono = { workspace = true } eyre = { workspace = true } uuid = { workspace = true } -whoami = { workspace = true } config = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -sodiumoxide = { workspace = true } base64 = { workspace = true } rand = { workspace = true } tokio = { workspace = true } -sqlx = { workspace = true } async-trait = { workspace = true } axum = "0.6.4" http = "0.2" diff --git a/atuin-server/migrations/20210425153745_create_history.sql b/atuin-server/migrations/20210425153745_create_history.sql deleted file mode 100644 index 2c2d17b0..00000000 --- a/atuin-server/migrations/20210425153745_create_history.sql +++ /dev/null @@ -1,11 +0,0 @@ -create table history ( - id bigserial primary key, - client_id text not null unique, -- the client-generated ID - user_id bigserial not null, -- allow multiple users - hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) - timestamp timestamp not null, -- one of the few non-encrypted metadatas - - data varchar(8192) not null, -- store the actual history data, encrypted. I don't wanna know! - - created_at timestamp not null default current_timestamp -); diff --git a/atuin-server/migrations/20210425153757_create_users.sql b/atuin-server/migrations/20210425153757_create_users.sql deleted file mode 100644 index a25dcced..00000000 --- a/atuin-server/migrations/20210425153757_create_users.sql +++ /dev/null @@ -1,10 +0,0 @@ -create table users ( - id bigserial primary key, -- also store our own ID - username varchar(32) not null unique, -- being able to contact users is useful - email varchar(128) not null unique, -- being able to contact users is useful - password varchar(128) not null unique -); - --- the prior index is case sensitive :( -CREATE UNIQUE INDEX email_unique_idx on users (LOWER(email)); -CREATE UNIQUE INDEX username_unique_idx on users (LOWER(username)); diff --git a/atuin-server/migrations/20210425153800_create_sessions.sql b/atuin-server/migrations/20210425153800_create_sessions.sql deleted file mode 100644 index c2fb6559..00000000 --- a/atuin-server/migrations/20210425153800_create_sessions.sql +++ /dev/null @@ -1,6 +0,0 @@ --- Add migration script here -create table sessions ( - id bigserial primary key, - user_id bigserial, - token varchar(128) unique not null -); diff --git a/atuin-server/migrations/20220419082412_add_count_trigger.sql b/atuin-server/migrations/20220419082412_add_count_trigger.sql deleted file mode 100644 index dd1afa88..00000000 --- a/atuin-server/migrations/20220419082412_add_count_trigger.sql +++ /dev/null @@ -1,51 +0,0 @@ --- Prior to this, the count endpoint was super naive and just ran COUNT(1). --- This is slow asf. Now that we have an amount of actual traffic, --- stop doing that! --- This basically maintains a count, so we can read ONE row, instead of ALL the --- rows. Much better. --- Future optimisation could use some sort of cache so we don't even need to hit --- postgres at all. - -create table total_history_count_user( - id bigserial primary key, - user_id bigserial, - total integer -- try and avoid using keywords - hence total, not count -); - -create or replace function user_history_count() -returns trigger as -$func$ -begin - if (TG_OP='INSERT') then - update total_history_count_user set total = total + 1 where user_id = new.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - new.user_id, - (select count(1) from history where user_id = new.user_id) - ); - end if; - - elsif (TG_OP='DELETE') then - update total_history_count_user set total = total - 1 where user_id = new.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - new.user_id, - (select count(1) from history where user_id = new.user_id) - ); - end if; - end if; - - return NEW; -- this is actually ignored for an after trigger, but oh well -end; -$func$ -language plpgsql volatile -- pldfplplpflh -cost 100; -- default value - -create trigger tg_user_history_count - after insert or delete on history - for each row - execute procedure user_history_count(); diff --git a/atuin-server/migrations/20220421073605_fix_count_trigger_delete.sql b/atuin-server/migrations/20220421073605_fix_count_trigger_delete.sql deleted file mode 100644 index 6198f300..00000000 --- a/atuin-server/migrations/20220421073605_fix_count_trigger_delete.sql +++ /dev/null @@ -1,35 +0,0 @@ --- the old version of this function used NEW in the delete part when it should --- use OLD - -create or replace function user_history_count() -returns trigger as -$func$ -begin - if (TG_OP='INSERT') then - update total_history_count_user set total = total + 1 where user_id = new.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - new.user_id, - (select count(1) from history where user_id = new.user_id) - ); - end if; - - elsif (TG_OP='DELETE') then - update total_history_count_user set total = total - 1 where user_id = old.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - old.user_id, - (select count(1) from history where user_id = old.user_id) - ); - end if; - end if; - - return NEW; -- this is actually ignored for an after trigger, but oh well -end; -$func$ -language plpgsql volatile -- pldfplplpflh -cost 100; -- default value diff --git a/atuin-server/migrations/20220421174016_larger-commands.sql b/atuin-server/migrations/20220421174016_larger-commands.sql deleted file mode 100644 index 0ac43433..00000000 --- a/atuin-server/migrations/20220421174016_larger-commands.sql +++ /dev/null @@ -1,3 +0,0 @@ --- Make it 4x larger. Most commands are less than this, but as it's base64 --- SOME are more than 8192. Should be enough for now. -ALTER TABLE history ALTER COLUMN data TYPE varchar(32768); diff --git a/atuin-server/migrations/20220426172813_user-created-at.sql b/atuin-server/migrations/20220426172813_user-created-at.sql deleted file mode 100644 index a9138194..00000000 --- a/atuin-server/migrations/20220426172813_user-created-at.sql +++ /dev/null @@ -1 +0,0 @@ -alter table users add column created_at timestamp not null default now(); diff --git a/atuin-server/migrations/20220505082442_create-events.sql b/atuin-server/migrations/20220505082442_create-events.sql deleted file mode 100644 index 57e16ec7..00000000 --- a/atuin-server/migrations/20220505082442_create-events.sql +++ /dev/null @@ -1,14 +0,0 @@ -create type event_type as enum ('create', 'delete'); - -create table events ( - id bigserial primary key, - client_id text not null unique, -- the client-generated ID - user_id bigserial not null, -- allow multiple users - hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) - timestamp timestamp not null, -- one of the few non-encrypted metadatas - - event_type event_type, - data text not null, -- store the actual history data, encrypted. I don't wanna know! - - created_at timestamp not null default current_timestamp -); diff --git a/atuin-server/migrations/20220610074049_history-length.sql b/atuin-server/migrations/20220610074049_history-length.sql deleted file mode 100644 index b1c23016..00000000 --- a/atuin-server/migrations/20220610074049_history-length.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add migration script here -alter table history alter column data type text; diff --git a/atuin-server/migrations/20230315220537_drop-events.sql b/atuin-server/migrations/20230315220537_drop-events.sql deleted file mode 100644 index fe3cae17..00000000 --- a/atuin-server/migrations/20230315220537_drop-events.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add migration script here -drop table events; diff --git a/atuin-server/migrations/20230315224203_create-deleted.sql b/atuin-server/migrations/20230315224203_create-deleted.sql deleted file mode 100644 index 9a9e6263..00000000 --- a/atuin-server/migrations/20230315224203_create-deleted.sql +++ /dev/null @@ -1,5 +0,0 @@ --- Add migration script here -alter table history add column if not exists deleted_at timestamp; - --- queries will all be selecting the ids of history for a user, that has been deleted -create index if not exists history_deleted_index on history(client_id, user_id, deleted_at); diff --git a/atuin-server/migrations/20230515221038_trigger-delete-only.sql b/atuin-server/migrations/20230515221038_trigger-delete-only.sql deleted file mode 100644 index 3d0bba52..00000000 --- a/atuin-server/migrations/20230515221038_trigger-delete-only.sql +++ /dev/null @@ -1,30 +0,0 @@ --- We do not need to run the trigger on deletes, as the only time we are deleting history is when the user --- has already been deleted --- This actually slows down deleting all the history a good bit! - -create or replace function user_history_count() -returns trigger as -$func$ -begin - if (TG_OP='INSERT') then - update total_history_count_user set total = total + 1 where user_id = new.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - new.user_id, - (select count(1) from history where user_id = new.user_id) - ); - end if; - end if; - - return NEW; -- this is actually ignored for an after trigger, but oh well -end; -$func$ -language plpgsql volatile -- pldfplplpflh -cost 100; -- default value - -create or replace trigger tg_user_history_count - after insert on history - for each row - execute procedure user_history_count(); diff --git a/atuin-server/src/auth.rs b/atuin-server/src/auth.rs deleted file mode 100644 index 52a73108..00000000 --- a/atuin-server/src/auth.rs +++ /dev/null @@ -1,222 +0,0 @@ -/* -use self::diesel::prelude::*; -use eyre::Result; -use rocket::http::Status; -use rocket::request::{self, FromRequest, Outcome, Request}; -use rocket::State; -use rocket_contrib::databases::diesel; -use sodiumoxide::crypto::pwhash::argon2id13; - -use rocket_contrib::json::Json; -use uuid::Uuid; - -use super::models::{NewSession, NewUser, Session, User}; -use super::views::ApiResponse; - -use crate::api::{LoginRequest, RegisterRequest}; -use crate::schema::{sessions, users}; -use crate::settings::Settings; -use crate::utils::hash_secret; - -use super::database::AtuinDbConn; - -#[derive(Debug)] -pub enum KeyError { - Missing, - Invalid, -} - -pub fn verify_str(secret: &str, verify: &str) -> bool { - sodiumoxide::init().unwrap(); - - let mut padded = [0_u8; 128]; - secret.as_bytes().iter().enumerate().for_each(|(i, val)| { - padded[i] = *val; - }); - - match argon2id13::HashedPassword::from_slice(&padded) { - Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), - None => false, - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for User { - type Error = KeyError; - - fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> { - let session: Vec<_> = request.headers().get("authorization").collect(); - - if session.is_empty() { - return Outcome::Failure((Status::BadRequest, KeyError::Missing)); - } else if session.len() > 1 { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - let session: Vec<_> = session[0].split(' ').collect(); - - if session.len() != 2 { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - if session[0] != "Token" { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - let session = session[1]; - - let db = request - .guard::<AtuinDbConn>() - .succeeded() - .expect("failed to load database"); - - let session = sessions::table - .filter(sessions::token.eq(session)) - .first::<Session>(&*db); - - if session.is_err() { - return Outcome::Failure((Status::Unauthorized, KeyError::Invalid)); - } - - let session = session.unwrap(); - - let user = users::table.find(session.user_id).first(&*db); - - match user { - Ok(user) => Outcome::Success(user), - Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)), - } - } -} - -#[get("/user/<user>")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse { - use crate::schema::users::dsl::{username, users}; - - let user: Result<String, diesel::result::Error> = users - .select(username) - .filter(username.eq(user)) - .first(&*conn); - - if user.is_err() { - return ApiResponse { - json: json!({ - "message": "could not find user", - }), - status: Status::NotFound, - }; - } - - let user = user.unwrap(); - - ApiResponse { - json: json!({ "username": user.as_str() }), - status: Status::Ok, - } -} - -#[post("/register", data = "<register>")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn register( - conn: AtuinDbConn, - register: Json<RegisterRequest>, - settings: State<Settings>, -) -> ApiResponse { - if !settings.server.open_registration { - return ApiResponse { - status: Status::BadRequest, - json: json!({ - "message": "registrations are not open" - }), - }; - } - - let hashed = hash_secret(register.password.as_str()); - - let new_user = NewUser { - email: register.email.as_str(), - username: register.username.as_str(), - password: hashed.as_str(), - }; - - let user = diesel::insert_into(users::table) - .values(&new_user) - .get_result(&*conn); - - if user.is_err() { - return ApiResponse { - status: Status::BadRequest, - json: json!({ - "message": "failed to create user - username or email in use?", - }), - }; - } - - let user: User = user.unwrap(); - let token = Uuid::new_v4().to_simple().to_string(); - - let new_session = NewSession { - user_id: user.id, - token: token.as_str(), - }; - - match diesel::insert_into(sessions::table) - .values(&new_session) - .execute(&*conn) - { - Ok(_) => ApiResponse { - status: Status::Ok, - json: json!({"message": "user created!", "session": token}), - }, - Err(_) => ApiResponse { - status: Status::BadRequest, - json: json!({ "message": "failed to create user"}), - }, - } -} - -#[post("/login", data = "<login>")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse { - let user = users::table - .filter(users::username.eq(login.username.as_str())) - .first(&*conn); - - if user.is_err() { - return ApiResponse { - status: Status::NotFound, - json: json!({"message": "user not found"}), - }; - } - - let user: User = user.unwrap(); - - let session = sessions::table - .filter(sessions::user_id.eq(user.id)) - .first(&*conn); - - // a session should exist... - if session.is_err() { - return ApiResponse { - status: Status::InternalServerError, - json: json!({"message": "something went wrong"}), - }; - } - - let verified = verify_str(user.password.as_str(), login.password.as_str()); - - if !verified { - return ApiResponse { - status: Status::NotFound, - json: json!({"message": "user not found"}), - }; - } - - let session: Session = session.unwrap(); - - ApiResponse { - status: Status::Ok, - json: json!({"session": session.token}), - } -} -*/ diff --git a/atuin-server/src/calendar.rs b/atuin-server/src/calendar.rs deleted file mode 100644 index 7c05dce3..00000000 --- a/atuin-server/src/calendar.rs +++ /dev/null @@ -1,17 +0,0 @@ -// Calendar data - -use serde::{Deserialize, Serialize}; - -pub enum TimePeriod { - YEAR, - MONTH, - DAY, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct TimePeriodInfo { - pub count: u64, - - // TODO: Use this for merkle tree magic - pub hash: String, -} diff --git a/atuin-server/src/database.rs b/atuin-server/src/database.rs deleted file mode 100644 index 894fab7b..00000000 --- a/atuin-server/src/database.rs +++ /dev/null @@ -1,510 +0,0 @@ -use std::collections::HashMap; - -use async_trait::async_trait; -use chrono::{Datelike, TimeZone}; -use chronoutil::RelativeDuration; -use sqlx::{postgres::PgPoolOptions, Result}; - -use sqlx::Row; - -use tracing::{debug, instrument, warn}; - -use super::{ - calendar::{TimePeriod, TimePeriodInfo}, - models::{History, NewHistory, NewSession, NewUser, Session, User}, -}; -use crate::settings::Settings; - -use atuin_common::utils::get_days_from_month; - -#[async_trait] -pub trait Database { - async fn get_session(&self, token: &str) -> Result<Session>; - async fn get_session_user(&self, token: &str) -> Result<User>; - async fn add_session(&self, session: &NewSession) -> Result<()>; - - async fn get_user(&self, username: &str) -> Result<User>; - async fn get_user_session(&self, u: &User) -> Result<Session>; - async fn add_user(&self, user: &NewUser) -> Result<i64>; - async fn delete_user(&self, u: &User) -> Result<()>; - - async fn count_history(&self, user: &User) -> Result<i64>; - async fn count_history_cached(&self, user: &User) -> Result<i64>; - - async fn delete_history(&self, user: &User, id: String) -> Result<()>; - async fn deleted_history(&self, user: &User) -> Result<Vec<String>>; - - async fn count_history_range( - &self, - user: &User, - start: chrono::NaiveDateTime, - end: chrono::NaiveDateTime, - ) -> Result<i64>; - async fn count_history_day(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>; - async fn count_history_month(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>; - async fn count_history_year(&self, user: &User, year: i32) -> Result<i64>; - - async fn list_history( - &self, - user: &User, - created_after: chrono::NaiveDateTime, - since: chrono::NaiveDateTime, - host: &str, - page_size: i64, - ) -> Result<Vec<History>>; - - async fn add_history(&self, history: &[NewHistory]) -> Result<()>; - - async fn oldest_history(&self, user: &User) -> Result<History>; - - async fn calendar( - &self, - user: &User, - period: TimePeriod, - year: u64, - month: u64, - ) -> Result<HashMap<u64, TimePeriodInfo>>; -} - -#[derive(Clone)] -pub struct Postgres { - pool: sqlx::Pool<sqlx::postgres::Postgres>, - settings: Settings, -} - -impl Postgres { - pub async fn new(settings: Settings) -> Result<Self> { - let pool = PgPoolOptions::new() - .max_connections(100) - .connect(settings.db_uri.as_str()) - .await?; - - sqlx::migrate!("./migrations").run(&pool).await?; - - Ok(Self { pool, settings }) - } -} - -#[async_trait] -impl Database for Postgres { - #[instrument(skip_all)] - async fn get_session(&self, token: &str) -> Result<Session> { - sqlx::query_as::<_, Session>("select id, user_id, token from sessions where token = $1") - .bind(token) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn get_user(&self, username: &str) -> Result<User> { - sqlx::query_as::<_, User>( - "select id, username, email, password from users where username = $1", - ) - .bind(username) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn get_session_user(&self, token: &str) -> Result<User> { - sqlx::query_as::<_, User>( - "select users.id, users.username, users.email, users.password from users - inner join sessions - on users.id = sessions.user_id - and sessions.token = $1", - ) - .bind(token) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn count_history(&self, user: &User) -> Result<i64> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn count_history_cached(&self, user: &User) -> Result<i64> { - let res: (i32,) = sqlx::query_as( - "select total from total_history_count_user - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0 as i64) - } - - async fn delete_history(&self, user: &User, id: String) -> Result<()> { - sqlx::query( - "update history - set deleted_at = $3 - where user_id = $1 - and client_id = $2 - and deleted_at is null", // don't just keep setting it - ) - .bind(user.id) - .bind(id) - .bind(chrono::Utc::now().naive_utc()) - .fetch_all(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn deleted_history(&self, user: &User) -> Result<Vec<String>> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res = sqlx::query( - "select client_id from history - where user_id = $1 - and deleted_at is not null", - ) - .bind(user.id) - .fetch_all(&self.pool) - .await?; - - let res = res - .iter() - .map(|row| row.get::<String, _>("client_id")) - .collect(); - - Ok(res) - } - - #[instrument(skip_all)] - async fn count_history_range( - &self, - user: &User, - start: chrono::NaiveDateTime, - end: chrono::NaiveDateTime, - ) -> Result<i64> { - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1 - and timestamp >= $2::date - and timestamp < $3::date", - ) - .bind(user.id) - .bind(start) - .bind(end) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - // Count the history for a given year - #[instrument(skip_all)] - async fn count_history_year(&self, user: &User, year: i32) -> Result<i64> { - let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0); - let end = start + RelativeDuration::years(1); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - // Count the history for a given month - #[instrument(skip_all)] - async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> Result<i64> { - let start = chrono::Utc - .ymd(month.year(), month.month(), 1) - .and_hms_nano(0, 0, 0, 0); - - // ofc... - let end = if month.month() < 12 { - chrono::Utc - .ymd(month.year(), month.month() + 1, 1) - .and_hms_nano(0, 0, 0, 0) - } else { - chrono::Utc - .ymd(month.year() + 1, 1, 1) - .and_hms_nano(0, 0, 0, 0) - }; - - debug!("start: {}, end: {}", start, end); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - // Count the history for a given day - #[instrument(skip_all)] - async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> Result<i64> { - let start = chrono::Utc - .ymd(day.year(), day.month(), day.day()) - .and_hms_nano(0, 0, 0, 0); - let end = chrono::Utc - .ymd(day.year(), day.month(), day.day() + 1) - .and_hms_nano(0, 0, 0, 0); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - #[instrument(skip_all)] - async fn list_history( - &self, - user: &User, - created_after: chrono::NaiveDateTime, - since: chrono::NaiveDateTime, - host: &str, - page_size: i64, - ) -> Result<Vec<History>> { - let res = sqlx::query_as::<_, History>( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - and hostname != $2 - and created_at >= $3 - and timestamp >= $4 - order by timestamp asc - limit $5", - ) - .bind(user.id) - .bind(host) - .bind(created_after) - .bind(since) - .bind(page_size) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn add_history(&self, history: &[NewHistory]) -> Result<()> { - let mut tx = self.pool.begin().await?; - - for i in history { - let client_id: &str = &i.client_id; - let hostname: &str = &i.hostname; - let data: &str = &i.data; - - if data.len() > self.settings.max_history_length - && self.settings.max_history_length != 0 - { - // Don't return an error here. We want to insert as much of the - // history list as we can, so log the error and continue going. - - warn!( - "history too long, got length {}, max {}", - data.len(), - self.settings.max_history_length - ); - - continue; - } - - sqlx::query( - "insert into history - (client_id, user_id, hostname, timestamp, data) - values ($1, $2, $3, $4, $5) - on conflict do nothing - ", - ) - .bind(client_id) - .bind(i.user_id) - .bind(hostname) - .bind(i.timestamp) - .bind(data) - .execute(&mut tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn delete_user(&self, u: &User) -> Result<()> { - sqlx::query("delete from sessions where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from users where id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from history where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn add_user(&self, user: &NewUser) -> Result<i64> { - let email: &str = &user.email; - let username: &str = &user.username; - let password: &str = &user.password; - - let res: (i64,) = sqlx::query_as( - "insert into users - (username, email, password) - values($1, $2, $3) - returning id", - ) - .bind(username) - .bind(email) - .bind(password) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn add_session(&self, session: &NewSession) -> Result<()> { - let token: &str = &session.token; - - sqlx::query( - "insert into sessions - (user_id, token) - values($1, $2)", - ) - .bind(session.user_id) - .bind(token) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn get_user_session(&self, u: &User) -> Result<Session> { - sqlx::query_as::<_, Session>("select id, user_id, token from sessions where user_id = $1") - .bind(u.id) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn oldest_history(&self, user: &User) -> Result<History> { - let res = sqlx::query_as::<_, History>( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - order by timestamp asc - limit 1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn calendar( - &self, - user: &User, - period: TimePeriod, - year: u64, - month: u64, - ) -> Result<HashMap<u64, TimePeriodInfo>> { - // TODO: Support different timezones. Right now we assume UTC and - // everything is stored as such. But it _should_ be possible to - // interpret the stored date with a different TZ - - match period { - TimePeriod::YEAR => { - let mut ret = HashMap::new(); - // First we need to work out how far back to calculate. Get the - // oldest history item - let oldest = self.oldest_history(user).await?.timestamp.year(); - let current_year = chrono::Utc::now().year(); - - // All the years we need to get data for - // The upper bound is exclusive, so include current +1 - let years = oldest..current_year + 1; - - for year in years { - let count = self.count_history_year(user, year).await?; - - ret.insert( - year as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - - TimePeriod::MONTH => { - let mut ret = HashMap::new(); - - for month in 1..13 { - let count = self - .count_history_month( - user, - chrono::Utc.ymd(year as i32, month, 1).naive_utc(), - ) - .await?; - - ret.insert( - month as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - - TimePeriod::DAY => { - let mut ret = HashMap::new(); - - for day in 1..get_days_from_month(year as i32, month as u32) { - let count = self - .count_history_day( - user, - chrono::Utc - .ymd(year as i32, month as u32, day as u32) - .naive_utc(), - ) - .await?; - - ret.insert( - day as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - } - } -} diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 1c9dff5f..bb0aa321 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -10,18 +10,20 @@ use tracing::{debug, error, instrument}; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ - calendar::{TimePeriod, TimePeriodInfo}, - database::Database, - models::{NewHistory, User}, - router::AppState, + router::{AppState, UserAuth}, utils::client_version_min, }; +use atuin_server_database::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::NewHistory, + Database, +}; use atuin_common::api::*; #[instrument(skip_all, fields(user.id = user.id))] pub async fn count<DB: Database>( - user: User, + UserAuth(user): UserAuth, state: State<AppState<DB>>, ) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { let db = &state.0.database; @@ -42,7 +44,7 @@ pub async fn count<DB: Database>( #[instrument(skip_all, fields(user.id = user.id))] pub async fn list<DB: Database>( req: Query<SyncHistoryRequest>, - user: User, + UserAuth(user): UserAuth, headers: HeaderMap, state: State<AppState<DB>>, ) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { @@ -101,7 +103,7 @@ pub async fn list<DB: Database>( #[instrument(skip_all, fields(user.id = user.id))] pub async fn delete<DB: Database>( - user: User, + UserAuth(user): UserAuth, state: State<AppState<DB>>, Json(req): Json<DeleteHistoryRequest>, ) -> Result<Json<MessageResponse>, ErrorResponseStatus<'static>> { @@ -123,13 +125,15 @@ pub async fn delete<DB: Database>( #[instrument(skip_all, fields(user.id = user.id))] pub async fn add<DB: Database>( - user: User, + UserAuth(user): UserAuth, state: State<AppState<DB>>, Json(req): Json<Vec<AddHistoryRequest>>, ) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + debug!("request to add {} history items", req.len()); - let history: Vec<NewHistory> = req + let mut history: Vec<NewHistory> = req .into_iter() .map(|h| NewHistory { client_id: h.id, @@ -140,8 +144,24 @@ pub async fn add<DB: Database>( }) .collect(); - let db = &state.0.database; - if let Err(e) = db.add_history(&history).await { + history.retain(|h| { + // keep if within limit, or limit is 0 (unlimited) + let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; + + // Don't return an error here. We want to insert as much of the + // history list as we can, so log the error and continue going. + if !keep { + tracing::warn!( + "history too long, got length {}, max {}", + h.data.len(), + settings.max_history_length + ); + } + + keep + }); + + if let Err(e) = database.add_history(&history).await { error!("failed to add history: {}", e); return Err(ErrorResponse::reply("failed to add history") @@ -155,7 +175,7 @@ pub async fn add<DB: Database>( pub async fn calendar<DB: Database>( Path(focus): Path<String>, Query(params): Query<HashMap<String, u64>>, - user: User, + UserAuth(user): UserAuth, state: State<AppState<DB>>, ) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { let focus = focus.as_str(); diff --git a/atuin-server/src/handlers/status.rs b/atuin-server/src/handlers/status.rs index 97c02886..d9b6afaf 100644 --- a/atuin-server/src/handlers/status.rs +++ b/atuin-server/src/handlers/status.rs @@ -3,7 +3,8 @@ use http::StatusCode; use tracing::instrument; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::{database::Database, models::User, router::AppState}; +use crate::router::{AppState, UserAuth}; +use atuin_server_database::Database; use atuin_common::api::*; @@ -11,7 +12,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); #[instrument(skip_all, fields(user.id = user.id))] pub async fn status<DB: Database>( - user: User, + UserAuth(user): UserAuth, state: State<AppState<DB>>, ) -> Result<Json<StatusResponse>, ErrorResponseStatus<'static>> { let db = &state.0.database; diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index e67828e4..75081155 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -16,10 +16,10 @@ use tracing::{debug, error, info, instrument}; use uuid::Uuid; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::{ - database::Database, - models::{NewSession, NewUser, User}, - router::AppState, +use crate::router::{AppState, UserAuth}; +use atuin_server_database::{ + models::{NewSession, NewUser}, + Database, DbError, }; use reqwest::header::CONTENT_TYPE; @@ -64,11 +64,11 @@ pub async fn get<DB: Database>( let db = &state.0.database; let user = match db.get_user(username.as_ref()).await { Ok(user) => user, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { debug!("user not found: {}", username); return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(err) => { + Err(DbError::Other(err)) => { error!("database error: {}", err); return Err(ErrorResponse::reply("database error") .with_status(StatusCode::INTERNAL_SERVER_ERROR)); @@ -152,7 +152,7 @@ pub async fn register<DB: Database>( #[instrument(skip_all, fields(user.id = user.id))] pub async fn delete<DB: Database>( - user: User, + UserAuth(user): UserAuth, state: State<AppState<DB>>, ) -> Result<Json<DeleteUserResponse>, ErrorResponseStatus<'static>> { debug!("request to delete user {}", user.id); @@ -175,10 +175,10 @@ pub async fn login<DB: Database>( let db = &state.0.database; let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(e) => { + Err(DbError::Other(e)) => { error!("failed to get user {}: {}", login.username.clone(), e); return Err(ErrorResponse::reply("database error") @@ -188,11 +188,11 @@ pub async fn login<DB: Database>( let session = match db.get_user_session(&user).await { Ok(u) => u, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { debug!("user session not found for user id={}", user.id); return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(err) => { + Err(DbError::Other(err)) => { error!("database error for user {}: {}", login.username, err); return Err(ErrorResponse::reply("database error") .with_status(StatusCode::INTERNAL_SERVER_ERROR)); diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs index 01873af9..aa2250d3 100644 --- a/atuin-server/src/lib.rs +++ b/atuin-server/src/lib.rs @@ -2,45 +2,38 @@ use std::net::{IpAddr, SocketAddr}; +use atuin_server_database::Database; use axum::Server; -use database::Postgres; use eyre::{Context, Result}; -use crate::settings::Settings; +mod handlers; +mod router; +mod settings; +mod utils; +pub use settings::Settings; use tokio::signal; -pub mod auth; -pub mod calendar; -pub mod database; -pub mod handlers; -pub mod models; -pub mod router; -pub mod settings; -pub mod utils; - async fn shutdown_signal() { - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to register signal handler") - .recv() - .await; - }; - - tokio::select! { - _ = terminate => (), - } + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register signal handler") + .recv() + .await; eprintln!("Shutting down gracefully..."); } -pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> { +pub async fn launch<Db: Database>( + settings: Settings<Db::Settings>, + host: String, + port: u16, +) -> Result<()> { let host = host.parse::<IpAddr>()?; - let postgres = Postgres::new(settings.clone()) + let db = Db::new(&settings.db_settings) .await - .wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?; + .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; - let r = router::router(postgres, settings); + let r = router::router(db, settings); Server::bind(&SocketAddr::new(host, port)) .serve(r.into_make_service()) diff --git a/atuin-server/src/models.rs b/atuin-server/src/models.rs deleted file mode 100644 index ee84f58a..00000000 --- a/atuin-server/src/models.rs +++ /dev/null @@ -1,49 +0,0 @@ -use chrono::prelude::*; - -#[derive(sqlx::FromRow)] -pub struct History { - pub id: i64, - pub client_id: String, // a client generated ID - pub user_id: i64, - pub hostname: String, - pub timestamp: NaiveDateTime, - - pub data: String, - - pub created_at: NaiveDateTime, -} - -pub struct NewHistory { - pub client_id: String, - pub user_id: i64, - pub hostname: String, - pub timestamp: chrono::NaiveDateTime, - - pub data: String, -} - -#[derive(sqlx::FromRow)] -pub struct User { - pub id: i64, - pub username: String, - pub email: String, - pub password: String, -} - -#[derive(sqlx::FromRow)] -pub struct Session { - pub id: i64, - pub user_id: i64, - pub token: String, -} - -pub struct NewUser { - pub username: String, - pub email: String, - pub password: String, -} - -pub struct NewSession { - pub user_id: i64, - pub token: String, -} diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index 20b11f45..ec558e78 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -10,11 +10,14 @@ use http::request::Parts; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; -use super::{database::Database, handlers}; -use crate::{models::User, settings::Settings}; +use super::handlers; +use crate::settings::Settings; +use atuin_server_database::{models::User, Database}; + +pub struct UserAuth(pub User); #[async_trait] -impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for User +impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth where DB: Database, { @@ -45,7 +48,7 @@ where .await .map_err(|_| http::StatusCode::FORBIDDEN)?; - Ok(user) + Ok(UserAuth(user)) } } @@ -54,15 +57,12 @@ async fn teapot() -> impl IntoResponse { } #[derive(Clone)] -pub struct AppState<DB> { +pub struct AppState<DB: Database> { pub database: DB, - pub settings: Settings, + pub settings: Settings<DB::Settings>, } -pub fn router<DB: Database + Clone + Send + Sync + 'static>( - database: DB, - settings: Settings, -) -> Router { +pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> Router { let routes = Router::new() .route("/", get(handlers::index)) .route("/sync/count", get(handlers::history::count)) diff --git a/atuin-server/src/settings.rs b/atuin-server/src/settings.rs index 981d239f..fb5325d4 100644 --- a/atuin-server/src/settings.rs +++ b/atuin-server/src/settings.rs @@ -3,24 +3,24 @@ use std::{io::prelude::*, path::PathBuf}; use config::{Config, Environment, File as ConfigFile, FileFormat}; use eyre::{eyre, Result}; use fs_err::{create_dir_all, File}; -use serde::{Deserialize, Serialize}; - -pub const HISTORY_PAGE_SIZE: i64 = 100; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Settings { +pub struct Settings<DbSettings> { pub host: String, pub port: u16, pub path: String, - pub db_uri: String, pub open_registration: bool, pub max_history_length: usize, pub page_size: i64, pub register_webhook_url: Option<String>, pub register_webhook_username: String, + + #[serde(flatten)] + pub db_settings: DbSettings, } -impl Settings { +impl<DbSettings: DeserializeOwned> Settings<DbSettings> { pub fn new() -> Result<Self> { let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { PathBuf::from(p) |
