use crate::{ atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse}, atuin_server::database::{db::ServerPostgres, models::User}, }; use axum::{ Router, extract::{FromRequestParts, Path, Request}, http::{self, request::Parts}, middleware::Next, response::{IntoResponse, Response}, routing::{get, post}, }; use eyre::Result; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; use uuid::Uuid; use super::handlers; use crate::atuin_server::{ handlers::{ErrorResponseStatus, RespExt}, metrics, settings::Settings, }; pub(crate) struct UserAuth(pub(crate) User); impl FromRequestParts for UserAuth { type Rejection = ErrorResponseStatus<'static>; async fn from_request_parts( req: &mut Parts, state: &AppState, ) -> Result { let user_id = { let Path(user_id) = as FromRequestParts>::from_request_parts(req, state) .await .map_err(|_| { ErrorResponse::reply("invalid user_id path param") .with_status(http::StatusCode::BAD_REQUEST) })?; user_id }; let user = User { id: user_id }; Ok(Self(user)) } } async fn teapot() -> impl IntoResponse { // This used to return 418: 🫖 // Much as it was fun, it wasn't as useful or informative as it should be (http::StatusCode::NOT_FOUND, "404 not found") } /// Ensure that we only try and sync with clients on the same major version async fn semver(request: Request, next: Next) -> Response { let mut response = next.run(request).await; response .headers_mut() .insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap()); response } #[derive(Clone)] pub(crate) struct AppState { pub(crate) database: ServerPostgres, pub(crate) settings: Settings, } pub(crate) fn router(database: ServerPostgres, settings: Settings) -> Router { let routes = Router::new() .route("/", get(handlers::index)) .route("/api/v0/{user_id}/record", post(handlers::v0::record::post)) .route("/api/v0/{user_id}/record", get(handlers::v0::record::index)) .route( "/api/v0/{user_id}/record/next", get(handlers::v0::record::next), ); let path = settings.path.as_str(); if path.is_empty() { routes } else { Router::new().nest(path, routes) } .fallback(teapot) .with_state(AppState { database, settings }) .layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) .layer(axum::middleware::from_fn(metrics::track_metrics)) .layer(axum::middleware::from_fn(semver)), ) }