aboutsummaryrefslogtreecommitdiffstats
path: root/atuin-server
diff options
context:
space:
mode:
Diffstat (limited to 'atuin-server')
-rw-r--r--atuin-server/Cargo.toml4
-rw-r--r--atuin-server/src/handlers/history.rs21
-rw-r--r--atuin-server/src/handlers/user.rs19
-rw-r--r--atuin-server/src/router.rs43
4 files changed, 49 insertions, 38 deletions
diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml
index bf416b0b..b0857984 100644
--- a/atuin-server/Cargo.toml
+++ b/atuin-server/Cargo.toml
@@ -20,7 +20,7 @@ config = { version = "0.13", default-features = false, features = ["toml"] }
serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.86"
sodiumoxide = "0.2.6"
-base64 = "0.20.0"
+base64 = "0.21.0"
rand = "0.8.4"
tokio = { version = "1", features = ["full"] }
sqlx = { version = "0.6", features = [
@@ -29,7 +29,7 @@ sqlx = { version = "0.6", features = [
"postgres",
] }
async-trait = "0.1.58"
-axum = "0.5"
+axum = "0.6.4"
http = "0.2"
fs-err = "2.9"
chronoutil = "0.2.3"
diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs
index d2fda772..9ee13e16 100644
--- a/atuin-server/src/handlers/history.rs
+++ b/atuin-server/src/handlers/history.rs
@@ -1,8 +1,8 @@
use std::collections::HashMap;
use axum::{
- extract::{Path, Query},
- Extension, Json,
+ extract::{Path, Query, State},
+ Json,
};
use http::StatusCode;
use tracing::{debug, error, instrument};
@@ -10,8 +10,9 @@ use tracing::{debug, error, instrument};
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::{
calendar::{TimePeriod, TimePeriodInfo},
- database::{Database, Postgres},
+ database::Database,
models::{NewHistory, User},
+ router::AppState,
};
use atuin_common::api::*;
@@ -19,8 +20,9 @@ use atuin_common::api::*;
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn count(
user: User,
- db: Extension<Postgres>,
+ state: State<AppState>,
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
+ let db = &state.0.postgres;
match db.count_history_cached(&user).await {
// By default read out the cached value
Ok(count) => Ok(Json(CountResponse { count })),
@@ -39,8 +41,9 @@ pub async fn count(
pub async fn list(
req: Query<SyncHistoryRequest>,
user: User,
- db: Extension<Postgres>,
+ state: State<AppState>,
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
+ let db = &state.0.postgres;
let history = db
.list_history(
&user,
@@ -73,9 +76,9 @@ pub async fn list(
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn add(
- Json(req): Json<Vec<AddHistoryRequest>>,
user: User,
- db: Extension<Postgres>,
+ state: State<AppState>,
+ Json(req): Json<Vec<AddHistoryRequest>>,
) -> Result<(), ErrorResponseStatus<'static>> {
debug!("request to add {} history items", req.len());
@@ -90,6 +93,7 @@ pub async fn add(
})
.collect();
+ let db = &state.0.postgres;
if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e);
@@ -105,13 +109,14 @@ pub async fn calendar(
Path(focus): Path<String>,
Query(params): Query<HashMap<String, u64>>,
user: User,
- db: Extension<Postgres>,
+ state: State<AppState>,
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
let focus = focus.as_str();
let year = params.get("year").unwrap_or(&0);
let month = params.get("month").unwrap_or(&1);
+ let db = &state.0.postgres;
let focus = match focus {
"year" => db
.calendar(&user, TimePeriod::YEAR, *year, *month)
diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs
index 1bc9178b..761724c5 100644
--- a/atuin-server/src/handlers/user.rs
+++ b/atuin-server/src/handlers/user.rs
@@ -1,6 +1,9 @@
use std::borrow::Borrow;
-use axum::{extract::Path, Extension, Json};
+use axum::{
+ extract::{Path, State},
+ Extension, Json,
+};
use http::StatusCode;
use sodiumoxide::crypto::pwhash::argon2id13;
use tracing::{debug, error, instrument};
@@ -8,8 +11,9 @@ use uuid::Uuid;
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::{
- database::{Database, Postgres},
+ database::Database,
models::{NewSession, NewUser},
+ router::AppState,
settings::Settings,
};
@@ -32,8 +36,9 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
#[instrument(skip_all, fields(user.username = username.as_str()))]
pub async fn get(
Path(username): Path<String>,
- db: Extension<Postgres>,
+ state: State<AppState>,
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
+ let db = &state.0.postgres;
let user = match db.get_user(username.as_ref()).await {
Ok(user) => user,
Err(sqlx::Error::RowNotFound) => {
@@ -54,9 +59,9 @@ pub async fn get(
#[instrument(skip_all)]
pub async fn register(
- Json(register): Json<RegisterRequest>,
settings: Extension<Settings>,
- db: Extension<Postgres>,
+ state: State<AppState>,
+ Json(register): Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
if !settings.open_registration {
return Err(
@@ -73,6 +78,7 @@ pub async fn register(
password: hashed,
};
+ let db = &state.0.postgres;
let user_id = match db.add_user(&new_user).await {
Ok(id) => id,
Err(e) => {
@@ -102,9 +108,10 @@ pub async fn register(
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
pub async fn login(
+ state: State<AppState>,
login: Json<LoginRequest>,
- db: Extension<Postgres>,
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
+ let db = &state.0.postgres;
let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u,
Err(sqlx::Error::RowNotFound) => {
diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs
index 08eea996..c4c15f1f 100644
--- a/atuin-server/src/router.rs
+++ b/atuin-server/src/router.rs
@@ -1,12 +1,12 @@
use async_trait::async_trait;
use axum::{
- extract::{FromRequest, RequestParts},
- handler::Handler,
+ extract::FromRequestParts,
response::IntoResponse,
routing::{get, post},
- Extension, Router,
+ Router,
};
use eyre::Result;
+use http::request::Parts;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
@@ -17,20 +17,15 @@ use super::{
use crate::{models::User, settings::Settings};
#[async_trait]
-impl<B> FromRequest<B> for User
-where
- B: Send,
-{
+impl FromRequestParts<AppState> for User {
type Rejection = http::StatusCode;
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
- let postgres = req
- .extensions()
- .get::<Postgres>()
- .ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?;
-
+ async fn from_request_parts(
+ req: &mut Parts,
+ state: &AppState,
+ ) -> Result<Self, Self::Rejection> {
let auth_header = req
- .headers()
+ .headers
.get(http::header::AUTHORIZATION)
.ok_or(http::StatusCode::FORBIDDEN)?;
let auth_header = auth_header
@@ -44,7 +39,8 @@ where
return Err(http::StatusCode::FORBIDDEN);
}
- let user = postgres
+ let user = state
+ .postgres
.get_session_user(token)
.await
.map_err(|_| http::StatusCode::FORBIDDEN)?;
@@ -56,6 +52,13 @@ where
async fn teapot() -> impl IntoResponse {
(http::StatusCode::IM_A_TEAPOT, "☕")
}
+
+#[derive(Clone)]
+pub struct AppState {
+ pub postgres: Postgres,
+ pub settings: Settings,
+}
+
pub fn router(postgres: Postgres, settings: Settings) -> Router {
let routes = Router::new()
.route("/", get(handlers::index))
@@ -73,11 +76,7 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router {
} else {
Router::new().nest(path, routes)
}
- .fallback(teapot.into_service())
- .layer(
- ServiceBuilder::new()
- .layer(TraceLayer::new_for_http())
- .layer(Extension(postgres))
- .layer(Extension(settings)),
- )
+ .fallback(teapot)
+ .with_state(AppState { postgres, settings })
+ .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
}