diff options
| author | Conrad Ludgate <conradludgate@gmail.com> | 2023-02-10 09:45:20 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-10 09:45:20 +0000 |
| commit | 0acdb99eb3f962b6428b2c4e88fe111324ade85b (patch) | |
| tree | 1c59bfa5eb1763a6c2e22e0e37f64de1c1cd6b4d | |
| parent | Bump fs-err from 2.8.1 to 2.9.0 (#604) (diff) | |
| download | atuin-0acdb99eb3f962b6428b2c4e88fe111324ade85b.zip | |
axum6 with typesafe state (#674)
Diffstat (limited to '')
| -rw-r--r-- | Cargo.lock | 42 | ||||
| -rw-r--r-- | atuin-server/Cargo.toml | 4 | ||||
| -rw-r--r-- | atuin-server/src/handlers/history.rs | 21 | ||||
| -rw-r--r-- | atuin-server/src/handlers/user.rs | 19 | ||||
| -rw-r--r-- | atuin-server/src/router.rs | 43 |
5 files changed, 82 insertions, 47 deletions
@@ -158,7 +158,7 @@ dependencies = [ "async-trait", "atuin-common", "axum", - "base64 0.20.0", + "base64 0.21.0", "chrono", "chronoutil", "config", @@ -186,9 +186,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.5.16" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043" +checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc" dependencies = [ "async-trait", "axum-core", @@ -204,8 +204,10 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", + "rustversion", "serde", "serde_json", + "serde_path_to_error", "serde_urlencoded", "sync_wrapper", "tokio", @@ -217,9 +219,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.2.8" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b" +checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34" dependencies = [ "async-trait", "bytes", @@ -227,6 +229,7 @@ dependencies = [ "http", "http-body", "mime", + "rustversion", "tower-layer", "tower-service", ] @@ -244,6 +247,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5" [[package]] +name = "base64" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" + +[[package]] name = "beef" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1123,9 +1132,9 @@ dependencies = [ [[package]] name = "matchit" -version = "0.5.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" +checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" [[package]] name = "md-5" @@ -1721,6 +1730,12 @@ dependencies = [ ] [[package]] +name = "rustversion" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" + +[[package]] name = "ryu" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1822,6 +1837,15 @@ dependencies = [ ] [[package]] +name = "serde_path_to_error" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341" +dependencies = [ + "serde", +] + +[[package]] name = "serde_urlencoded" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2328,9 +2352,9 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" [[package]] name = "tower-service" diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml index bf416b0b..b0857984 100644 --- a/atuin-server/Cargo.toml +++ b/atuin-server/Cargo.toml @@ -20,7 +20,7 @@ config = { version = "0.13", default-features = false, features = ["toml"] } serde = { version = "1.0.145", features = ["derive"] } serde_json = "1.0.86" sodiumoxide = "0.2.6" -base64 = "0.20.0" +base64 = "0.21.0" rand = "0.8.4" tokio = { version = "1", features = ["full"] } sqlx = { version = "0.6", features = [ @@ -29,7 +29,7 @@ sqlx = { version = "0.6", features = [ "postgres", ] } async-trait = "0.1.58" -axum = "0.5" +axum = "0.6.4" http = "0.2" fs-err = "2.9" chronoutil = "0.2.3" diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index d2fda772..9ee13e16 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use axum::{ - extract::{Path, Query}, - Extension, Json, + extract::{Path, Query, State}, + Json, }; use http::StatusCode; use tracing::{debug, error, instrument}; @@ -10,8 +10,9 @@ use tracing::{debug, error, instrument}; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ calendar::{TimePeriod, TimePeriodInfo}, - database::{Database, Postgres}, + database::Database, models::{NewHistory, User}, + router::AppState, }; use atuin_common::api::*; @@ -19,8 +20,9 @@ use atuin_common::api::*; #[instrument(skip_all, fields(user.id = user.id))] pub async fn count( user: User, - db: Extension<Postgres>, + state: State<AppState>, ) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; match db.count_history_cached(&user).await { // By default read out the cached value Ok(count) => Ok(Json(CountResponse { count })), @@ -39,8 +41,9 @@ pub async fn count( pub async fn list( req: Query<SyncHistoryRequest>, user: User, - db: Extension<Postgres>, + state: State<AppState>, ) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let history = db .list_history( &user, @@ -73,9 +76,9 @@ pub async fn list( #[instrument(skip_all, fields(user.id = user.id))] pub async fn add( - Json(req): Json<Vec<AddHistoryRequest>>, user: User, - db: Extension<Postgres>, + state: State<AppState>, + Json(req): Json<Vec<AddHistoryRequest>>, ) -> Result<(), ErrorResponseStatus<'static>> { debug!("request to add {} history items", req.len()); @@ -90,6 +93,7 @@ pub async fn add( }) .collect(); + let db = &state.0.postgres; if let Err(e) = db.add_history(&history).await { error!("failed to add history: {}", e); @@ -105,13 +109,14 @@ pub async fn calendar( Path(focus): Path<String>, Query(params): Query<HashMap<String, u64>>, user: User, - db: Extension<Postgres>, + state: State<AppState>, ) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { let focus = focus.as_str(); let year = params.get("year").unwrap_or(&0); let month = params.get("month").unwrap_or(&1); + let db = &state.0.postgres; let focus = match focus { "year" => db .calendar(&user, TimePeriod::YEAR, *year, *month) diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 1bc9178b..761724c5 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -1,6 +1,9 @@ use std::borrow::Borrow; -use axum::{extract::Path, Extension, Json}; +use axum::{ + extract::{Path, State}, + Extension, Json, +}; use http::StatusCode; use sodiumoxide::crypto::pwhash::argon2id13; use tracing::{debug, error, instrument}; @@ -8,8 +11,9 @@ use uuid::Uuid; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ - database::{Database, Postgres}, + database::Database, models::{NewSession, NewUser}, + router::AppState, settings::Settings, }; @@ -32,8 +36,9 @@ pub fn verify_str(secret: &str, verify: &str) -> bool { #[instrument(skip_all, fields(user.username = username.as_str()))] pub async fn get( Path(username): Path<String>, - db: Extension<Postgres>, + state: State<AppState>, ) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let user = match db.get_user(username.as_ref()).await { Ok(user) => user, Err(sqlx::Error::RowNotFound) => { @@ -54,9 +59,9 @@ pub async fn get( #[instrument(skip_all)] pub async fn register( - Json(register): Json<RegisterRequest>, settings: Extension<Settings>, - db: Extension<Postgres>, + state: State<AppState>, + Json(register): Json<RegisterRequest>, ) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { if !settings.open_registration { return Err( @@ -73,6 +78,7 @@ pub async fn register( password: hashed, }; + let db = &state.0.postgres; let user_id = match db.add_user(&new_user).await { Ok(id) => id, Err(e) => { @@ -102,9 +108,10 @@ pub async fn register( #[instrument(skip_all, fields(user.username = login.username.as_str()))] pub async fn login( + state: State<AppState>, login: Json<LoginRequest>, - db: Extension<Postgres>, ) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, Err(sqlx::Error::RowNotFound) => { diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index 08eea996..c4c15f1f 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -1,12 +1,12 @@ use async_trait::async_trait; use axum::{ - extract::{FromRequest, RequestParts}, - handler::Handler, + extract::FromRequestParts, response::IntoResponse, routing::{get, post}, - Extension, Router, + Router, }; use eyre::Result; +use http::request::Parts; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; @@ -17,20 +17,15 @@ use super::{ use crate::{models::User, settings::Settings}; #[async_trait] -impl<B> FromRequest<B> for User -where - B: Send, -{ +impl FromRequestParts<AppState> for User { type Rejection = http::StatusCode; - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - let postgres = req - .extensions() - .get::<Postgres>() - .ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?; - + async fn from_request_parts( + req: &mut Parts, + state: &AppState, + ) -> Result<Self, Self::Rejection> { let auth_header = req - .headers() + .headers .get(http::header::AUTHORIZATION) .ok_or(http::StatusCode::FORBIDDEN)?; let auth_header = auth_header @@ -44,7 +39,8 @@ where return Err(http::StatusCode::FORBIDDEN); } - let user = postgres + let user = state + .postgres .get_session_user(token) .await .map_err(|_| http::StatusCode::FORBIDDEN)?; @@ -56,6 +52,13 @@ where async fn teapot() -> impl IntoResponse { (http::StatusCode::IM_A_TEAPOT, "☕") } + +#[derive(Clone)] +pub struct AppState { + pub postgres: Postgres, + pub settings: Settings, +} + pub fn router(postgres: Postgres, settings: Settings) -> Router { let routes = Router::new() .route("/", get(handlers::index)) @@ -73,11 +76,7 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router { } else { Router::new().nest(path, routes) } - .fallback(teapot.into_service()) - .layer( - ServiceBuilder::new() - .layer(TraceLayer::new_for_http()) - .layer(Extension(postgres)) - .layer(Extension(settings)), - ) + .fallback(teapot) + .with_state(AppState { postgres, settings }) + .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())) } |
