diff options
| author | Ellie Huxtable <e@elm.sh> | 2021-04-20 17:07:11 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-04-20 16:07:11 +0000 |
| commit | 34888827f8a06de835cbe5833a06914f28cce514 (patch) | |
| tree | 8b56f20e50065cd2c222d5e8e067ec55cf1947a1 /src/server | |
| parent | Optimise docker (#34) (diff) | |
| download | atuin-34888827f8a06de835cbe5833a06914f28cce514.zip | |
Switch to Warp + SQLx, use async, switch to Rust stable (#36)
* Switch to warp + sql, use async and stable rust
* Update CI to use stable
Diffstat (limited to 'src/server')
| -rw-r--r-- | src/server/auth.rs | 222 | ||||
| -rw-r--r-- | src/server/database.rs | 202 | ||||
| -rw-r--r-- | src/server/handlers/history.rs | 89 | ||||
| -rw-r--r-- | src/server/handlers/mod.rs | 6 | ||||
| -rw-r--r-- | src/server/handlers/user.rs | 140 | ||||
| -rw-r--r-- | src/server/mod.rs | 23 | ||||
| -rw-r--r-- | src/server/models.rs | 49 | ||||
| -rw-r--r-- | src/server/router.rs | 121 |
8 files changed, 852 insertions, 0 deletions
diff --git a/src/server/auth.rs b/src/server/auth.rs new file mode 100644 index 00000000..52a73108 --- /dev/null +++ b/src/server/auth.rs @@ -0,0 +1,222 @@ +/* +use self::diesel::prelude::*; +use eyre::Result; +use rocket::http::Status; +use rocket::request::{self, FromRequest, Outcome, Request}; +use rocket::State; +use rocket_contrib::databases::diesel; +use sodiumoxide::crypto::pwhash::argon2id13; + +use rocket_contrib::json::Json; +use uuid::Uuid; + +use super::models::{NewSession, NewUser, Session, User}; +use super::views::ApiResponse; + +use crate::api::{LoginRequest, RegisterRequest}; +use crate::schema::{sessions, users}; +use crate::settings::Settings; +use crate::utils::hash_secret; + +use super::database::AtuinDbConn; + +#[derive(Debug)] +pub enum KeyError { + Missing, + Invalid, +} + +pub fn verify_str(secret: &str, verify: &str) -> bool { + sodiumoxide::init().unwrap(); + + let mut padded = [0_u8; 128]; + secret.as_bytes().iter().enumerate().for_each(|(i, val)| { + padded[i] = *val; + }); + + match argon2id13::HashedPassword::from_slice(&padded) { + Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), + None => false, + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for User { + type Error = KeyError; + + fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> { + let session: Vec<_> = request.headers().get("authorization").collect(); + + if session.is_empty() { + return Outcome::Failure((Status::BadRequest, KeyError::Missing)); + } else if session.len() > 1 { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + let session: Vec<_> = session[0].split(' ').collect(); + + if session.len() != 2 { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + if session[0] != "Token" { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + let session = session[1]; + + let db = request + .guard::<AtuinDbConn>() + .succeeded() + .expect("failed to load database"); + + let session = sessions::table + .filter(sessions::token.eq(session)) + .first::<Session>(&*db); + + if session.is_err() { + return Outcome::Failure((Status::Unauthorized, KeyError::Invalid)); + } + + let session = session.unwrap(); + + let user = users::table.find(session.user_id).first(&*db); + + match user { + Ok(user) => Outcome::Success(user), + Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)), + } + } +} + +#[get("/user/<user>")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse { + use crate::schema::users::dsl::{username, users}; + + let user: Result<String, diesel::result::Error> = users + .select(username) + .filter(username.eq(user)) + .first(&*conn); + + if user.is_err() { + return ApiResponse { + json: json!({ + "message": "could not find user", + }), + status: Status::NotFound, + }; + } + + let user = user.unwrap(); + + ApiResponse { + json: json!({ "username": user.as_str() }), + status: Status::Ok, + } +} + +#[post("/register", data = "<register>")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn register( + conn: AtuinDbConn, + register: Json<RegisterRequest>, + settings: State<Settings>, +) -> ApiResponse { + if !settings.server.open_registration { + return ApiResponse { + status: Status::BadRequest, + json: json!({ + "message": "registrations are not open" + }), + }; + } + + let hashed = hash_secret(register.password.as_str()); + + let new_user = NewUser { + email: register.email.as_str(), + username: register.username.as_str(), + password: hashed.as_str(), + }; + + let user = diesel::insert_into(users::table) + .values(&new_user) + .get_result(&*conn); + + if user.is_err() { + return ApiResponse { + status: Status::BadRequest, + json: json!({ + "message": "failed to create user - username or email in use?", + }), + }; + } + + let user: User = user.unwrap(); + let token = Uuid::new_v4().to_simple().to_string(); + + let new_session = NewSession { + user_id: user.id, + token: token.as_str(), + }; + + match diesel::insert_into(sessions::table) + .values(&new_session) + .execute(&*conn) + { + Ok(_) => ApiResponse { + status: Status::Ok, + json: json!({"message": "user created!", "session": token}), + }, + Err(_) => ApiResponse { + status: Status::BadRequest, + json: json!({ "message": "failed to create user"}), + }, + } +} + +#[post("/login", data = "<login>")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse { + let user = users::table + .filter(users::username.eq(login.username.as_str())) + .first(&*conn); + + if user.is_err() { + return ApiResponse { + status: Status::NotFound, + json: json!({"message": "user not found"}), + }; + } + + let user: User = user.unwrap(); + + let session = sessions::table + .filter(sessions::user_id.eq(user.id)) + .first(&*conn); + + // a session should exist... + if session.is_err() { + return ApiResponse { + status: Status::InternalServerError, + json: json!({"message": "something went wrong"}), + }; + } + + let verified = verify_str(user.password.as_str(), login.password.as_str()); + + if !verified { + return ApiResponse { + status: Status::NotFound, + json: json!({"message": "user not found"}), + }; + } + + let session: Session = session.unwrap(); + + ApiResponse { + status: Status::Ok, + json: json!({"session": session.token}), + } +} +*/ diff --git a/src/server/database.rs b/src/server/database.rs new file mode 100644 index 00000000..5945baaf --- /dev/null +++ b/src/server/database.rs @@ -0,0 +1,202 @@ +use async_trait::async_trait; + +use eyre::{eyre, Result}; +use sqlx::postgres::PgPoolOptions; + +use crate::settings::HISTORY_PAGE_SIZE; + +use super::models::{History, NewHistory, NewSession, NewUser, Session, User}; + +#[async_trait] +pub trait Database { + async fn get_session(&self, token: &str) -> Result<Session>; + async fn get_session_user(&self, token: &str) -> Result<User>; + async fn add_session(&self, session: &NewSession) -> Result<()>; + + async fn get_user(&self, username: String) -> Result<User>; + async fn get_user_session(&self, u: &User) -> Result<Session>; + async fn add_user(&self, user: NewUser) -> Result<i64>; + + async fn count_history(&self, user: &User) -> Result<i64>; + async fn list_history( + &self, + user: &User, + created_since: chrono::NaiveDateTime, + since: chrono::NaiveDateTime, + host: String, + ) -> Result<Vec<History>>; + async fn add_history(&self, history: &[NewHistory]) -> Result<()>; +} + +#[derive(Clone)] +pub struct Postgres { + pool: sqlx::Pool<sqlx::postgres::Postgres>, +} + +impl Postgres { + pub async fn new(uri: &str) -> Result<Self, sqlx::Error> { + let pool = PgPoolOptions::new() + .max_connections(100) + .connect(uri) + .await?; + + Ok(Self { pool }) + } +} + +#[async_trait] +impl Database for Postgres { + async fn get_session(&self, token: &str) -> Result<Session> { + let res: Option<Session> = + sqlx::query_as::<_, Session>("select * from sessions where token = $1") + .bind(token) + .fetch_optional(&self.pool) + .await?; + + if let Some(s) = res { + Ok(s) + } else { + Err(eyre!("could not find session")) + } + } + + async fn get_user(&self, username: String) -> Result<User> { + let res: Option<User> = + sqlx::query_as::<_, User>("select * from users where username = $1") + .bind(username) + .fetch_optional(&self.pool) + .await?; + + if let Some(u) = res { + Ok(u) + } else { + Err(eyre!("could not find user")) + } + } + + async fn get_session_user(&self, token: &str) -> Result<User> { + let res: Option<User> = sqlx::query_as::<_, User>( + "select * from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_optional(&self.pool) + .await?; + + if let Some(u) = res { + Ok(u) + } else { + Err(eyre!("could not find user")) + } + } + + async fn count_history(&self, user: &User) -> Result<i64> { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + async fn list_history( + &self, + user: &User, + created_since: chrono::NaiveDateTime, + since: chrono::NaiveDateTime, + host: String, + ) -> Result<Vec<History>> { + let res = sqlx::query_as::<_, History>( + "select * from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(created_since) + .bind(since) + .bind(HISTORY_PAGE_SIZE) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn add_history(&self, history: &[NewHistory]) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for i in history { + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(i.client_id) + .bind(i.user_id) + .bind(i.hostname) + .bind(i.timestamp) + .bind(i.data) + .execute(&mut tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn add_user(&self, user: NewUser) -> Result<i64> { + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(user.username.as_str()) + .bind(user.email.as_str()) + .bind(user.password) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + async fn add_session(&self, session: &NewSession) -> Result<()> { + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(session.token) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn get_user_session(&self, u: &User) -> Result<Session> { + let res: Option<Session> = + sqlx::query_as::<_, Session>("select * from sessions where user_id = $1") + .bind(u.id) + .fetch_optional(&self.pool) + .await?; + + if let Some(s) = res { + Ok(s) + } else { + Err(eyre!("could not find session")) + } + } +} diff --git a/src/server/handlers/history.rs b/src/server/handlers/history.rs new file mode 100644 index 00000000..4fd6f03f --- /dev/null +++ b/src/server/handlers/history.rs @@ -0,0 +1,89 @@ +use std::convert::Infallible; + +use warp::{http::StatusCode, reply::json}; + +use crate::api::{ + AddHistoryRequest, CountResponse, ErrorResponse, SyncHistoryRequest, SyncHistoryResponse, +}; +use crate::server::database::Database; +use crate::server::models::{NewHistory, User}; + +pub async fn count( + user: User, + db: impl Database + Clone + Send + Sync, +) -> Result<Box<dyn warp::Reply>, Infallible> { + db.count_history(&user).await.map_or( + Ok(Box::new(ErrorResponse::reply( + "failed to query history count", + StatusCode::INTERNAL_SERVER_ERROR, + ))), + |count| Ok(Box::new(json(&CountResponse { count }))), + ) +} + +pub async fn list( + req: SyncHistoryRequest, + user: User, + db: impl Database + Clone + Send + Sync, +) -> Result<Box<dyn warp::Reply>, Infallible> { + let history = db + .list_history( + &user, + req.sync_ts.naive_utc(), + req.history_ts.naive_utc(), + req.host, + ) + .await; + + if let Err(e) = history { + error!("failed to load history: {}", e); + let resp = + ErrorResponse::reply("failed to load history", StatusCode::INTERNAL_SERVER_ERROR); + let resp = Box::new(resp); + return Ok(resp); + } + + let history: Vec<String> = history + .unwrap() + .iter() + .map(|i| i.data.to_string()) + .collect(); + + debug!( + "loaded {} items of history for user {}", + history.len(), + user.id + ); + + Ok(Box::new(json(&SyncHistoryResponse { history }))) +} + +pub async fn add( + req: Vec<AddHistoryRequest>, + user: User, + db: impl Database + Clone + Send + Sync, +) -> Result<Box<dyn warp::Reply>, Infallible> { + debug!("request to add {} history items", req.len()); + + let history: Vec<NewHistory> = req + .iter() + .map(|h| NewHistory { + client_id: h.id.as_str(), + user_id: user.id, + hostname: h.hostname.as_str(), + timestamp: h.timestamp.naive_utc(), + data: h.data.as_str(), + }) + .collect(); + + if let Err(e) = db.add_history(&history).await { + error!("failed to add history: {}", e); + + return Ok(Box::new(ErrorResponse::reply( + "failed to add history", + StatusCode::INTERNAL_SERVER_ERROR, + ))); + }; + + Ok(Box::new(warp::reply())) +} diff --git a/src/server/handlers/mod.rs b/src/server/handlers/mod.rs new file mode 100644 index 00000000..3c20538c --- /dev/null +++ b/src/server/handlers/mod.rs @@ -0,0 +1,6 @@ +pub mod history; +pub mod user; + +pub const fn index() -> &'static str { + "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett" +} diff --git a/src/server/handlers/user.rs b/src/server/handlers/user.rs new file mode 100644 index 00000000..782d7dbd --- /dev/null +++ b/src/server/handlers/user.rs @@ -0,0 +1,140 @@ +use std::convert::Infallible; + +use sodiumoxide::crypto::pwhash::argon2id13; +use uuid::Uuid; +use warp::http::StatusCode; +use warp::reply::json; + +use crate::api::{ + ErrorResponse, LoginRequest, LoginResponse, RegisterRequest, RegisterResponse, UserResponse, +}; +use crate::server::database::Database; +use crate::server::models::{NewSession, NewUser}; +use crate::settings::Settings; +use crate::utils::hash_secret; + +pub fn verify_str(secret: &str, verify: &str) -> bool { + sodiumoxide::init().unwrap(); + + let mut padded = [0_u8; 128]; + secret.as_bytes().iter().enumerate().for_each(|(i, val)| { + padded[i] = *val; + }); + + match argon2id13::HashedPassword::from_slice(&padded) { + Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), + None => false, + } +} + +pub async fn get( + username: String, + db: impl Database + Clone + Send + Sync, +) -> Result<Box<dyn warp::Reply>, Infallible> { + let user = match db.get_user(username).await { + Ok(user) => user, + Err(e) => { + debug!("user not found: {}", e); + return Ok(Box::new(ErrorResponse::reply( + "user not found", + StatusCode::NOT_FOUND, + ))); + } + }; + + Ok(Box::new(warp::reply::json(&UserResponse { + username: user.username, + }))) +} + +pub async fn register( + register: RegisterRequest, + settings: Settings, + db: impl Database + Clone + Send + Sync, +) -> Result<Box<dyn warp::Reply>, Infallible> { + if !settings.server.open_registration { + return Ok(Box::new(ErrorResponse::reply( + "this server is not open for registrations", + StatusCode::BAD_REQUEST, + ))); + } + + let hashed = hash_secret(register.password.as_str()); + + let new_user = NewUser { + email: register.email, + username: register.username, + password: hashed, + }; + + let user_id = match db.add_user(new_user).await { + Ok(id) => id, + Err(e) => { + error!("failed to add user: {}", e); + return Ok(Box::new(ErrorResponse::reply( + "failed to add user", + StatusCode::BAD_REQUEST, + ))); + } + }; + + let token = Uuid::new_v4().to_simple().to_string(); + + let new_session = NewSession { + user_id, + token: token.as_str(), + }; + + match db.add_session(&new_session).await { + Ok(_) => Ok(Box::new(json(&RegisterResponse { session: token }))), + Err(e) => { + error!("failed to add session: {}", e); + Ok(Box::new(ErrorResponse::reply( + "failed to register user", + StatusCode::BAD_REQUEST, + ))) + } + } +} + +pub async fn login( + login: LoginRequest, + db: impl Database + Clone + Send + Sync, +) -> Result<Box<dyn warp::Reply>, Infallible> { + let user = match db.get_user(login.username.clone()).await { + Ok(u) => u, + Err(e) => { + error!("failed to get user {}: {}", login.username.clone(), e); + + return Ok(Box::new(ErrorResponse::reply( + "user not found", + StatusCode::NOT_FOUND, + ))); + } + }; + + let session = match db.get_user_session(&user).await { + Ok(u) => u, + Err(e) => { + error!("failed to get session for {}: {}", login.username, e); + + return Ok(Box::new(ErrorResponse::reply( + "user not found", + StatusCode::NOT_FOUND, + ))); + } + }; + + let verified = verify_str(user.password.as_str(), login.password.as_str()); + + if !verified { + return Ok(Box::new(ErrorResponse::reply( + "user not found", + StatusCode::NOT_FOUND, + ))); + } + + Ok(Box::new(warp::reply::json(&LoginResponse { + session: session.token, + }))) +} diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 00000000..d5e083df --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,23 @@ +use std::net::IpAddr; + +use eyre::Result; + +use crate::settings::Settings; + +pub mod auth; +pub mod database; +pub mod handlers; +pub mod models; +pub mod router; + +pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> { + // routes to run: + // index, register, add_history, login, get_user, sync_count, sync_list + let host = host.parse::<IpAddr>()?; + + let r = router::router(settings).await?; + + warp::serve(r).run((host, port)).await; + + Ok(()) +} diff --git a/src/server/models.rs b/src/server/models.rs new file mode 100644 index 00000000..fbf1897e --- /dev/null +++ b/src/server/models.rs @@ -0,0 +1,49 @@ +use chrono::prelude::*; + +#[derive(sqlx::FromRow)] +pub struct History { + pub id: i64, + pub client_id: String, // a client generated ID + pub user_id: i64, + pub hostname: String, + pub timestamp: NaiveDateTime, + + pub data: String, + + pub created_at: NaiveDateTime, +} + +pub struct NewHistory<'a> { + pub client_id: &'a str, + pub user_id: i64, + pub hostname: &'a str, + pub timestamp: chrono::NaiveDateTime, + + pub data: &'a str, +} + +#[derive(sqlx::FromRow)] +pub struct User { + pub id: i64, + pub username: String, + pub email: String, + pub password: String, +} + +#[derive(sqlx::FromRow)] +pub struct Session { + pub id: i64, + pub user_id: i64, + pub token: String, +} + +pub struct NewUser { + pub username: String, + pub email: String, + pub password: String, +} + +pub struct NewSession<'a> { + pub user_id: i64, + pub token: &'a str, +} diff --git a/src/server/router.rs b/src/server/router.rs new file mode 100644 index 00000000..ed317ab2 --- /dev/null +++ b/src/server/router.rs @@ -0,0 +1,121 @@ +use std::convert::Infallible; + +use eyre::Result; +use warp::Filter; + +use super::handlers; +use super::{database::Database, database::Postgres}; +use crate::server::models::User; +use crate::{api::SyncHistoryRequest, settings::Settings}; + +fn with_settings( + settings: Settings, +) -> impl Filter<Extract = (Settings,), Error = Infallible> + Clone { + warp::any().map(move || settings.clone()) +} + +fn with_db( + db: impl Database + Clone + Send + Sync, +) -> impl Filter<Extract = (impl Database + Clone,), Error = Infallible> + Clone { + warp::any().map(move || db.clone()) +} + +fn with_user( + postgres: Postgres, +) -> impl Filter<Extract = (User,), Error = warp::Rejection> + Clone { + warp::header::<String>("authorization").and_then(move |header: String| { + // async closures are still buggy :( + let postgres = postgres.clone(); + + async move { + let header: Vec<&str> = header.split(' ').collect(); + + let token; + + if header.len() == 2 { + if header[0] != "Token" { + return Err(warp::reject()); + } + + token = header[1]; + } else { + return Err(warp::reject()); + } + + let user = postgres + .get_session_user(token) + .await + .map_err(|_| warp::reject())?; + + Ok(user) + } + }) +} + +pub async fn router( + settings: &Settings, +) -> Result<impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone> { + let postgres = Postgres::new(settings.server.db_uri.as_str()).await?; + let index = warp::get().and(warp::path::end()).map(handlers::index); + + let count = warp::get() + .and(warp::path("sync")) + .and(warp::path("count")) + .and(warp::path::end()) + .and(with_user(postgres.clone())) + .and(with_db(postgres.clone())) + .and_then(handlers::history::count); + + let sync = warp::get() + .and(warp::path("sync")) + .and(warp::path("history")) + .and(warp::query::<SyncHistoryRequest>()) + .and(warp::path::end()) + .and(with_user(postgres.clone())) + .and(with_db(postgres.clone())) + .and_then(handlers::history::list); + + let add_history = warp::post() + .and(warp::path("history")) + .and(warp::path::end()) + .and(warp::body::json()) + .and(with_user(postgres.clone())) + .and(with_db(postgres.clone())) + .and_then(handlers::history::add); + + let user = warp::get() + .and(warp::path("user")) + .and(warp::path::param::<String>()) + .and(warp::path::end()) + .and(with_db(postgres.clone())) + .and_then(handlers::user::get); + + let register = warp::post() + .and(warp::path("register")) + .and(warp::path::end()) + .and(warp::body::json()) + .and(with_settings(settings.clone())) + .and(with_db(postgres.clone())) + .and_then(handlers::user::register); + + let login = warp::post() + .and(warp::path("login")) + .and(warp::path::end()) + .and(warp::body::json()) + .and(with_db(postgres)) + .and_then(handlers::user::login); + + let r = warp::any() + .and( + index + .or(count) + .or(sync) + .or(add_history) + .or(user) + .or(register) + .or(login), + ) + .with(warp::filters::log::log("atuin::api")); + + Ok(r) +} |
