From 0acdb99eb3f962b6428b2c4e88fe111324ade85b Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 10 Feb 2023 09:45:20 +0000 Subject: axum6 with typesafe state (#674) --- atuin-server/src/handlers/history.rs | 21 +++++++++++++-------- atuin-server/src/handlers/user.rs | 19 +++++++++++++------ 2 files changed, 26 insertions(+), 14 deletions(-) (limited to 'atuin-server/src/handlers') diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index d2fda772..9ee13e16 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use axum::{ - extract::{Path, Query}, - Extension, Json, + extract::{Path, Query, State}, + Json, }; use http::StatusCode; use tracing::{debug, error, instrument}; @@ -10,8 +10,9 @@ use tracing::{debug, error, instrument}; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ calendar::{TimePeriod, TimePeriodInfo}, - database::{Database, Postgres}, + database::Database, models::{NewHistory, User}, + router::AppState, }; use atuin_common::api::*; @@ -19,8 +20,9 @@ use atuin_common::api::*; #[instrument(skip_all, fields(user.id = user.id))] pub async fn count( user: User, - db: Extension, + state: State, ) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; match db.count_history_cached(&user).await { // By default read out the cached value Ok(count) => Ok(Json(CountResponse { count })), @@ -39,8 +41,9 @@ pub async fn count( pub async fn list( req: Query, user: User, - db: Extension, + state: State, ) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let history = db .list_history( &user, @@ -73,9 +76,9 @@ pub async fn list( #[instrument(skip_all, fields(user.id = user.id))] pub async fn add( - Json(req): Json>, user: User, - db: Extension, + state: State, + Json(req): Json>, ) -> Result<(), ErrorResponseStatus<'static>> { debug!("request to add {} history items", req.len()); @@ -90,6 +93,7 @@ pub async fn add( }) .collect(); + let db = &state.0.postgres; if let Err(e) = db.add_history(&history).await { error!("failed to add history: {}", e); @@ -105,13 +109,14 @@ pub async fn calendar( Path(focus): Path, Query(params): Query>, user: User, - db: Extension, + state: State, ) -> Result>, ErrorResponseStatus<'static>> { let focus = focus.as_str(); let year = params.get("year").unwrap_or(&0); let month = params.get("month").unwrap_or(&1); + let db = &state.0.postgres; let focus = match focus { "year" => db .calendar(&user, TimePeriod::YEAR, *year, *month) diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 1bc9178b..761724c5 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -1,6 +1,9 @@ use std::borrow::Borrow; -use axum::{extract::Path, Extension, Json}; +use axum::{ + extract::{Path, State}, + Extension, Json, +}; use http::StatusCode; use sodiumoxide::crypto::pwhash::argon2id13; use tracing::{debug, error, instrument}; @@ -8,8 +11,9 @@ use uuid::Uuid; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ - database::{Database, Postgres}, + database::Database, models::{NewSession, NewUser}, + router::AppState, settings::Settings, }; @@ -32,8 +36,9 @@ pub fn verify_str(secret: &str, verify: &str) -> bool { #[instrument(skip_all, fields(user.username = username.as_str()))] pub async fn get( Path(username): Path, - db: Extension, + state: State, ) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let user = match db.get_user(username.as_ref()).await { Ok(user) => user, Err(sqlx::Error::RowNotFound) => { @@ -54,9 +59,9 @@ pub async fn get( #[instrument(skip_all)] pub async fn register( - Json(register): Json, settings: Extension, - db: Extension, + state: State, + Json(register): Json, ) -> Result, ErrorResponseStatus<'static>> { if !settings.open_registration { return Err( @@ -73,6 +78,7 @@ pub async fn register( password: hashed, }; + let db = &state.0.postgres; let user_id = match db.add_user(&new_user).await { Ok(id) => id, Err(e) => { @@ -102,9 +108,10 @@ pub async fn register( #[instrument(skip_all, fields(user.username = login.username.as_str()))] pub async fn login( + state: State, login: Json, - db: Extension, ) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.postgres; let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, Err(sqlx::Error::RowNotFound) => { -- cgit v1.3.1