From 8655c93853506acf05f6ae4e58bfc2c6198be254 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 12 Jun 2023 09:04:35 +0100 Subject: refactor server to allow pluggable db and tracing (#1036) * refactor server to allow pluggable db and tracing * clean up * fix descriptions * remove dependencies --- atuin-server/src/auth.rs | 222 --------------- atuin-server/src/calendar.rs | 17 -- atuin-server/src/database.rs | 510 ----------------------------------- atuin-server/src/handlers/history.rs | 44 ++- atuin-server/src/handlers/status.rs | 5 +- atuin-server/src/handlers/user.rs | 22 +- atuin-server/src/lib.rs | 43 ++- atuin-server/src/models.rs | 49 ---- atuin-server/src/router.rs | 20 +- atuin-server/src/settings.rs | 12 +- 10 files changed, 80 insertions(+), 864 deletions(-) delete mode 100644 atuin-server/src/auth.rs delete mode 100644 atuin-server/src/calendar.rs delete mode 100644 atuin-server/src/database.rs delete mode 100644 atuin-server/src/models.rs (limited to 'atuin-server/src') diff --git a/atuin-server/src/auth.rs b/atuin-server/src/auth.rs deleted file mode 100644 index 52a73108..00000000 --- a/atuin-server/src/auth.rs +++ /dev/null @@ -1,222 +0,0 @@ -/* -use self::diesel::prelude::*; -use eyre::Result; -use rocket::http::Status; -use rocket::request::{self, FromRequest, Outcome, Request}; -use rocket::State; -use rocket_contrib::databases::diesel; -use sodiumoxide::crypto::pwhash::argon2id13; - -use rocket_contrib::json::Json; -use uuid::Uuid; - -use super::models::{NewSession, NewUser, Session, User}; -use super::views::ApiResponse; - -use crate::api::{LoginRequest, RegisterRequest}; -use crate::schema::{sessions, users}; -use crate::settings::Settings; -use crate::utils::hash_secret; - -use super::database::AtuinDbConn; - -#[derive(Debug)] -pub enum KeyError { - Missing, - Invalid, -} - -pub fn verify_str(secret: &str, verify: &str) -> bool { - sodiumoxide::init().unwrap(); - - let mut padded = [0_u8; 128]; - secret.as_bytes().iter().enumerate().for_each(|(i, val)| { - padded[i] = *val; - }); - - match argon2id13::HashedPassword::from_slice(&padded) { - Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), - None => false, - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for User { - type Error = KeyError; - - fn from_request(request: &'a Request<'r>) -> request::Outcome { - let session: Vec<_> = request.headers().get("authorization").collect(); - - if session.is_empty() { - return Outcome::Failure((Status::BadRequest, KeyError::Missing)); - } else if session.len() > 1 { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - let session: Vec<_> = session[0].split(' ').collect(); - - if session.len() != 2 { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - if session[0] != "Token" { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - let session = session[1]; - - let db = request - .guard::() - .succeeded() - .expect("failed to load database"); - - let session = sessions::table - .filter(sessions::token.eq(session)) - .first::(&*db); - - if session.is_err() { - return Outcome::Failure((Status::Unauthorized, KeyError::Invalid)); - } - - let session = session.unwrap(); - - let user = users::table.find(session.user_id).first(&*db); - - match user { - Ok(user) => Outcome::Success(user), - Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)), - } - } -} - -#[get("/user/")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse { - use crate::schema::users::dsl::{username, users}; - - let user: Result = users - .select(username) - .filter(username.eq(user)) - .first(&*conn); - - if user.is_err() { - return ApiResponse { - json: json!({ - "message": "could not find user", - }), - status: Status::NotFound, - }; - } - - let user = user.unwrap(); - - ApiResponse { - json: json!({ "username": user.as_str() }), - status: Status::Ok, - } -} - -#[post("/register", data = "")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn register( - conn: AtuinDbConn, - register: Json, - settings: State, -) -> ApiResponse { - if !settings.server.open_registration { - return ApiResponse { - status: Status::BadRequest, - json: json!({ - "message": "registrations are not open" - }), - }; - } - - let hashed = hash_secret(register.password.as_str()); - - let new_user = NewUser { - email: register.email.as_str(), - username: register.username.as_str(), - password: hashed.as_str(), - }; - - let user = diesel::insert_into(users::table) - .values(&new_user) - .get_result(&*conn); - - if user.is_err() { - return ApiResponse { - status: Status::BadRequest, - json: json!({ - "message": "failed to create user - username or email in use?", - }), - }; - } - - let user: User = user.unwrap(); - let token = Uuid::new_v4().to_simple().to_string(); - - let new_session = NewSession { - user_id: user.id, - token: token.as_str(), - }; - - match diesel::insert_into(sessions::table) - .values(&new_session) - .execute(&*conn) - { - Ok(_) => ApiResponse { - status: Status::Ok, - json: json!({"message": "user created!", "session": token}), - }, - Err(_) => ApiResponse { - status: Status::BadRequest, - json: json!({ "message": "failed to create user"}), - }, - } -} - -#[post("/login", data = "")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { - let user = users::table - .filter(users::username.eq(login.username.as_str())) - .first(&*conn); - - if user.is_err() { - return ApiResponse { - status: Status::NotFound, - json: json!({"message": "user not found"}), - }; - } - - let user: User = user.unwrap(); - - let session = sessions::table - .filter(sessions::user_id.eq(user.id)) - .first(&*conn); - - // a session should exist... - if session.is_err() { - return ApiResponse { - status: Status::InternalServerError, - json: json!({"message": "something went wrong"}), - }; - } - - let verified = verify_str(user.password.as_str(), login.password.as_str()); - - if !verified { - return ApiResponse { - status: Status::NotFound, - json: json!({"message": "user not found"}), - }; - } - - let session: Session = session.unwrap(); - - ApiResponse { - status: Status::Ok, - json: json!({"session": session.token}), - } -} -*/ diff --git a/atuin-server/src/calendar.rs b/atuin-server/src/calendar.rs deleted file mode 100644 index 7c05dce3..00000000 --- a/atuin-server/src/calendar.rs +++ /dev/null @@ -1,17 +0,0 @@ -// Calendar data - -use serde::{Deserialize, Serialize}; - -pub enum TimePeriod { - YEAR, - MONTH, - DAY, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct TimePeriodInfo { - pub count: u64, - - // TODO: Use this for merkle tree magic - pub hash: String, -} diff --git a/atuin-server/src/database.rs b/atuin-server/src/database.rs deleted file mode 100644 index 894fab7b..00000000 --- a/atuin-server/src/database.rs +++ /dev/null @@ -1,510 +0,0 @@ -use std::collections::HashMap; - -use async_trait::async_trait; -use chrono::{Datelike, TimeZone}; -use chronoutil::RelativeDuration; -use sqlx::{postgres::PgPoolOptions, Result}; - -use sqlx::Row; - -use tracing::{debug, instrument, warn}; - -use super::{ - calendar::{TimePeriod, TimePeriodInfo}, - models::{History, NewHistory, NewSession, NewUser, Session, User}, -}; -use crate::settings::Settings; - -use atuin_common::utils::get_days_from_month; - -#[async_trait] -pub trait Database { - async fn get_session(&self, token: &str) -> Result; - async fn get_session_user(&self, token: &str) -> Result; - async fn add_session(&self, session: &NewSession) -> Result<()>; - - async fn get_user(&self, username: &str) -> Result; - async fn get_user_session(&self, u: &User) -> Result; - async fn add_user(&self, user: &NewUser) -> Result; - async fn delete_user(&self, u: &User) -> Result<()>; - - async fn count_history(&self, user: &User) -> Result; - async fn count_history_cached(&self, user: &User) -> Result; - - async fn delete_history(&self, user: &User, id: String) -> Result<()>; - async fn deleted_history(&self, user: &User) -> Result>; - - async fn count_history_range( - &self, - user: &User, - start: chrono::NaiveDateTime, - end: chrono::NaiveDateTime, - ) -> Result; - async fn count_history_day(&self, user: &User, date: chrono::NaiveDate) -> Result; - async fn count_history_month(&self, user: &User, date: chrono::NaiveDate) -> Result; - async fn count_history_year(&self, user: &User, year: i32) -> Result; - - async fn list_history( - &self, - user: &User, - created_after: chrono::NaiveDateTime, - since: chrono::NaiveDateTime, - host: &str, - page_size: i64, - ) -> Result>; - - async fn add_history(&self, history: &[NewHistory]) -> Result<()>; - - async fn oldest_history(&self, user: &User) -> Result; - - async fn calendar( - &self, - user: &User, - period: TimePeriod, - year: u64, - month: u64, - ) -> Result>; -} - -#[derive(Clone)] -pub struct Postgres { - pool: sqlx::Pool, - settings: Settings, -} - -impl Postgres { - pub async fn new(settings: Settings) -> Result { - let pool = PgPoolOptions::new() - .max_connections(100) - .connect(settings.db_uri.as_str()) - .await?; - - sqlx::migrate!("./migrations").run(&pool).await?; - - Ok(Self { pool, settings }) - } -} - -#[async_trait] -impl Database for Postgres { - #[instrument(skip_all)] - async fn get_session(&self, token: &str) -> Result { - sqlx::query_as::<_, Session>("select id, user_id, token from sessions where token = $1") - .bind(token) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn get_user(&self, username: &str) -> Result { - sqlx::query_as::<_, User>( - "select id, username, email, password from users where username = $1", - ) - .bind(username) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn get_session_user(&self, token: &str) -> Result { - sqlx::query_as::<_, User>( - "select users.id, users.username, users.email, users.password from users - inner join sessions - on users.id = sessions.user_id - and sessions.token = $1", - ) - .bind(token) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn count_history(&self, user: &User) -> Result { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn count_history_cached(&self, user: &User) -> Result { - let res: (i32,) = sqlx::query_as( - "select total from total_history_count_user - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0 as i64) - } - - async fn delete_history(&self, user: &User, id: String) -> Result<()> { - sqlx::query( - "update history - set deleted_at = $3 - where user_id = $1 - and client_id = $2 - and deleted_at is null", // don't just keep setting it - ) - .bind(user.id) - .bind(id) - .bind(chrono::Utc::now().naive_utc()) - .fetch_all(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn deleted_history(&self, user: &User) -> Result> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res = sqlx::query( - "select client_id from history - where user_id = $1 - and deleted_at is not null", - ) - .bind(user.id) - .fetch_all(&self.pool) - .await?; - - let res = res - .iter() - .map(|row| row.get::("client_id")) - .collect(); - - Ok(res) - } - - #[instrument(skip_all)] - async fn count_history_range( - &self, - user: &User, - start: chrono::NaiveDateTime, - end: chrono::NaiveDateTime, - ) -> Result { - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1 - and timestamp >= $2::date - and timestamp < $3::date", - ) - .bind(user.id) - .bind(start) - .bind(end) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - // Count the history for a given year - #[instrument(skip_all)] - async fn count_history_year(&self, user: &User, year: i32) -> Result { - let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0); - let end = start + RelativeDuration::years(1); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - // Count the history for a given month - #[instrument(skip_all)] - async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> Result { - let start = chrono::Utc - .ymd(month.year(), month.month(), 1) - .and_hms_nano(0, 0, 0, 0); - - // ofc... - let end = if month.month() < 12 { - chrono::Utc - .ymd(month.year(), month.month() + 1, 1) - .and_hms_nano(0, 0, 0, 0) - } else { - chrono::Utc - .ymd(month.year() + 1, 1, 1) - .and_hms_nano(0, 0, 0, 0) - }; - - debug!("start: {}, end: {}", start, end); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - // Count the history for a given day - #[instrument(skip_all)] - async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> Result { - let start = chrono::Utc - .ymd(day.year(), day.month(), day.day()) - .and_hms_nano(0, 0, 0, 0); - let end = chrono::Utc - .ymd(day.year(), day.month(), day.day() + 1) - .and_hms_nano(0, 0, 0, 0); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - #[instrument(skip_all)] - async fn list_history( - &self, - user: &User, - created_after: chrono::NaiveDateTime, - since: chrono::NaiveDateTime, - host: &str, - page_size: i64, - ) -> Result> { - let res = sqlx::query_as::<_, History>( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - and hostname != $2 - and created_at >= $3 - and timestamp >= $4 - order by timestamp asc - limit $5", - ) - .bind(user.id) - .bind(host) - .bind(created_after) - .bind(since) - .bind(page_size) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn add_history(&self, history: &[NewHistory]) -> Result<()> { - let mut tx = self.pool.begin().await?; - - for i in history { - let client_id: &str = &i.client_id; - let hostname: &str = &i.hostname; - let data: &str = &i.data; - - if data.len() > self.settings.max_history_length - && self.settings.max_history_length != 0 - { - // Don't return an error here. We want to insert as much of the - // history list as we can, so log the error and continue going. - - warn!( - "history too long, got length {}, max {}", - data.len(), - self.settings.max_history_length - ); - - continue; - } - - sqlx::query( - "insert into history - (client_id, user_id, hostname, timestamp, data) - values ($1, $2, $3, $4, $5) - on conflict do nothing - ", - ) - .bind(client_id) - .bind(i.user_id) - .bind(hostname) - .bind(i.timestamp) - .bind(data) - .execute(&mut tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn delete_user(&self, u: &User) -> Result<()> { - sqlx::query("delete from sessions where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from users where id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from history where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn add_user(&self, user: &NewUser) -> Result { - let email: &str = &user.email; - let username: &str = &user.username; - let password: &str = &user.password; - - let res: (i64,) = sqlx::query_as( - "insert into users - (username, email, password) - values($1, $2, $3) - returning id", - ) - .bind(username) - .bind(email) - .bind(password) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn add_session(&self, session: &NewSession) -> Result<()> { - let token: &str = &session.token; - - sqlx::query( - "insert into sessions - (user_id, token) - values($1, $2)", - ) - .bind(session.user_id) - .bind(token) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn get_user_session(&self, u: &User) -> Result { - sqlx::query_as::<_, Session>("select id, user_id, token from sessions where user_id = $1") - .bind(u.id) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn oldest_history(&self, user: &User) -> Result { - let res = sqlx::query_as::<_, History>( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - order by timestamp asc - limit 1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn calendar( - &self, - user: &User, - period: TimePeriod, - year: u64, - month: u64, - ) -> Result> { - // TODO: Support different timezones. Right now we assume UTC and - // everything is stored as such. But it _should_ be possible to - // interpret the stored date with a different TZ - - match period { - TimePeriod::YEAR => { - let mut ret = HashMap::new(); - // First we need to work out how far back to calculate. Get the - // oldest history item - let oldest = self.oldest_history(user).await?.timestamp.year(); - let current_year = chrono::Utc::now().year(); - - // All the years we need to get data for - // The upper bound is exclusive, so include current +1 - let years = oldest..current_year + 1; - - for year in years { - let count = self.count_history_year(user, year).await?; - - ret.insert( - year as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - - TimePeriod::MONTH => { - let mut ret = HashMap::new(); - - for month in 1..13 { - let count = self - .count_history_month( - user, - chrono::Utc.ymd(year as i32, month, 1).naive_utc(), - ) - .await?; - - ret.insert( - month as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - - TimePeriod::DAY => { - let mut ret = HashMap::new(); - - for day in 1..get_days_from_month(year as i32, month as u32) { - let count = self - .count_history_day( - user, - chrono::Utc - .ymd(year as i32, month as u32, day as u32) - .naive_utc(), - ) - .await?; - - ret.insert( - day as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - } - } -} diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 1c9dff5f..bb0aa321 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -10,18 +10,20 @@ use tracing::{debug, error, instrument}; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ - calendar::{TimePeriod, TimePeriodInfo}, - database::Database, - models::{NewHistory, User}, - router::AppState, + router::{AppState, UserAuth}, utils::client_version_min, }; +use atuin_server_database::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::NewHistory, + Database, +}; use atuin_common::api::*; #[instrument(skip_all, fields(user.id = user.id))] pub async fn count( - user: User, + UserAuth(user): UserAuth, state: State>, ) -> Result, ErrorResponseStatus<'static>> { let db = &state.0.database; @@ -42,7 +44,7 @@ pub async fn count( #[instrument(skip_all, fields(user.id = user.id))] pub async fn list( req: Query, - user: User, + UserAuth(user): UserAuth, headers: HeaderMap, state: State>, ) -> Result, ErrorResponseStatus<'static>> { @@ -101,7 +103,7 @@ pub async fn list( #[instrument(skip_all, fields(user.id = user.id))] pub async fn delete( - user: User, + UserAuth(user): UserAuth, state: State>, Json(req): Json, ) -> Result, ErrorResponseStatus<'static>> { @@ -123,13 +125,15 @@ pub async fn delete( #[instrument(skip_all, fields(user.id = user.id))] pub async fn add( - user: User, + UserAuth(user): UserAuth, state: State>, Json(req): Json>, ) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + debug!("request to add {} history items", req.len()); - let history: Vec = req + let mut history: Vec = req .into_iter() .map(|h| NewHistory { client_id: h.id, @@ -140,8 +144,24 @@ pub async fn add( }) .collect(); - let db = &state.0.database; - if let Err(e) = db.add_history(&history).await { + history.retain(|h| { + // keep if within limit, or limit is 0 (unlimited) + let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; + + // Don't return an error here. We want to insert as much of the + // history list as we can, so log the error and continue going. + if !keep { + tracing::warn!( + "history too long, got length {}, max {}", + h.data.len(), + settings.max_history_length + ); + } + + keep + }); + + if let Err(e) = database.add_history(&history).await { error!("failed to add history: {}", e); return Err(ErrorResponse::reply("failed to add history") @@ -155,7 +175,7 @@ pub async fn add( pub async fn calendar( Path(focus): Path, Query(params): Query>, - user: User, + UserAuth(user): UserAuth, state: State>, ) -> Result>, ErrorResponseStatus<'static>> { let focus = focus.as_str(); diff --git a/atuin-server/src/handlers/status.rs b/atuin-server/src/handlers/status.rs index 97c02886..d9b6afaf 100644 --- a/atuin-server/src/handlers/status.rs +++ b/atuin-server/src/handlers/status.rs @@ -3,7 +3,8 @@ use http::StatusCode; use tracing::instrument; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::{database::Database, models::User, router::AppState}; +use crate::router::{AppState, UserAuth}; +use atuin_server_database::Database; use atuin_common::api::*; @@ -11,7 +12,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); #[instrument(skip_all, fields(user.id = user.id))] pub async fn status( - user: User, + UserAuth(user): UserAuth, state: State>, ) -> Result, ErrorResponseStatus<'static>> { let db = &state.0.database; diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index e67828e4..75081155 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -16,10 +16,10 @@ use tracing::{debug, error, info, instrument}; use uuid::Uuid; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::{ - database::Database, - models::{NewSession, NewUser, User}, - router::AppState, +use crate::router::{AppState, UserAuth}; +use atuin_server_database::{ + models::{NewSession, NewUser}, + Database, DbError, }; use reqwest::header::CONTENT_TYPE; @@ -64,11 +64,11 @@ pub async fn get( let db = &state.0.database; let user = match db.get_user(username.as_ref()).await { Ok(user) => user, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { debug!("user not found: {}", username); return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(err) => { + Err(DbError::Other(err)) => { error!("database error: {}", err); return Err(ErrorResponse::reply("database error") .with_status(StatusCode::INTERNAL_SERVER_ERROR)); @@ -152,7 +152,7 @@ pub async fn register( #[instrument(skip_all, fields(user.id = user.id))] pub async fn delete( - user: User, + UserAuth(user): UserAuth, state: State>, ) -> Result, ErrorResponseStatus<'static>> { debug!("request to delete user {}", user.id); @@ -175,10 +175,10 @@ pub async fn login( let db = &state.0.database; let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(e) => { + Err(DbError::Other(e)) => { error!("failed to get user {}: {}", login.username.clone(), e); return Err(ErrorResponse::reply("database error") @@ -188,11 +188,11 @@ pub async fn login( let session = match db.get_user_session(&user).await { Ok(u) => u, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { debug!("user session not found for user id={}", user.id); return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(err) => { + Err(DbError::Other(err)) => { error!("database error for user {}: {}", login.username, err); return Err(ErrorResponse::reply("database error") .with_status(StatusCode::INTERNAL_SERVER_ERROR)); diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs index 01873af9..aa2250d3 100644 --- a/atuin-server/src/lib.rs +++ b/atuin-server/src/lib.rs @@ -2,45 +2,38 @@ use std::net::{IpAddr, SocketAddr}; +use atuin_server_database::Database; use axum::Server; -use database::Postgres; use eyre::{Context, Result}; -use crate::settings::Settings; +mod handlers; +mod router; +mod settings; +mod utils; +pub use settings::Settings; use tokio::signal; -pub mod auth; -pub mod calendar; -pub mod database; -pub mod handlers; -pub mod models; -pub mod router; -pub mod settings; -pub mod utils; - async fn shutdown_signal() { - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to register signal handler") - .recv() - .await; - }; - - tokio::select! { - _ = terminate => (), - } + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register signal handler") + .recv() + .await; eprintln!("Shutting down gracefully..."); } -pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> { +pub async fn launch( + settings: Settings, + host: String, + port: u16, +) -> Result<()> { let host = host.parse::()?; - let postgres = Postgres::new(settings.clone()) + let db = Db::new(&settings.db_settings) .await - .wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?; + .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; - let r = router::router(postgres, settings); + let r = router::router(db, settings); Server::bind(&SocketAddr::new(host, port)) .serve(r.into_make_service()) diff --git a/atuin-server/src/models.rs b/atuin-server/src/models.rs deleted file mode 100644 index ee84f58a..00000000 --- a/atuin-server/src/models.rs +++ /dev/null @@ -1,49 +0,0 @@ -use chrono::prelude::*; - -#[derive(sqlx::FromRow)] -pub struct History { - pub id: i64, - pub client_id: String, // a client generated ID - pub user_id: i64, - pub hostname: String, - pub timestamp: NaiveDateTime, - - pub data: String, - - pub created_at: NaiveDateTime, -} - -pub struct NewHistory { - pub client_id: String, - pub user_id: i64, - pub hostname: String, - pub timestamp: chrono::NaiveDateTime, - - pub data: String, -} - -#[derive(sqlx::FromRow)] -pub struct User { - pub id: i64, - pub username: String, - pub email: String, - pub password: String, -} - -#[derive(sqlx::FromRow)] -pub struct Session { - pub id: i64, - pub user_id: i64, - pub token: String, -} - -pub struct NewUser { - pub username: String, - pub email: String, - pub password: String, -} - -pub struct NewSession { - pub user_id: i64, - pub token: String, -} diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index 20b11f45..ec558e78 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -10,11 +10,14 @@ use http::request::Parts; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; -use super::{database::Database, handlers}; -use crate::{models::User, settings::Settings}; +use super::handlers; +use crate::settings::Settings; +use atuin_server_database::{models::User, Database}; + +pub struct UserAuth(pub User); #[async_trait] -impl FromRequestParts> for User +impl FromRequestParts> for UserAuth where DB: Database, { @@ -45,7 +48,7 @@ where .await .map_err(|_| http::StatusCode::FORBIDDEN)?; - Ok(user) + Ok(UserAuth(user)) } } @@ -54,15 +57,12 @@ async fn teapot() -> impl IntoResponse { } #[derive(Clone)] -pub struct AppState { +pub struct AppState { pub database: DB, - pub settings: Settings, + pub settings: Settings, } -pub fn router( - database: DB, - settings: Settings, -) -> Router { +pub fn router(database: DB, settings: Settings) -> Router { let routes = Router::new() .route("/", get(handlers::index)) .route("/sync/count", get(handlers::history::count)) diff --git a/atuin-server/src/settings.rs b/atuin-server/src/settings.rs index 981d239f..fb5325d4 100644 --- a/atuin-server/src/settings.rs +++ b/atuin-server/src/settings.rs @@ -3,24 +3,24 @@ use std::{io::prelude::*, path::PathBuf}; use config::{Config, Environment, File as ConfigFile, FileFormat}; use eyre::{eyre, Result}; use fs_err::{create_dir_all, File}; -use serde::{Deserialize, Serialize}; - -pub const HISTORY_PAGE_SIZE: i64 = 100; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Settings { +pub struct Settings { pub host: String, pub port: u16, pub path: String, - pub db_uri: String, pub open_registration: bool, pub max_history_length: usize, pub page_size: i64, pub register_webhook_url: Option, pub register_webhook_username: String, + + #[serde(flatten)] + pub db_settings: DbSettings, } -impl Settings { +impl Settings { pub fn new() -> Result { let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { PathBuf::from(p) -- cgit v1.3.1