diff options
| author | Conrad Ludgate <conradludgate@gmail.com> | 2023-02-10 09:45:20 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-10 09:45:20 +0000 |
| commit | 0acdb99eb3f962b6428b2c4e88fe111324ade85b (patch) | |
| tree | 1c59bfa5eb1763a6c2e22e0e37f64de1c1cd6b4d /atuin-server | |
| parent | Bump fs-err from 2.8.1 to 2.9.0 (#604) (diff) | |
| download | atuin-0acdb99eb3f962b6428b2c4e88fe111324ade85b.zip | |
axum6 with typesafe state (#674)
Diffstat (limited to '')
| -rw-r--r-- | atuin-server/Cargo.toml | 4 | ||||
| -rw-r--r-- | atuin-server/src/handlers/history.rs | 21 | ||||
| -rw-r--r-- | atuin-server/src/handlers/user.rs | 19 | ||||
| -rw-r--r-- | atuin-server/src/router.rs | 43 |
4 files changed, 49 insertions, 38 deletions
diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml index bf416b0b..b0857984 100644 --- a/atuin-server/Cargo.toml +++ b/atuin-server/Cargo.toml @@ -20,7 +20,7 @@ config = { version = "0.13", default-features = false, features = ["toml"] } serde = { version = "1.0.145", features = ["derive"] } serde_json = "1.0.86" sodiumoxide = "0.2.6" -base64 = "0.20.0" +base64 = "0.21.0" rand = "0.8.4" tokio = { version = "1", features = ["full"] } sqlx = { version = "0.6", features = [ @@ -29,7 +29,7 @@ sqlx = { version = "0.6", features = [ "postgres", ] } async-trait = "0.1.58" -axum = "0.5" +axum = "0.6.4" http = "0.2" fs-err = "2.9" chronoutil = "0.2.3" diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index d2fda772..9ee13e16 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use axum::{ - extract::{Path, Query}, - Extension, Json, + extract::{Path, Query, State}, + Json, }; use http::StatusCode; use tracing::{debug, error, instrument}; @@ -10,8 +10,9 @@ use tracing::{debug, error, instrument}; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ calendar::{TimePeriod, TimePeriodInfo}, - database::{Database, Postgres}, + database::Database, models::{NewHistory, User}, + router::AppState, }; use atuin_common::api::*; @@ -19,8 +20,9 @@ use atuin_common::api::*; #[instrument(skip_all, fields(user.id = user.id))] pub async fn count( user: User, - db: Extension<Postgres>, + state: State<AppState>, ) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; match db.count_history_cached(&user).await { // By default read out the cached value Ok(count) => Ok(Json(CountResponse { count })), @@ -39,8 +41,9 @@ pub async fn count( pub async fn list( req: Query<SyncHistoryRequest>, user: User, - db: Extension<Postgres>, + state: State<AppState>, ) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let history = db .list_history( &user, @@ -73,9 +76,9 @@ pub async fn list( #[instrument(skip_all, fields(user.id = user.id))] pub async fn add( - Json(req): Json<Vec<AddHistoryRequest>>, user: User, - db: Extension<Postgres>, + state: State<AppState>, + Json(req): Json<Vec<AddHistoryRequest>>, ) -> Result<(), ErrorResponseStatus<'static>> { debug!("request to add {} history items", req.len()); @@ -90,6 +93,7 @@ pub async fn add( }) .collect(); + let db = &state.0.postgres; if let Err(e) = db.add_history(&history).await { error!("failed to add history: {}", e); @@ -105,13 +109,14 @@ pub async fn calendar( Path(focus): Path<String>, Query(params): Query<HashMap<String, u64>>, user: User, - db: Extension<Postgres>, + state: State<AppState>, ) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { let focus = focus.as_str(); let year = params.get("year").unwrap_or(&0); let month = params.get("month").unwrap_or(&1); + let db = &state.0.postgres; let focus = match focus { "year" => db .calendar(&user, TimePeriod::YEAR, *year, *month) diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 1bc9178b..761724c5 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -1,6 +1,9 @@ use std::borrow::Borrow; -use axum::{extract::Path, Extension, Json}; +use axum::{ + extract::{Path, State}, + Extension, Json, +}; use http::StatusCode; use sodiumoxide::crypto::pwhash::argon2id13; use tracing::{debug, error, instrument}; @@ -8,8 +11,9 @@ use uuid::Uuid; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ - database::{Database, Postgres}, + database::Database, models::{NewSession, NewUser}, + router::AppState, settings::Settings, }; @@ -32,8 +36,9 @@ pub fn verify_str(secret: &str, verify: &str) -> bool { #[instrument(skip_all, fields(user.username = username.as_str()))] pub async fn get( Path(username): Path<String>, - db: Extension<Postgres>, + state: State<AppState>, ) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let user = match db.get_user(username.as_ref()).await { Ok(user) => user, Err(sqlx::Error::RowNotFound) => { @@ -54,9 +59,9 @@ pub async fn get( #[instrument(skip_all)] pub async fn register( - Json(register): Json<RegisterRequest>, settings: Extension<Settings>, - db: Extension<Postgres>, + state: State<AppState>, + Json(register): Json<RegisterRequest>, ) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { if !settings.open_registration { return Err( @@ -73,6 +78,7 @@ pub async fn register( password: hashed, }; + let db = &state.0.postgres; let user_id = match db.add_user(&new_user).await { Ok(id) => id, Err(e) => { @@ -102,9 +108,10 @@ pub async fn register( #[instrument(skip_all, fields(user.username = login.username.as_str()))] pub async fn login( + state: State<AppState>, login: Json<LoginRequest>, - db: Extension<Postgres>, ) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, Err(sqlx::Error::RowNotFound) => { diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index 08eea996..c4c15f1f 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -1,12 +1,12 @@ use async_trait::async_trait; use axum::{ - extract::{FromRequest, RequestParts}, - handler::Handler, + extract::FromRequestParts, response::IntoResponse, routing::{get, post}, - Extension, Router, + Router, }; use eyre::Result; +use http::request::Parts; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; @@ -17,20 +17,15 @@ use super::{ use crate::{models::User, settings::Settings}; #[async_trait] -impl<B> FromRequest<B> for User -where - B: Send, -{ +impl FromRequestParts<AppState> for User { type Rejection = http::StatusCode; - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - let postgres = req - .extensions() - .get::<Postgres>() - .ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?; - + async fn from_request_parts( + req: &mut Parts, + state: &AppState, + ) -> Result<Self, Self::Rejection> { let auth_header = req - .headers() + .headers .get(http::header::AUTHORIZATION) .ok_or(http::StatusCode::FORBIDDEN)?; let auth_header = auth_header @@ -44,7 +39,8 @@ where return Err(http::StatusCode::FORBIDDEN); } - let user = postgres + let user = state + .postgres .get_session_user(token) .await .map_err(|_| http::StatusCode::FORBIDDEN)?; @@ -56,6 +52,13 @@ where async fn teapot() -> impl IntoResponse { (http::StatusCode::IM_A_TEAPOT, "☕") } + +#[derive(Clone)] +pub struct AppState { + pub postgres: Postgres, + pub settings: Settings, +} + pub fn router(postgres: Postgres, settings: Settings) -> Router { let routes = Router::new() .route("/", get(handlers::index)) @@ -73,11 +76,7 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router { } else { Router::new().nest(path, routes) } - .fallback(teapot.into_service()) - .layer( - ServiceBuilder::new() - .layer(TraceLayer::new_for_http()) - .layer(Extension(postgres)) - .layer(Extension(settings)), - ) + .fallback(teapot) + .with_state(AppState { postgres, settings }) + .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())) } |
