diff options
| author | Conrad Ludgate <conradludgate@gmail.com> | 2023-06-12 09:04:35 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-06-12 09:04:35 +0100 |
| commit | 8655c93853506acf05f6ae4e58bfc2c6198be254 (patch) | |
| tree | 22d20b35636ad2eb717d58c93ae07378adbb76eb /atuin-server/src/database.rs | |
| parent | Make Ctrl-d behaviour match other tools (#1040) (diff) | |
| download | atuin-8655c93853506acf05f6ae4e58bfc2c6198be254.zip | |
refactor server to allow pluggable db and tracing (#1036)
* refactor server to allow pluggable db and tracing
* clean up
* fix descriptions
* remove dependencies
Diffstat (limited to 'atuin-server/src/database.rs')
| -rw-r--r-- | atuin-server/src/database.rs | 510 |
1 files changed, 0 insertions, 510 deletions
diff --git a/atuin-server/src/database.rs b/atuin-server/src/database.rs deleted file mode 100644 index 894fab7b..00000000 --- a/atuin-server/src/database.rs +++ /dev/null @@ -1,510 +0,0 @@ -use std::collections::HashMap; - -use async_trait::async_trait; -use chrono::{Datelike, TimeZone}; -use chronoutil::RelativeDuration; -use sqlx::{postgres::PgPoolOptions, Result}; - -use sqlx::Row; - -use tracing::{debug, instrument, warn}; - -use super::{ - calendar::{TimePeriod, TimePeriodInfo}, - models::{History, NewHistory, NewSession, NewUser, Session, User}, -}; -use crate::settings::Settings; - -use atuin_common::utils::get_days_from_month; - -#[async_trait] -pub trait Database { - async fn get_session(&self, token: &str) -> Result<Session>; - async fn get_session_user(&self, token: &str) -> Result<User>; - async fn add_session(&self, session: &NewSession) -> Result<()>; - - async fn get_user(&self, username: &str) -> Result<User>; - async fn get_user_session(&self, u: &User) -> Result<Session>; - async fn add_user(&self, user: &NewUser) -> Result<i64>; - async fn delete_user(&self, u: &User) -> Result<()>; - - async fn count_history(&self, user: &User) -> Result<i64>; - async fn count_history_cached(&self, user: &User) -> Result<i64>; - - async fn delete_history(&self, user: &User, id: String) -> Result<()>; - async fn deleted_history(&self, user: &User) -> Result<Vec<String>>; - - async fn count_history_range( - &self, - user: &User, - start: chrono::NaiveDateTime, - end: chrono::NaiveDateTime, - ) -> Result<i64>; - async fn count_history_day(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>; - async fn count_history_month(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>; - async fn count_history_year(&self, user: &User, year: i32) -> Result<i64>; - - async fn list_history( - &self, - user: &User, - created_after: chrono::NaiveDateTime, - since: chrono::NaiveDateTime, - host: &str, - page_size: i64, - ) -> Result<Vec<History>>; - - async fn add_history(&self, history: &[NewHistory]) -> Result<()>; - - async fn oldest_history(&self, user: &User) -> Result<History>; - - async fn calendar( - &self, - user: &User, - period: TimePeriod, - year: u64, - month: u64, - ) -> Result<HashMap<u64, TimePeriodInfo>>; -} - -#[derive(Clone)] -pub struct Postgres { - pool: sqlx::Pool<sqlx::postgres::Postgres>, - settings: Settings, -} - -impl Postgres { - pub async fn new(settings: Settings) -> Result<Self> { - let pool = PgPoolOptions::new() - .max_connections(100) - .connect(settings.db_uri.as_str()) - .await?; - - sqlx::migrate!("./migrations").run(&pool).await?; - - Ok(Self { pool, settings }) - } -} - -#[async_trait] -impl Database for Postgres { - #[instrument(skip_all)] - async fn get_session(&self, token: &str) -> Result<Session> { - sqlx::query_as::<_, Session>("select id, user_id, token from sessions where token = $1") - .bind(token) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn get_user(&self, username: &str) -> Result<User> { - sqlx::query_as::<_, User>( - "select id, username, email, password from users where username = $1", - ) - .bind(username) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn get_session_user(&self, token: &str) -> Result<User> { - sqlx::query_as::<_, User>( - "select users.id, users.username, users.email, users.password from users - inner join sessions - on users.id = sessions.user_id - and sessions.token = $1", - ) - .bind(token) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn count_history(&self, user: &User) -> Result<i64> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn count_history_cached(&self, user: &User) -> Result<i64> { - let res: (i32,) = sqlx::query_as( - "select total from total_history_count_user - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0 as i64) - } - - async fn delete_history(&self, user: &User, id: String) -> Result<()> { - sqlx::query( - "update history - set deleted_at = $3 - where user_id = $1 - and client_id = $2 - and deleted_at is null", // don't just keep setting it - ) - .bind(user.id) - .bind(id) - .bind(chrono::Utc::now().naive_utc()) - .fetch_all(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn deleted_history(&self, user: &User) -> Result<Vec<String>> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res = sqlx::query( - "select client_id from history - where user_id = $1 - and deleted_at is not null", - ) - .bind(user.id) - .fetch_all(&self.pool) - .await?; - - let res = res - .iter() - .map(|row| row.get::<String, _>("client_id")) - .collect(); - - Ok(res) - } - - #[instrument(skip_all)] - async fn count_history_range( - &self, - user: &User, - start: chrono::NaiveDateTime, - end: chrono::NaiveDateTime, - ) -> Result<i64> { - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1 - and timestamp >= $2::date - and timestamp < $3::date", - ) - .bind(user.id) - .bind(start) - .bind(end) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - // Count the history for a given year - #[instrument(skip_all)] - async fn count_history_year(&self, user: &User, year: i32) -> Result<i64> { - let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0); - let end = start + RelativeDuration::years(1); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - // Count the history for a given month - #[instrument(skip_all)] - async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> Result<i64> { - let start = chrono::Utc - .ymd(month.year(), month.month(), 1) - .and_hms_nano(0, 0, 0, 0); - - // ofc... - let end = if month.month() < 12 { - chrono::Utc - .ymd(month.year(), month.month() + 1, 1) - .and_hms_nano(0, 0, 0, 0) - } else { - chrono::Utc - .ymd(month.year() + 1, 1, 1) - .and_hms_nano(0, 0, 0, 0) - }; - - debug!("start: {}, end: {}", start, end); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - // Count the history for a given day - #[instrument(skip_all)] - async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> Result<i64> { - let start = chrono::Utc - .ymd(day.year(), day.month(), day.day()) - .and_hms_nano(0, 0, 0, 0); - let end = chrono::Utc - .ymd(day.year(), day.month(), day.day() + 1) - .and_hms_nano(0, 0, 0, 0); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - #[instrument(skip_all)] - async fn list_history( - &self, - user: &User, - created_after: chrono::NaiveDateTime, - since: chrono::NaiveDateTime, - host: &str, - page_size: i64, - ) -> Result<Vec<History>> { - let res = sqlx::query_as::<_, History>( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - and hostname != $2 - and created_at >= $3 - and timestamp >= $4 - order by timestamp asc - limit $5", - ) - .bind(user.id) - .bind(host) - .bind(created_after) - .bind(since) - .bind(page_size) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn add_history(&self, history: &[NewHistory]) -> Result<()> { - let mut tx = self.pool.begin().await?; - - for i in history { - let client_id: &str = &i.client_id; - let hostname: &str = &i.hostname; - let data: &str = &i.data; - - if data.len() > self.settings.max_history_length - && self.settings.max_history_length != 0 - { - // Don't return an error here. We want to insert as much of the - // history list as we can, so log the error and continue going. - - warn!( - "history too long, got length {}, max {}", - data.len(), - self.settings.max_history_length - ); - - continue; - } - - sqlx::query( - "insert into history - (client_id, user_id, hostname, timestamp, data) - values ($1, $2, $3, $4, $5) - on conflict do nothing - ", - ) - .bind(client_id) - .bind(i.user_id) - .bind(hostname) - .bind(i.timestamp) - .bind(data) - .execute(&mut tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn delete_user(&self, u: &User) -> Result<()> { - sqlx::query("delete from sessions where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from users where id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from history where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn add_user(&self, user: &NewUser) -> Result<i64> { - let email: &str = &user.email; - let username: &str = &user.username; - let password: &str = &user.password; - - let res: (i64,) = sqlx::query_as( - "insert into users - (username, email, password) - values($1, $2, $3) - returning id", - ) - .bind(username) - .bind(email) - .bind(password) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn add_session(&self, session: &NewSession) -> Result<()> { - let token: &str = &session.token; - - sqlx::query( - "insert into sessions - (user_id, token) - values($1, $2)", - ) - .bind(session.user_id) - .bind(token) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn get_user_session(&self, u: &User) -> Result<Session> { - sqlx::query_as::<_, Session>("select id, user_id, token from sessions where user_id = $1") - .bind(u.id) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn oldest_history(&self, user: &User) -> Result<History> { - let res = sqlx::query_as::<_, History>( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - order by timestamp asc - limit 1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn calendar( - &self, - user: &User, - period: TimePeriod, - year: u64, - month: u64, - ) -> Result<HashMap<u64, TimePeriodInfo>> { - // TODO: Support different timezones. Right now we assume UTC and - // everything is stored as such. But it _should_ be possible to - // interpret the stored date with a different TZ - - match period { - TimePeriod::YEAR => { - let mut ret = HashMap::new(); - // First we need to work out how far back to calculate. Get the - // oldest history item - let oldest = self.oldest_history(user).await?.timestamp.year(); - let current_year = chrono::Utc::now().year(); - - // All the years we need to get data for - // The upper bound is exclusive, so include current +1 - let years = oldest..current_year + 1; - - for year in years { - let count = self.count_history_year(user, year).await?; - - ret.insert( - year as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - - TimePeriod::MONTH => { - let mut ret = HashMap::new(); - - for month in 1..13 { - let count = self - .count_history_month( - user, - chrono::Utc.ymd(year as i32, month, 1).naive_utc(), - ) - .await?; - - ret.insert( - month as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - - TimePeriod::DAY => { - let mut ret = HashMap::new(); - - for day in 1..get_days_from_month(year as i32, month as u32) { - let count = self - .count_history_day( - user, - chrono::Utc - .ymd(year as i32, month as u32, day as u32) - .naive_utc(), - ) - .await?; - - ret.insert( - day as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - } - } -} |
