1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
use crate::{
atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse},
atuin_server::database::{DbError, db::Database, models::User},
};
use axum::{
Router,
extract::{FromRequestParts, Request},
http::{self, request::Parts},
middleware::Next,
response::{IntoResponse, Response},
routing::{delete, get, patch, post},
};
use eyre::Result;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
use super::handlers;
use crate::atuin_server::{
handlers::{ErrorResponseStatus, RespExt},
metrics,
settings::Settings,
};
pub(crate) struct UserAuth(pub(crate) User);
impl FromRequestParts<AppState> for UserAuth {
type Rejection = ErrorResponseStatus<'static>;
async fn from_request_parts(
req: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let auth_header = req
.headers
.get(http::header::AUTHORIZATION)
.ok_or_else(|| {
ErrorResponse::reply("missing authorization header")
.with_status(http::StatusCode::BAD_REQUEST)
})?;
let auth_header = auth_header.to_str().map_err(|_| {
ErrorResponse::reply("invalid authorization header encoding")
.with_status(http::StatusCode::BAD_REQUEST)
})?;
let (typ, token) = auth_header.split_once(' ').ok_or_else(|| {
ErrorResponse::reply("invalid authorization header encoding")
.with_status(http::StatusCode::BAD_REQUEST)
})?;
if typ != "Token" {
return Err(
ErrorResponse::reply("invalid authorization header encoding")
.with_status(http::StatusCode::BAD_REQUEST),
);
}
let user = state
.database
.get_session_user(token)
.await
.map_err(|e| match e {
DbError::NotFound => ErrorResponse::reply("session not found")
.with_status(http::StatusCode::FORBIDDEN),
DbError::Other(e) => {
tracing::error!(error = ?e, "could not query user session");
ErrorResponse::reply("could not query user session")
.with_status(http::StatusCode::INTERNAL_SERVER_ERROR)
}
})?;
Ok(UserAuth(user))
}
}
async fn teapot() -> impl IntoResponse {
// This used to return 418: 🫖
// Much as it was fun, it wasn't as useful or informative as it should be
(http::StatusCode::NOT_FOUND, "404 not found")
}
/// Ensure that we only try and sync with clients on the same major version
async fn semver(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
response
.headers_mut()
.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap());
response
}
#[derive(Clone)]
pub(crate) struct AppState {
pub(crate) database: Database,
pub(crate) settings: Settings,
}
pub(crate) fn router(database: Database, settings: Settings) -> Router {
let routes = Router::new()
.route("/", get(handlers::index))
.route("/healthz", get(handlers::health::health_check));
let routes = routes
.route("/user/{username}", get(handlers::user::get))
.route("/account", delete(handlers::user::delete))
.route("/account/password", patch(handlers::user::change_password))
.route("/register", post(handlers::user::register))
.route("/login", post(handlers::user::login))
.route("/record", post(handlers::record::post))
.route("/record", get(handlers::record::index))
.route("/record/next", get(handlers::record::next))
.route("/api/v0/me", get(handlers::v0::me::get))
.route("/api/v0/record", post(handlers::v0::record::post))
.route("/api/v0/record", get(handlers::v0::record::index))
.route("/api/v0/record/next", get(handlers::v0::record::next))
.route("/api/v0/store", delete(handlers::v0::store::delete));
let path = settings.path.as_str();
if path.is_empty() {
routes
} else {
Router::new().nest(path, routes)
}
.fallback(teapot)
.with_state(AppState { database, settings })
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(axum::middleware::from_fn(metrics::track_metrics))
.layer(axum::middleware::from_fn(semver)),
)
}
|