aboutsummaryrefslogtreecommitdiffstats
path: root/src/remote
diff options
context:
space:
mode:
authorEllie Huxtable <e@elm.sh>2021-03-21 20:04:39 +0000
committerGitHub <noreply@github.com>2021-03-21 20:04:39 +0000
commitc9579cb9ca2a6a165d10f128e0af1dfd372e0c03 (patch)
tree1d4feecb422aae3cde1cc7cad54ccc73b2dae410 /src/remote
parentAdd TUI, resolve #19, #17, #16 (#21) (diff)
downloadatuin-c9579cb9ca2a6a165d10f128e0af1dfd372e0c03.zip
Implement server (#23)
* Add initial database and server setup * Set up all routes, auth, etc * Implement sessions, password auth, hashing with argon2, and history storage
Diffstat (limited to 'src/remote')
-rw-r--r--src/remote/auth.rs200
-rw-r--r--src/remote/database.rs14
-rw-r--r--src/remote/mod.rs4
-rw-r--r--src/remote/models.rs56
-rw-r--r--src/remote/server.rs46
-rw-r--r--src/remote/views.rs89
6 files changed, 403 insertions, 6 deletions
diff --git a/src/remote/auth.rs b/src/remote/auth.rs
new file mode 100644
index 00000000..8f9e9b46
--- /dev/null
+++ b/src/remote/auth.rs
@@ -0,0 +1,200 @@
+use self::diesel::prelude::*;
+use rocket::http::Status;
+use rocket::request::{self, FromRequest, Outcome, Request};
+use rocket_contrib::databases::diesel;
+use sodiumoxide::crypto::pwhash::argon2id13;
+
+use rocket_contrib::json::Json;
+use uuid::Uuid;
+
+use super::models::{NewSession, NewUser, Session, User};
+use super::views::ApiResponse;
+use crate::schema::{sessions, users};
+
+use super::database::AtuinDbConn;
+
+#[derive(Debug)]
+pub enum KeyError {
+ Missing,
+ Invalid,
+}
+
+pub fn hash_str(secret: &str) -> String {
+ sodiumoxide::init().unwrap();
+ let hash = argon2id13::pwhash(
+ secret.as_bytes(),
+ argon2id13::OPSLIMIT_INTERACTIVE,
+ argon2id13::MEMLIMIT_INTERACTIVE,
+ )
+ .unwrap();
+ let texthash = std::str::from_utf8(&hash.0).unwrap().to_string();
+
+ // postgres hates null chars. don't do that to postgres
+ texthash.trim_end_matches('\u{0}').to_string()
+}
+
+pub fn verify_str(secret: &str, verify: &str) -> bool {
+ sodiumoxide::init().unwrap();
+
+ let mut padded = [0_u8; 128];
+ secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
+ padded[i] = *val;
+ });
+
+ match argon2id13::HashedPassword::from_slice(&padded) {
+ Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
+ None => false,
+ }
+}
+
+impl<'a, 'r> FromRequest<'a, 'r> for User {
+ type Error = KeyError;
+
+ fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> {
+ let session: Vec<_> = request.headers().get("authorization").collect();
+
+ if session.is_empty() {
+ return Outcome::Failure((Status::BadRequest, KeyError::Missing));
+ } else if session.len() > 1 {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ let session: Vec<_> = session[0].split(' ').collect();
+
+ if session.len() != 2 {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ if session[0] != "Token" {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ let session = session[1];
+
+ let db = request
+ .guard::<AtuinDbConn>()
+ .succeeded()
+ .expect("failed to load database");
+
+ let session = sessions::table
+ .filter(sessions::token.eq(session))
+ .first::<Session>(&*db);
+
+ if session.is_err() {
+ return Outcome::Failure((Status::Unauthorized, KeyError::Invalid));
+ }
+
+ let session = session.unwrap();
+
+ let user = users::table.find(session.user_id).first(&*db);
+
+ match user {
+ Ok(user) => Outcome::Success(user),
+ Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)),
+ }
+ }
+}
+
+#[derive(Deserialize)]
+pub struct Register {
+ email: String,
+ password: String,
+}
+
+#[post("/register", data = "<register>")]
+#[allow(clippy::clippy::needless_pass_by_value)]
+pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
+ let hashed = hash_str(register.password.as_str());
+
+ let new_user = NewUser {
+ email: register.email.as_str(),
+ password: hashed.as_str(),
+ };
+
+ let user = diesel::insert_into(users::table)
+ .values(&new_user)
+ .get_result(&*conn);
+
+ if user.is_err() {
+ return ApiResponse {
+ status: Status::BadRequest,
+ json: json!({
+ "status": "error",
+ "message": "failed to create user - is the email already in use?",
+ }),
+ };
+ }
+
+ let user: User = user.unwrap();
+ let token = Uuid::new_v4().to_simple().to_string();
+
+ let new_session = NewSession {
+ user_id: user.id,
+ token: token.as_str(),
+ };
+
+ match diesel::insert_into(sessions::table)
+ .values(&new_session)
+ .execute(&*conn)
+ {
+ Ok(_) => ApiResponse {
+ status: Status::Ok,
+ json: json!({"status": "ok", "message": "user created!", "session": token}),
+ },
+ Err(_) => ApiResponse {
+ status: Status::BadRequest,
+ json: json!({"status": "error", "message": "failed to create user"}),
+ },
+ }
+}
+
+#[derive(Deserialize)]
+pub struct Login {
+ email: String,
+ password: String,
+}
+
+#[post("/login", data = "<login>")]
+#[allow(clippy::clippy::needless_pass_by_value)]
+pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
+ let user = users::table
+ .filter(users::email.eq(login.email.as_str()))
+ .first(&*conn);
+
+ if user.is_err() {
+ return ApiResponse {
+ status: Status::NotFound,
+ json: json!({"status": "error", "message": "user not found"}),
+ };
+ }
+
+ let user: User = user.unwrap();
+
+ let session = sessions::table
+ .filter(sessions::user_id.eq(user.id))
+ .first(&*conn);
+
+ // a session should exist...
+ if session.is_err() {
+ return ApiResponse {
+ status: Status::InternalServerError,
+ json: json!({"status": "error", "message": "something went wrong"}),
+ };
+ }
+
+ let verified = verify_str(user.password.as_str(), login.password.as_str());
+
+ if !verified {
+ return ApiResponse {
+ status: Status::NotFound,
+ json: json!({"status": "error", "message": "user not found"}),
+ };
+ }
+
+ let session: Session = session.unwrap();
+
+ ApiResponse {
+ status: Status::Ok,
+ json: json!({"status": "ok", "token": session.token}),
+ }
+}
diff --git a/src/remote/database.rs b/src/remote/database.rs
new file mode 100644
index 00000000..4f386def
--- /dev/null
+++ b/src/remote/database.rs
@@ -0,0 +1,14 @@
+use diesel::pg::PgConnection;
+use diesel::prelude::*;
+
+use crate::settings::Settings;
+
+#[database("atuin")]
+pub struct AtuinDbConn(diesel::PgConnection);
+
+// TODO: connection pooling
+pub fn establish_connection(settings: &Settings) -> PgConnection {
+ let database_url = &settings.remote.db.url;
+ PgConnection::establish(database_url)
+ .unwrap_or_else(|_| panic!("Error connecting to {}", database_url))
+}
diff --git a/src/remote/mod.rs b/src/remote/mod.rs
index 74f47ad3..7147b88e 100644
--- a/src/remote/mod.rs
+++ b/src/remote/mod.rs
@@ -1 +1,5 @@
+pub mod auth;
+pub mod database;
+pub mod models;
pub mod server;
+pub mod views;
diff --git a/src/remote/models.rs b/src/remote/models.rs
new file mode 100644
index 00000000..058b2f0b
--- /dev/null
+++ b/src/remote/models.rs
@@ -0,0 +1,56 @@
+use chrono::naive::NaiveDateTime;
+
+use crate::schema::{history, sessions, users};
+
+#[derive(Identifiable, Queryable, Associations)]
+#[table_name = "history"]
+#[belongs_to(User)]
+pub struct History {
+ pub id: i64,
+ pub client_id: String,
+ pub user_id: i64,
+ pub mac: String,
+ pub timestamp: NaiveDateTime,
+
+ pub data: String,
+}
+
+#[derive(Identifiable, Queryable, Associations)]
+pub struct User {
+ pub id: i64,
+ pub email: String,
+ pub password: String,
+}
+
+#[derive(Queryable, Identifiable, Associations)]
+#[belongs_to(User)]
+pub struct Session {
+ pub id: i64,
+ pub user_id: i64,
+ pub token: String,
+}
+
+#[derive(Insertable)]
+#[table_name = "history"]
+pub struct NewHistory<'a> {
+ pub client_id: &'a str,
+ pub user_id: i64,
+ pub mac: &'a str,
+ pub timestamp: NaiveDateTime,
+
+ pub data: &'a str,
+}
+
+#[derive(Insertable)]
+#[table_name = "users"]
+pub struct NewUser<'a> {
+ pub email: &'a str,
+ pub password: &'a str,
+}
+
+#[derive(Insertable)]
+#[table_name = "sessions"]
+pub struct NewSession<'a> {
+ pub user_id: i64,
+ pub token: &'a str,
+}
diff --git a/src/remote/server.rs b/src/remote/server.rs
index bc1dc2bd..4409f646 100644
--- a/src/remote/server.rs
+++ b/src/remote/server.rs
@@ -1,8 +1,42 @@
-#[get("/")]
-const fn index() -> &'static str {
- "Hello, world!"
-}
+use rocket::config::{Config, Environment, LoggingLevel, Value};
+
+use std::collections::HashMap;
+
+use crate::remote::database::establish_connection;
+use crate::settings::Settings;
+
+use super::database::AtuinDbConn;
+
+// a bunch of these imports are generated by macros, it's easier to wildcard
+#[allow(clippy::clippy::wildcard_imports)]
+use super::views::*;
+
+#[allow(clippy::clippy::wildcard_imports)]
+use super::auth::*;
+
+embed_migrations!("migrations");
+
+pub fn launch(settings: &Settings) {
+ let mut database_config = HashMap::new();
+ let mut databases = HashMap::new();
+
+ database_config.insert("url", Value::from(settings.remote.db.url.clone()));
+ databases.insert("atuin", Value::from(database_config));
+
+ let connection = establish_connection(settings);
+ embedded_migrations::run(&connection).expect("failed to run migrations");
+
+ let config = Config::build(Environment::Production)
+ .address("0.0.0.0")
+ .log_level(LoggingLevel::Normal)
+ .port(8080)
+ .extra("databases", databases)
+ .finalize()
+ .unwrap();
-pub fn launch() {
- rocket::ignite().mount("/", routes![index]).launch();
+ let app = rocket::custom(config);
+ app.mount("/", routes![index, register, add_history, login])
+ .attach(AtuinDbConn::fairing())
+ .register(catchers![internal_error, bad_request])
+ .launch();
}
diff --git a/src/remote/views.rs b/src/remote/views.rs
new file mode 100644
index 00000000..2af3f369
--- /dev/null
+++ b/src/remote/views.rs
@@ -0,0 +1,89 @@
+use self::diesel::prelude::*;
+use rocket::http::{ContentType, Status};
+use rocket::request::Request;
+use rocket::response;
+use rocket::response::{Responder, Response};
+use rocket_contrib::databases::diesel;
+use rocket_contrib::json::{Json, JsonValue};
+
+use super::database::AtuinDbConn;
+use super::models::{NewHistory, User};
+use crate::schema::history;
+
+#[derive(Debug)]
+pub struct ApiResponse {
+ pub json: JsonValue,
+ pub status: Status,
+}
+
+impl<'r> Responder<'r> for ApiResponse {
+ fn respond_to(self, req: &Request) -> response::Result<'r> {
+ Response::build_from(self.json.respond_to(req).unwrap())
+ .status(self.status)
+ .header(ContentType::JSON)
+ .ok()
+ }
+}
+
+#[get("/")]
+pub const fn index() -> &'static str {
+ "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
+}
+
+#[catch(500)]
+pub fn internal_error(_req: &Request) -> ApiResponse {
+ ApiResponse {
+ status: Status::InternalServerError,
+ json: json!({"status": "error", "message": "an internal server error has occured"}),
+ }
+}
+
+#[catch(400)]
+pub fn bad_request(_req: &Request) -> ApiResponse {
+ ApiResponse {
+ status: Status::InternalServerError,
+ json: json!({"status": "error", "message": "bad request. don't do that."}),
+ }
+}
+
+#[derive(Deserialize)]
+pub struct AddHistory {
+ id: String,
+ timestamp: i64,
+ data: String,
+ mac: String,
+}
+
+#[post("/history", data = "<add_history>")]
+#[allow(
+ clippy::clippy::cast_sign_loss,
+ clippy::cast_possible_truncation,
+ clippy::clippy::needless_pass_by_value
+)]
+pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>) -> ApiResponse {
+ let secs: i64 = add_history.timestamp / 1_000_000_000;
+ let nanosecs: u32 = (add_history.timestamp - (secs * 1_000_000_000)) as u32;
+ let datetime = chrono::NaiveDateTime::from_timestamp(secs, nanosecs);
+
+ let new_history = NewHistory {
+ client_id: add_history.id.as_str(),
+ user_id: user.id,
+ mac: add_history.mac.as_str(),
+ timestamp: datetime,
+ data: add_history.data.as_str(),
+ };
+
+ match diesel::insert_into(history::table)
+ .values(&new_history)
+ .execute(&*conn)
+ {
+ Ok(_) => ApiResponse {
+ status: Status::Ok,
+ json: json!({"status": "ok", "message": "history added", "id": new_history.client_id}),
+ },
+ Err(_) => ApiResponse {
+ status: Status::BadRequest,
+ json: json!({"status": "error", "message": "failed to add history"}),
+ },
+ }
+}