From 95cc472037fcb3207b510e67f1a44af4e2a2cae9 Mon Sep 17 00:00:00 2001 From: Ellie Huxtable Date: Thu, 18 Apr 2024 16:41:28 +0100 Subject: chore: move crates into crates/ dir (#1958) I'd like to tidy up the root a little, and it's nice to have all the rust crates in one place --- crates/atuin-server/Cargo.toml | 39 ++++ crates/atuin-server/server.toml | 34 ++++ crates/atuin-server/src/handlers/history.rs | 237 +++++++++++++++++++++++ crates/atuin-server/src/handlers/mod.rs | 58 ++++++ crates/atuin-server/src/handlers/record.rs | 45 +++++ crates/atuin-server/src/handlers/status.rs | 43 +++++ crates/atuin-server/src/handlers/user.rs | 258 ++++++++++++++++++++++++++ crates/atuin-server/src/handlers/v0/me.rs | 16 ++ crates/atuin-server/src/handlers/v0/mod.rs | 3 + crates/atuin-server/src/handlers/v0/record.rs | 112 +++++++++++ crates/atuin-server/src/handlers/v0/store.rs | 37 ++++ crates/atuin-server/src/lib.rs | 144 ++++++++++++++ crates/atuin-server/src/metrics.rs | 56 ++++++ crates/atuin-server/src/router.rs | 149 +++++++++++++++ crates/atuin-server/src/settings.rs | 151 +++++++++++++++ crates/atuin-server/src/utils.rs | 15 ++ 16 files changed, 1397 insertions(+) create mode 100644 crates/atuin-server/Cargo.toml create mode 100644 crates/atuin-server/server.toml create mode 100644 crates/atuin-server/src/handlers/history.rs create mode 100644 crates/atuin-server/src/handlers/mod.rs create mode 100644 crates/atuin-server/src/handlers/record.rs create mode 100644 crates/atuin-server/src/handlers/status.rs create mode 100644 crates/atuin-server/src/handlers/user.rs create mode 100644 crates/atuin-server/src/handlers/v0/me.rs create mode 100644 crates/atuin-server/src/handlers/v0/mod.rs create mode 100644 crates/atuin-server/src/handlers/v0/record.rs create mode 100644 crates/atuin-server/src/handlers/v0/store.rs create mode 100644 crates/atuin-server/src/lib.rs create mode 100644 crates/atuin-server/src/metrics.rs create mode 100644 crates/atuin-server/src/router.rs create mode 100644 crates/atuin-server/src/settings.rs create mode 100644 crates/atuin-server/src/utils.rs (limited to 'crates/atuin-server') diff --git a/crates/atuin-server/Cargo.toml b/crates/atuin-server/Cargo.toml new file mode 100644 index 00000000..a6b8a9f6 --- /dev/null +++ b/crates/atuin-server/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "atuin-server" +edition = "2021" +description = "server library for atuin" + +rust-version = { workspace = true } +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +atuin-common = { path = "../atuin-common", version = "18.2.0" } +atuin-server-database = { path = "../atuin-server-database", version = "18.2.0" } + +tracing = "0.1" +time = { workspace = true } +eyre = { workspace = true } +uuid = { workspace = true } +config = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +base64 = { workspace = true } +rand = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } +axum = "0.7.4" +axum-server = { version = "0.6.0", features = ["tls-rustls"] } +fs-err = { workspace = true } +tower = "0.4" +tower-http = { version = "0.5.1", features = ["trace"] } +reqwest = { workspace = true } +rustls = "0.21" +rustls-pemfile = "2.1" +argon2 = "0.5.3" +semver = { workspace = true } +metrics-exporter-prometheus = "0.12.1" +metrics = "0.21.1" diff --git a/crates/atuin-server/server.toml b/crates/atuin-server/server.toml new file mode 100644 index 00000000..946769c9 --- /dev/null +++ b/crates/atuin-server/server.toml @@ -0,0 +1,34 @@ +## host to bind, can also be passed via CLI args +# host = "127.0.0.1" + +## port to bind, can also be passed via CLI args +# port = 8888 + +## whether to allow anyone to register an account +# open_registration = false + +## URI for postgres (using development creds here) +# db_uri="postgres://username:password@localhost/atuin" + +## Maximum size for one history entry +# max_history_length = 8192 + +## Maximum size for one record entry +## 1024 * 1024 * 1024 +# max_record_size = 1073741824 + +## Webhook to be called when user registers on the servers +# register_webhook_username = "" + +## Default page size for requests +# page_size = 1100 + +# [metrics] +# enable = false +# host = 127.0.0.1 +# port = 9001 + +# [tls] +# enable = false +# cert_path = "" +# pkey_path = "" diff --git a/crates/atuin-server/src/handlers/history.rs b/crates/atuin-server/src/handlers/history.rs new file mode 100644 index 00000000..05bbe740 --- /dev/null +++ b/crates/atuin-server/src/handlers/history.rs @@ -0,0 +1,237 @@ +use std::{collections::HashMap, convert::TryFrom}; + +use axum::{ + extract::{Path, Query, State}, + http::{HeaderMap, StatusCode}, + Json, +}; +use metrics::counter; +use time::{Month, UtcOffset}; +use tracing::{debug, error, instrument}; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::{ + router::{AppState, UserAuth}, + utils::client_version_min, +}; +use atuin_server_database::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::NewHistory, + Database, +}; + +use atuin_common::api::*; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn count( + UserAuth(user): UserAuth, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + match db.count_history_cached(&user).await { + // By default read out the cached value + Ok(count) => Ok(Json(CountResponse { count })), + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => Ok(Json(CountResponse { count })), + Err(_) => Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)), + }, + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn list( + req: Query, + UserAuth(user): UserAuth, + headers: HeaderMap, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let agent = headers + .get("user-agent") + .map_or("", |v| v.to_str().unwrap_or("")); + + let variable_page_size = client_version_min(agent, ">=15.0.0").unwrap_or(false); + + let page_size = if variable_page_size { + state.settings.page_size + } else { + 100 + }; + + if req.sync_ts.unix_timestamp_nanos() < 0 || req.history_ts.unix_timestamp_nanos() < 0 { + error!("client asked for history from < epoch 0"); + counter!("atuin_history_epoch_before_zero", 1); + + return Err( + ErrorResponse::reply("asked for history from before epoch 0") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + let history = db + .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size) + .await; + + if let Err(e) = history { + error!("failed to load history: {}", e); + return Err(ErrorResponse::reply("failed to load history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + let history: Vec = history + .unwrap() + .iter() + .map(|i| i.data.to_string()) + .collect(); + + debug!( + "loaded {} items of history for user {}", + history.len(), + user.id + ); + + counter!("atuin_history_returned", history.len() as u64); + + Ok(Json(SyncHistoryResponse { history })) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete( + UserAuth(user): UserAuth, + state: State>, + Json(req): Json, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + // user_id is the ID of the history, as set by the user (the server has its own ID) + let deleted = db.delete_history(&user, req.client_id).await; + + if let Err(e) = deleted { + error!("failed to delete history: {}", e); + return Err(ErrorResponse::reply("failed to delete history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + Ok(Json(MessageResponse { + message: String::from("deleted OK"), + })) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn add( + UserAuth(user): UserAuth, + state: State>, + Json(req): Json>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + debug!("request to add {} history items", req.len()); + counter!("atuin_history_uploaded", req.len() as u64); + + let mut history: Vec = req + .into_iter() + .map(|h| NewHistory { + client_id: h.id, + user_id: user.id, + hostname: h.hostname, + timestamp: h.timestamp, + data: h.data, + }) + .collect(); + + history.retain(|h| { + // keep if within limit, or limit is 0 (unlimited) + let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; + + // Don't return an error here. We want to insert as much of the + // history list as we can, so log the error and continue going. + if !keep { + counter!("atuin_history_too_long", 1); + + tracing::warn!( + "history too long, got length {}, max {}", + h.data.len(), + settings.max_history_length + ); + } + + keep + }); + + if let Err(e) = database.add_history(&history).await { + error!("failed to add history: {}", e); + + return Err(ErrorResponse::reply("failed to add history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[derive(serde::Deserialize, Debug)] +pub struct CalendarQuery { + #[serde(default = "serde_calendar::zero")] + year: i32, + #[serde(default = "serde_calendar::one")] + month: u8, + + #[serde(default = "serde_calendar::utc")] + tz: UtcOffset, +} + +mod serde_calendar { + use time::UtcOffset; + + pub fn zero() -> i32 { + 0 + } + + pub fn one() -> u8 { + 1 + } + + pub fn utc() -> UtcOffset { + UtcOffset::UTC + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn calendar( + Path(focus): Path, + Query(params): Query, + UserAuth(user): UserAuth, + state: State>, +) -> Result>, ErrorResponseStatus<'static>> { + let focus = focus.as_str(); + + let year = params.year; + let month = Month::try_from(params.month).map_err(|e| ErrorResponseStatus { + error: ErrorResponse { + reason: e.to_string().into(), + }, + status: StatusCode::BAD_REQUEST, + })?; + + let period = match focus { + "year" => TimePeriod::Year, + "month" => TimePeriod::Month { year }, + "day" => TimePeriod::Day { year, month }, + _ => { + return Err(ErrorResponse::reply("invalid focus: use year/month/day") + .with_status(StatusCode::BAD_REQUEST)) + } + }; + + let db = &state.0.database; + let focus = db.calendar(&user, period, params.tz).await.map_err(|_| { + ErrorResponse::reply("failed to query calendar") + .with_status(StatusCode::INTERNAL_SERVER_ERROR) + })?; + + Ok(Json(focus)) +} diff --git a/crates/atuin-server/src/handlers/mod.rs b/crates/atuin-server/src/handlers/mod.rs new file mode 100644 index 00000000..50f82336 --- /dev/null +++ b/crates/atuin-server/src/handlers/mod.rs @@ -0,0 +1,58 @@ +use atuin_common::api::{ErrorResponse, IndexResponse}; +use atuin_server_database::Database; +use axum::{extract::State, http, response::IntoResponse, Json}; + +use crate::router::AppState; + +pub mod history; +pub mod record; +pub mod status; +pub mod user; +pub mod v0; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +pub async fn index(state: State>) -> Json { + let homage = r#""Through the fathomless deeps of space swims the star turtle Great A'Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld." -- Sir Terry Pratchett"#; + + // Error with a -1 response + // It's super unlikely this will happen + let count = state.database.total_history().await.unwrap_or(-1); + + Json(IndexResponse { + homage: homage.to_string(), + version: VERSION.to_string(), + total_history: count, + }) +} + +impl<'a> IntoResponse for ErrorResponseStatus<'a> { + fn into_response(self) -> axum::response::Response { + (self.status, Json(self.error)).into_response() + } +} + +pub struct ErrorResponseStatus<'a> { + pub error: ErrorResponse<'a>, + pub status: http::StatusCode, +} + +pub trait RespExt<'a> { + fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a>; + fn reply(reason: &'a str) -> Self; +} + +impl<'a> RespExt<'a> for ErrorResponse<'a> { + fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a> { + ErrorResponseStatus { + error: self, + status, + } + } + + fn reply(reason: &'a str) -> ErrorResponse { + Self { + reason: reason.into(), + } + } +} diff --git a/crates/atuin-server/src/handlers/record.rs b/crates/atuin-server/src/handlers/record.rs new file mode 100644 index 00000000..bf454949 --- /dev/null +++ b/crates/atuin-server/src/handlers/record.rs @@ -0,0 +1,45 @@ +use axum::{http::StatusCode, response::IntoResponse, Json}; +use serde_json::json; +use tracing::instrument; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::router::UserAuth; +use atuin_server_database::Database; + +use atuin_common::record::{EncryptedData, Record}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post( + UserAuth(user): UserAuth, +) -> Result<(), ErrorResponseStatus<'static>> { + // anyone who has actually used the old record store (a very small number) will see this error + // upon trying to sync. + // 1. The status endpoint will say that the server has nothing + // 2. The client will try to upload local records + // 3. Sync will fail with this error + + // If the client has no local records, they will see the empty index and do nothing. For the + // vast majority of users, this is the case. + return Err( + ErrorResponse::reply("record store deprecated; please upgrade") + .with_status(StatusCode::BAD_REQUEST), + ); +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index(UserAuth(user): UserAuth) -> axum::response::Response { + let ret = json!({ + "hosts": {} + }); + + ret.to_string().into_response() +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next( + UserAuth(user): UserAuth, +) -> Result>>, ErrorResponseStatus<'static>> { + let records = Vec::new(); + + Ok(Json(records)) +} diff --git a/crates/atuin-server/src/handlers/status.rs b/crates/atuin-server/src/handlers/status.rs new file mode 100644 index 00000000..3c22232c --- /dev/null +++ b/crates/atuin-server/src/handlers/status.rs @@ -0,0 +1,43 @@ +use axum::{extract::State, http::StatusCode, Json}; +use tracing::instrument; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::router::{AppState, UserAuth}; +use atuin_server_database::Database; + +use atuin_common::api::*; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn status( + UserAuth(user): UserAuth, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let deleted = db.deleted_history(&user).await.unwrap_or(vec![]); + + let count = match db.count_history_cached(&user).await { + // By default read out the cached value + Ok(count) => count, + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => count, + Err(_) => { + return Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)) + } + }, + }; + + Ok(Json(StatusResponse { + count, + deleted, + username: user.username, + version: VERSION.to_string(), + page_size: state.settings.page_size, + })) +} diff --git a/crates/atuin-server/src/handlers/user.rs b/crates/atuin-server/src/handlers/user.rs new file mode 100644 index 00000000..e5651fe2 --- /dev/null +++ b/crates/atuin-server/src/handlers/user.rs @@ -0,0 +1,258 @@ +use std::borrow::Borrow; +use std::collections::HashMap; +use std::time::Duration; + +use argon2::{ + password_hash::SaltString, Algorithm, Argon2, Params, PasswordHash, PasswordHasher, + PasswordVerifier, Version, +}; +use axum::{ + extract::{Path, State}, + http::StatusCode, + Json, +}; +use metrics::counter; +use rand::rngs::OsRng; +use tracing::{debug, error, info, instrument}; +use uuid::Uuid; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::router::{AppState, UserAuth}; +use atuin_server_database::{ + models::{NewSession, NewUser}, + Database, DbError, +}; + +use reqwest::header::CONTENT_TYPE; + +use atuin_common::api::*; + +pub fn verify_str(hash: &str, password: &str) -> bool { + let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); + let Ok(hash) = PasswordHash::new(hash) else { + return false; + }; + arg2.verify_password(password.as_bytes(), &hash).is_ok() +} + +// Try to send a Discord webhook once - if it fails, we don't retry. "At most once", and best effort. +// Don't return the status because if this fails, we don't really care. +async fn send_register_hook(url: &str, username: String, registered: String) { + let hook = HashMap::from([ + ("username", username), + ("content", format!("{registered} has just signed up!")), + ]); + + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .timeout(Duration::new(5, 0)) + .header(CONTENT_TYPE, "application/json") + .json(&hook) + .send() + .await; + + match resp { + Ok(_) => info!("register webhook sent ok!"), + Err(e) => error!("failed to send register webhook: {}", e), + } +} + +#[instrument(skip_all, fields(user.username = username.as_str()))] +pub async fn get( + Path(username): Path, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + let user = match db.get_user(username.as_ref()).await { + Ok(user) => user, + Err(DbError::NotFound) => { + debug!("user not found: {}", username); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(err)) => { + error!("database error: {}", err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + Ok(Json(UserResponse { + username: user.username, + })) +} + +#[instrument(skip_all)] +pub async fn register( + state: State>, + Json(register): Json, +) -> Result, ErrorResponseStatus<'static>> { + if !state.settings.open_registration { + return Err( + ErrorResponse::reply("this server is not open for registrations") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + for c in register.username.chars() { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' => {} + _ => { + return Err(ErrorResponse::reply( + "Only alphanumeric and hyphens (-) are allowed in usernames", + ) + .with_status(StatusCode::BAD_REQUEST)) + } + } + } + + let hashed = hash_secret(®ister.password); + + let new_user = NewUser { + email: register.email.clone(), + username: register.username.clone(), + password: hashed, + }; + + let db = &state.0.database; + let user_id = match db.add_user(&new_user).await { + Ok(id) => id, + Err(e) => { + error!("failed to add user: {}", e); + return Err( + ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST) + ); + } + }; + + let token = Uuid::new_v4().as_simple().to_string(); + + let new_session = NewSession { + user_id, + token: (&token).into(), + }; + + if let Some(url) = &state.settings.register_webhook_url { + // Could probs be run on another thread, but it's ok atm + send_register_hook( + url, + state.settings.register_webhook_username.clone(), + register.username, + ) + .await; + } + + counter!("atuin_users_registered", 1); + + match db.add_session(&new_session).await { + Ok(_) => Ok(Json(RegisterResponse { session: token })), + Err(e) => { + error!("failed to add session: {}", e); + Err(ErrorResponse::reply("failed to register user") + .with_status(StatusCode::BAD_REQUEST)) + } + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete( + UserAuth(user): UserAuth, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + debug!("request to delete user {}", user.id); + + let db = &state.0.database; + if let Err(e) = db.delete_user(&user).await { + error!("failed to delete user: {}", e); + + return Err(ErrorResponse::reply("failed to delete user") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + counter!("atuin_users_deleted", 1); + + Ok(Json(DeleteUserResponse {})) +} + +#[instrument(skip_all, fields(user.id = user.id, change_password))] +pub async fn change_password( + UserAuth(mut user): UserAuth, + state: State>, + Json(change_password): Json, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let verified = verify_str( + user.password.as_str(), + change_password.current_password.borrow(), + ); + if !verified { + return Err( + ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) + ); + } + + let hashed = hash_secret(&change_password.new_password); + user.password = hashed; + + if let Err(e) = db.update_user_password(&user).await { + error!("failed to change user password: {}", e); + + return Err(ErrorResponse::reply("failed to change user password") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + Ok(Json(ChangePasswordResponse {})) +} + +#[instrument(skip_all, fields(user.username = login.username.as_str()))] +pub async fn login( + state: State>, + login: Json, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + let user = match db.get_user(login.username.borrow()).await { + Ok(u) => u, + Err(DbError::NotFound) => { + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(e)) => { + error!("failed to get user {}: {}", login.username.clone(), e); + + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + let session = match db.get_user_session(&user).await { + Ok(u) => u, + Err(DbError::NotFound) => { + debug!("user session not found for user id={}", user.id); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(err)) => { + error!("database error for user {}: {}", login.username, err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + let verified = verify_str(user.password.as_str(), login.password.borrow()); + + if !verified { + return Err( + ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) + ); + } + + Ok(Json(LoginResponse { + session: session.token, + })) +} + +fn hash_secret(password: &str) -> String { + let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); + let salt = SaltString::generate(&mut OsRng); + let hash = arg2.hash_password(password.as_bytes(), &salt).unwrap(); + hash.to_string() +} diff --git a/crates/atuin-server/src/handlers/v0/me.rs b/crates/atuin-server/src/handlers/v0/me.rs new file mode 100644 index 00000000..7960b479 --- /dev/null +++ b/crates/atuin-server/src/handlers/v0/me.rs @@ -0,0 +1,16 @@ +use axum::Json; +use tracing::instrument; + +use crate::handlers::ErrorResponseStatus; +use crate::router::UserAuth; + +use atuin_common::api::*; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn get( + UserAuth(user): UserAuth, +) -> Result, ErrorResponseStatus<'static>> { + Ok(Json(MeResponse { + username: user.username, + })) +} diff --git a/crates/atuin-server/src/handlers/v0/mod.rs b/crates/atuin-server/src/handlers/v0/mod.rs new file mode 100644 index 00000000..d6f880f2 --- /dev/null +++ b/crates/atuin-server/src/handlers/v0/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod me; +pub(crate) mod record; +pub(crate) mod store; diff --git a/crates/atuin-server/src/handlers/v0/record.rs b/crates/atuin-server/src/handlers/v0/record.rs new file mode 100644 index 00000000..321c34c2 --- /dev/null +++ b/crates/atuin-server/src/handlers/v0/record.rs @@ -0,0 +1,112 @@ +use axum::{extract::Query, extract::State, http::StatusCode, Json}; +use metrics::counter; +use serde::Deserialize; +use tracing::{error, instrument}; + +use crate::{ + handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, + router::{AppState, UserAuth}, +}; +use atuin_server_database::Database; + +use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post( + UserAuth(user): UserAuth, + state: State>, + Json(records): Json>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + tracing::debug!( + count = records.len(), + user = user.username, + "request to add records" + ); + + counter!("atuin_record_uploaded", records.len() as u64); + + let keep = records + .iter() + .all(|r| r.data.data.len() <= settings.max_record_size || settings.max_record_size == 0); + + if !keep { + counter!("atuin_record_too_large", 1); + + return Err( + ErrorResponse::reply("could not add records; record too large") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + if let Err(e) = database.add_records(&user, &records).await { + error!("failed to add record: {}", e); + + return Err(ErrorResponse::reply("failed to add record") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index( + UserAuth(user): UserAuth, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + let record_index = match database.status(&user).await { + Ok(index) => index, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + Ok(Json(record_index)) +} + +#[derive(Deserialize)] +pub struct NextParams { + host: HostId, + tag: String, + start: Option, + count: u64, +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next( + params: Query, + UserAuth(user): UserAuth, + state: State>, +) -> Result>>, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + let params = params.0; + + let records = match database + .next_records(&user, params.host, params.tag, params.start, params.count) + .await + { + Ok(records) => records, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + counter!("atuin_record_downloaded", records.len() as u64); + + Ok(Json(records)) +} diff --git a/crates/atuin-server/src/handlers/v0/store.rs b/crates/atuin-server/src/handlers/v0/store.rs new file mode 100644 index 00000000..941f2487 --- /dev/null +++ b/crates/atuin-server/src/handlers/v0/store.rs @@ -0,0 +1,37 @@ +use axum::{extract::Query, extract::State, http::StatusCode}; +use metrics::counter; +use serde::Deserialize; +use tracing::{error, instrument}; + +use crate::{ + handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, + router::{AppState, UserAuth}, +}; +use atuin_server_database::Database; + +#[derive(Deserialize)] +pub struct DeleteParams {} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete( + _params: Query, + UserAuth(user): UserAuth, + state: State>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + if let Err(e) = database.delete_store(&user).await { + counter!("atuin_store_delete_failed", 1); + error!("failed to delete store {e:?}"); + + return Err(ErrorResponse::reply("failed to delete store") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + counter!("atuin_store_deleted", 1); + + Ok(()) +} diff --git a/crates/atuin-server/src/lib.rs b/crates/atuin-server/src/lib.rs new file mode 100644 index 00000000..a0c104dc --- /dev/null +++ b/crates/atuin-server/src/lib.rs @@ -0,0 +1,144 @@ +#![forbid(unsafe_code)] + +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use atuin_server_database::Database; +use axum::{serve, Router}; +use axum_server::Handle; +use eyre::{Context, Result}; + +mod handlers; +mod metrics; +mod router; +mod utils; + +use rustls::ServerConfig; +pub use settings::example_config; +pub use settings::Settings; + +pub mod settings; + +use tokio::net::TcpListener; +use tokio::signal; + +#[cfg(target_family = "unix")] +async fn shutdown_signal() { + let mut term = signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register signal handler"); + let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt()) + .expect("failed to register signal handler"); + + tokio::select! { + _ = term.recv() => {}, + _ = interrupt.recv() => {}, + }; + eprintln!("Shutting down gracefully..."); +} + +#[cfg(target_family = "windows")] +async fn shutdown_signal() { + signal::windows::ctrl_c() + .expect("failed to register signal handler") + .recv() + .await; + eprintln!("Shutting down gracefully..."); +} + +pub async fn launch( + settings: Settings, + addr: SocketAddr, +) -> Result<()> { + if settings.tls.enable { + launch_with_tls::(settings, addr, shutdown_signal()).await + } else { + launch_with_tcp_listener::( + settings, + TcpListener::bind(addr) + .await + .context("could not connect to socket")?, + shutdown_signal(), + ) + .await + } +} + +pub async fn launch_with_tcp_listener( + settings: Settings, + listener: TcpListener, + shutdown: impl Future + Send + 'static, +) -> Result<()> { + let r = make_router::(settings).await?; + + serve(listener, r.into_make_service()) + .with_graceful_shutdown(shutdown) + .await?; + + Ok(()) +} + +async fn launch_with_tls( + settings: Settings, + addr: SocketAddr, + shutdown: impl Future, +) -> Result<()> { + let certificates = settings.tls.certificates()?; + let pkey = settings.tls.private_key()?; + + let server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certificates, pkey)?; + + let server_config = Arc::new(server_config); + let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config); + + let r = make_router::(settings).await?; + + let handle = Handle::new(); + + let server = axum_server::bind_rustls(addr, rustls_config) + .handle(handle.clone()) + .serve(r.into_make_service()); + + tokio::select! { + _ = server => {} + _ = shutdown => { + handle.graceful_shutdown(None); + } + } + + Ok(()) +} + +// The separate listener means it's much easier to ensure metrics are not accidentally exposed to +// the public. +pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { + let listener = TcpListener::bind((host, port)) + .await + .context("failed to bind metrics tcp")?; + + let recorder_handle = metrics::setup_metrics_recorder(); + + let router = Router::new().route( + "/metrics", + axum::routing::get(move || std::future::ready(recorder_handle.render())), + ); + + serve(listener, router.into_make_service()) + .with_graceful_shutdown(shutdown_signal()) + .await?; + + Ok(()) +} + +async fn make_router( + settings: Settings<::Settings>, +) -> Result { + let db = Db::new(&settings.db_settings) + .await + .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; + let r = router::router(db, settings); + Ok(r) +} diff --git a/crates/atuin-server/src/metrics.rs b/crates/atuin-server/src/metrics.rs new file mode 100644 index 00000000..0a7ac6bd --- /dev/null +++ b/crates/atuin-server/src/metrics.rs @@ -0,0 +1,56 @@ +use std::time::Instant; + +use axum::{ + extract::{MatchedPath, Request}, + middleware::Next, + response::IntoResponse, +}; +use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; + +pub fn setup_metrics_recorder() -> PrometheusHandle { + const EXPONENTIAL_SECONDS: &[f64] = &[ + 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, + ]; + + PrometheusBuilder::new() + .set_buckets_for_metric( + Matcher::Full("http_requests_duration_seconds".to_string()), + EXPONENTIAL_SECONDS, + ) + .unwrap() + .install_recorder() + .unwrap() +} + +/// Middleware to record some common HTTP metrics +/// Generic over B to allow for arbitrary body types (eg Vec, Streams, a deserialized thing, etc) +/// Someday tower-http might provide a metrics middleware: https://github.com/tower-rs/tower-http/issues/57 +pub async fn track_metrics(req: Request, next: Next) -> impl IntoResponse { + let start = Instant::now(); + + let path = if let Some(matched_path) = req.extensions().get::() { + matched_path.as_str().to_owned() + } else { + req.uri().path().to_owned() + }; + + let method = req.method().clone(); + + // Run the rest of the request handling first, so we can measure it and get response + // codes. + let response = next.run(req).await; + + let latency = start.elapsed().as_secs_f64(); + let status = response.status().as_u16().to_string(); + + let labels = [ + ("method", method.to_string()), + ("path", path), + ("status", status), + ]; + + metrics::increment_counter!("http_requests_total", &labels); + metrics::histogram!("http_requests_duration_seconds", latency, &labels); + + response +} diff --git a/crates/atuin-server/src/router.rs b/crates/atuin-server/src/router.rs new file mode 100644 index 00000000..96dff2bd --- /dev/null +++ b/crates/atuin-server/src/router.rs @@ -0,0 +1,149 @@ +use async_trait::async_trait; +use atuin_common::api::{ErrorResponse, ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION}; +use axum::{ + extract::{FromRequestParts, Request}, + http::{self, request::Parts}, + middleware::Next, + response::{IntoResponse, Response}, + routing::{delete, get, patch, post}, + Router, +}; +use eyre::Result; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; + +use super::handlers; +use crate::{ + handlers::{ErrorResponseStatus, RespExt}, + metrics, + settings::Settings, +}; +use atuin_server_database::{models::User, Database, DbError}; + +pub struct UserAuth(pub User); + +#[async_trait] +impl FromRequestParts> for UserAuth +where + DB: Database, +{ + type Rejection = ErrorResponseStatus<'static>; + + async fn from_request_parts( + req: &mut Parts, + state: &AppState, + ) -> Result { + let auth_header = req + .headers + .get(http::header::AUTHORIZATION) + .ok_or_else(|| { + ErrorResponse::reply("missing authorization header") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let auth_header = auth_header.to_str().map_err(|_| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let (typ, token) = auth_header.split_once(' ').ok_or_else(|| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + + if typ != "Token" { + return Err( + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST), + ); + } + + let user = state + .database + .get_session_user(token) + .await + .map_err(|e| match e { + DbError::NotFound => ErrorResponse::reply("session not found") + .with_status(http::StatusCode::FORBIDDEN), + DbError::Other(e) => { + tracing::error!(error = ?e, "could not query user session"); + ErrorResponse::reply("could not query user session") + .with_status(http::StatusCode::INTERNAL_SERVER_ERROR) + } + })?; + + Ok(UserAuth(user)) + } +} + +async fn teapot() -> impl IntoResponse { + // This used to return 418: 🫖 + // Much as it was fun, it wasn't as useful or informative as it should be + (http::StatusCode::NOT_FOUND, "404 not found") +} + +async fn clacks_overhead(request: Request, next: Next) -> Response { + let mut response = next.run(request).await; + + let gnu_terry_value = "GNU Terry Pratchett, Kris Nova"; + let gnu_terry_header = "X-Clacks-Overhead"; + + response + .headers_mut() + .insert(gnu_terry_header, gnu_terry_value.parse().unwrap()); + response +} + +/// Ensure that we only try and sync with clients on the same major version +async fn semver(request: Request, next: Next) -> Response { + let mut response = next.run(request).await; + response + .headers_mut() + .insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap()); + + response +} + +#[derive(Clone)] +pub struct AppState { + pub database: DB, + pub settings: Settings, +} + +pub fn router(database: DB, settings: Settings) -> Router { + let routes = Router::new() + .route("/", get(handlers::index)) + .route("/sync/count", get(handlers::history::count)) + .route("/sync/history", get(handlers::history::list)) + .route("/sync/calendar/:focus", get(handlers::history::calendar)) + .route("/sync/status", get(handlers::status::status)) + .route("/history", post(handlers::history::add)) + .route("/history", delete(handlers::history::delete)) + .route("/user/:username", get(handlers::user::get)) + .route("/account", delete(handlers::user::delete)) + .route("/account/password", patch(handlers::user::change_password)) + .route("/register", post(handlers::user::register)) + .route("/login", post(handlers::user::login)) + .route("/record", post(handlers::record::post::)) + .route("/record", get(handlers::record::index::)) + .route("/record/next", get(handlers::record::next)) + .route("/api/v0/me", get(handlers::v0::me::get)) + .route("/api/v0/record", post(handlers::v0::record::post)) + .route("/api/v0/record", get(handlers::v0::record::index)) + .route("/api/v0/record/next", get(handlers::v0::record::next)) + .route("/api/v0/store", delete(handlers::v0::store::delete)); + + let path = settings.path.as_str(); + if path.is_empty() { + routes + } else { + Router::new().nest(path, routes) + } + .fallback(teapot) + .with_state(AppState { database, settings }) + .layer( + ServiceBuilder::new() + .layer(axum::middleware::from_fn(clacks_overhead)) + .layer(TraceLayer::new_for_http()) + .layer(axum::middleware::from_fn(metrics::track_metrics)) + .layer(axum::middleware::from_fn(semver)), + ) +} diff --git a/crates/atuin-server/src/settings.rs b/crates/atuin-server/src/settings.rs new file mode 100644 index 00000000..2d00df36 --- /dev/null +++ b/crates/atuin-server/src/settings.rs @@ -0,0 +1,151 @@ +use std::{io::prelude::*, path::PathBuf}; + +use config::{Config, Environment, File as ConfigFile, FileFormat}; +use eyre::{bail, eyre, Context, Result}; +use fs_err::{create_dir_all, File}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +static EXAMPLE_CONFIG: &str = include_str!("../server.toml"); + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Metrics { + pub enable: bool, + pub host: String, + pub port: u16, +} + +impl Default for Metrics { + fn default() -> Self { + Self { + enable: false, + host: String::from("127.0.0.1"), + port: 9001, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Settings { + pub host: String, + pub port: u16, + pub path: String, + pub open_registration: bool, + pub max_history_length: usize, + pub max_record_size: usize, + pub page_size: i64, + pub register_webhook_url: Option, + pub register_webhook_username: String, + pub metrics: Metrics, + pub tls: Tls, + + #[serde(flatten)] + pub db_settings: DbSettings, +} + +impl Settings { + pub fn new() -> Result { + let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut config_file = PathBuf::new(); + let config_dir = atuin_common::utils::config_dir(); + config_file.push(config_dir); + config_file + }; + + config_file.push("server.toml"); + + // create the config file if it does not exist + let mut config_builder = Config::builder() + .set_default("host", "127.0.0.1")? + .set_default("port", 8888)? + .set_default("open_registration", false)? + .set_default("max_history_length", 8192)? + .set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky + .set_default("path", "")? + .set_default("register_webhook_username", "")? + .set_default("page_size", 1100)? + .set_default("metrics.enable", false)? + .set_default("metrics.host", "127.0.0.1")? + .set_default("metrics.port", 9001)? + .set_default("tls.enable", false)? + .set_default("tls.cert_path", "")? + .set_default("tls.pkey_path", "")? + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + ); + + config_builder = if config_file.exists() { + config_builder.add_source(ConfigFile::new( + config_file.to_str().unwrap(), + FileFormat::Toml, + )) + } else { + create_dir_all(config_file.parent().unwrap())?; + let mut file = File::create(config_file)?; + file.write_all(EXAMPLE_CONFIG.as_bytes())?; + + config_builder + }; + + let config = config_builder.build()?; + + config + .try_deserialize() + .map_err(|e| eyre!("failed to deserialize: {}", e)) + } +} + +pub fn example_config() -> &'static str { + EXAMPLE_CONFIG +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct Tls { + pub enable: bool, + + pub cert_path: PathBuf, + pub pkey_path: PathBuf, +} + +impl Tls { + pub fn certificates(&self) -> Result> { + let cert_file = std::fs::File::open(&self.cert_path) + .with_context(|| format!("tls.cert_path {:?} is missing", self.cert_path))?; + let mut reader = std::io::BufReader::new(cert_file); + let certs: Vec<_> = rustls_pemfile::certs(&mut reader) + .map(|c| c.map(|c| rustls::Certificate(c.to_vec()))) + .collect::, _>>() + .with_context(|| format!("tls.cert_path {:?} is invalid", self.cert_path))?; + + if certs.is_empty() { + bail!( + "tls.cert_path {:?} must have at least one certificate", + self.cert_path + ); + } + + Ok(certs) + } + + pub fn private_key(&self) -> Result { + let pkey_file = std::fs::File::open(&self.pkey_path) + .with_context(|| format!("tls.pkey_path {:?} is missing", self.pkey_path))?; + let mut reader = std::io::BufReader::new(pkey_file); + let keys = rustls_pemfile::pkcs8_private_keys(&mut reader) + .map(|c| c.map(|c| rustls::PrivateKey(c.secret_pkcs8_der().to_vec()))) + .collect::, _>>() + .with_context(|| format!("tls.pkey_path {:?} is not PKCS8-encoded", self.pkey_path))?; + + if keys.is_empty() { + bail!( + "tls.pkey_path {:?} must have at least one private key", + self.pkey_path + ); + } + + Ok(keys[0].clone()) + } +} diff --git a/crates/atuin-server/src/utils.rs b/crates/atuin-server/src/utils.rs new file mode 100644 index 00000000..12e9ac1b --- /dev/null +++ b/crates/atuin-server/src/utils.rs @@ -0,0 +1,15 @@ +use eyre::Result; +use semver::{Version, VersionReq}; + +pub fn client_version_min(user_agent: &str, req: &str) -> Result { + if user_agent.is_empty() { + return Ok(false); + } + + let version = user_agent.replace("atuin/", ""); + + let req = VersionReq::parse(req)?; + let version = Version::parse(version.as_str())?; + + Ok(req.matches(&version)) +} -- cgit v1.3.1