diff options
Diffstat (limited to 'src/remote')
| -rw-r--r-- | src/remote/auth.rs | 92 | ||||
| -rw-r--r-- | src/remote/database.rs | 2 | ||||
| -rw-r--r-- | src/remote/models.rs | 16 | ||||
| -rw-r--r-- | src/remote/server.rs | 26 | ||||
| -rw-r--r-- | src/remote/views.rs | 144 |
5 files changed, 207 insertions, 73 deletions
diff --git a/src/remote/auth.rs b/src/remote/auth.rs index 8f9e9b46..cf61b077 100644 --- a/src/remote/auth.rs +++ b/src/remote/auth.rs @@ -1,6 +1,8 @@ use self::diesel::prelude::*; +use eyre::Result; use rocket::http::Status; use rocket::request::{self, FromRequest, Outcome, Request}; +use rocket::State; use rocket_contrib::databases::diesel; use sodiumoxide::crypto::pwhash::argon2id13; @@ -9,7 +11,11 @@ use uuid::Uuid; use super::models::{NewSession, NewUser, Session, User}; use super::views::ApiResponse; + +use crate::api::{LoginRequest, RegisterRequest}; use crate::schema::{sessions, users}; +use crate::settings::Settings; +use crate::utils::hash_secret; use super::database::AtuinDbConn; @@ -19,20 +25,6 @@ pub enum KeyError { Invalid, } -pub fn hash_str(secret: &str) -> String { - sodiumoxide::init().unwrap(); - let hash = argon2id13::pwhash( - secret.as_bytes(), - argon2id13::OPSLIMIT_INTERACTIVE, - argon2id13::MEMLIMIT_INTERACTIVE, - ) - .unwrap(); - let texthash = std::str::from_utf8(&hash.0).unwrap().to_string(); - - // postgres hates null chars. don't do that to postgres - texthash.trim_end_matches('\u{0}').to_string() -} - pub fn verify_str(secret: &str, verify: &str) -> bool { sodiumoxide::init().unwrap(); @@ -95,19 +87,54 @@ impl<'a, 'r> FromRequest<'a, 'r> for User { } } -#[derive(Deserialize)] -pub struct Register { - email: String, - password: String, +#[get("/user/<user>")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse { + use crate::schema::users::dsl::{username, users}; + + let user: Result<String, diesel::result::Error> = users + .select(username) + .filter(username.eq(user)) + .first(&*conn); + + if user.is_err() { + return ApiResponse { + json: json!({ + "message": "could not find user", + }), + status: Status::NotFound, + }; + } + + let user = user.unwrap(); + + ApiResponse { + json: json!({ "username": user.as_str() }), + status: Status::Ok, + } } #[post("/register", data = "<register>")] #[allow(clippy::clippy::needless_pass_by_value)] -pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse { - let hashed = hash_str(register.password.as_str()); +pub fn register( + conn: AtuinDbConn, + register: Json<RegisterRequest>, + settings: State<Settings>, +) -> ApiResponse { + if !settings.server.open_registration { + return ApiResponse { + status: Status::BadRequest, + json: json!({ + "message": "registrations are not open" + }), + }; + } + + let hashed = hash_secret(register.password.as_str()); let new_user = NewUser { email: register.email.as_str(), + username: register.username.as_str(), password: hashed.as_str(), }; @@ -119,8 +146,7 @@ pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse { return ApiResponse { status: Status::BadRequest, json: json!({ - "status": "error", - "message": "failed to create user - is the email already in use?", + "message": "failed to create user - username or email in use?", }), }; } @@ -139,32 +165,26 @@ pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse { { Ok(_) => ApiResponse { status: Status::Ok, - json: json!({"status": "ok", "message": "user created!", "session": token}), + json: json!({"message": "user created!", "session": token}), }, Err(_) => ApiResponse { status: Status::BadRequest, - json: json!({"status": "error", "message": "failed to create user"}), + json: json!({ "message": "failed to create user"}), }, } } -#[derive(Deserialize)] -pub struct Login { - email: String, - password: String, -} - #[post("/login", data = "<login>")] #[allow(clippy::clippy::needless_pass_by_value)] -pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse { +pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse { let user = users::table - .filter(users::email.eq(login.email.as_str())) + .filter(users::username.eq(login.username.as_str())) .first(&*conn); if user.is_err() { return ApiResponse { status: Status::NotFound, - json: json!({"status": "error", "message": "user not found"}), + json: json!({"message": "user not found"}), }; } @@ -178,7 +198,7 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse { if session.is_err() { return ApiResponse { status: Status::InternalServerError, - json: json!({"status": "error", "message": "something went wrong"}), + json: json!({"message": "something went wrong"}), }; } @@ -187,7 +207,7 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse { if !verified { return ApiResponse { status: Status::NotFound, - json: json!({"status": "error", "message": "user not found"}), + json: json!({"message": "user not found"}), }; } @@ -195,6 +215,6 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse { ApiResponse { status: Status::Ok, - json: json!({"status": "ok", "token": session.token}), + json: json!({"session": session.token}), } } diff --git a/src/remote/database.rs b/src/remote/database.rs index fabd07de..ddcffda0 100644 --- a/src/remote/database.rs +++ b/src/remote/database.rs @@ -8,7 +8,7 @@ pub struct AtuinDbConn(diesel::PgConnection); // TODO: connection pooling pub fn establish_connection(settings: &Settings) -> PgConnection { - let database_url = &settings.remote.db_uri; + let database_url = &settings.server.db_uri; PgConnection::establish(database_url) .unwrap_or_else(|_| panic!("Error connecting to {}", database_url)) } diff --git a/src/remote/models.rs b/src/remote/models.rs index 058b2f0b..7f6f7766 100644 --- a/src/remote/models.rs +++ b/src/remote/models.rs @@ -1,23 +1,26 @@ -use chrono::naive::NaiveDateTime; +use chrono::prelude::*; use crate::schema::{history, sessions, users}; -#[derive(Identifiable, Queryable, Associations)] +#[derive(Deserialize, Serialize, Identifiable, Queryable, Associations)] #[table_name = "history"] #[belongs_to(User)] pub struct History { pub id: i64, - pub client_id: String, + pub client_id: String, // a client generated ID pub user_id: i64, - pub mac: String, + pub hostname: String, pub timestamp: NaiveDateTime, pub data: String, + + pub created_at: NaiveDateTime, } #[derive(Identifiable, Queryable, Associations)] pub struct User { pub id: i64, + pub username: String, pub email: String, pub password: String, } @@ -35,8 +38,8 @@ pub struct Session { pub struct NewHistory<'a> { pub client_id: &'a str, pub user_id: i64, - pub mac: &'a str, - pub timestamp: NaiveDateTime, + pub hostname: String, + pub timestamp: chrono::NaiveDateTime, pub data: &'a str, } @@ -44,6 +47,7 @@ pub struct NewHistory<'a> { #[derive(Insertable)] #[table_name = "users"] pub struct NewUser<'a> { + pub username: &'a str, pub email: &'a str, pub password: &'a str, } diff --git a/src/remote/server.rs b/src/remote/server.rs index cd2ca7b8..de58397d 100644 --- a/src/remote/server.rs +++ b/src/remote/server.rs @@ -17,13 +17,15 @@ use super::auth::*; embed_migrations!("migrations"); pub fn launch(settings: &Settings, host: String, port: u16) { + let settings: Settings = settings.clone(); // clone so rocket can manage it + let mut database_config = HashMap::new(); let mut databases = HashMap::new(); - database_config.insert("url", Value::from(settings.remote.db_uri.clone())); + database_config.insert("url", Value::from(settings.server.db_uri.clone())); databases.insert("atuin", Value::from(database_config)); - let connection = establish_connection(settings); + let connection = establish_connection(&settings); embedded_migrations::run(&connection).expect("failed to run migrations"); let config = Config::build(Environment::Production) @@ -36,8 +38,20 @@ pub fn launch(settings: &Settings, host: String, port: u16) { let app = rocket::custom(config); - app.mount("/", routes![index, register, add_history, login]) - .attach(AtuinDbConn::fairing()) - .register(catchers![internal_error, bad_request]) - .launch(); + app.mount( + "/", + routes![ + index, + register, + add_history, + login, + get_user, + sync_count, + sync_list + ], + ) + .manage(settings) + .attach(AtuinDbConn::fairing()) + .register(catchers![internal_error, bad_request]) + .launch(); } diff --git a/src/remote/views.rs b/src/remote/views.rs index 2af3f369..08dff13e 100644 --- a/src/remote/views.rs +++ b/src/remote/views.rs @@ -1,14 +1,22 @@ -use self::diesel::prelude::*; +use chrono::Utc; +use rocket::http::uri::Uri; +use rocket::http::RawStr; use rocket::http::{ContentType, Status}; +use rocket::request::FromFormValue; use rocket::request::Request; use rocket::response; use rocket::response::{Responder, Response}; use rocket_contrib::databases::diesel; use rocket_contrib::json::{Json, JsonValue}; -use super::database::AtuinDbConn; -use super::models::{NewHistory, User}; +use self::diesel::prelude::*; + +use crate::api::AddHistoryRequest; use crate::schema::history; +use crate::settings::HISTORY_PAGE_SIZE; + +use super::database::AtuinDbConn; +use super::models::{History, NewHistory, User}; #[derive(Debug)] pub struct ApiResponse { @@ -46,40 +54,36 @@ pub fn bad_request(_req: &Request) -> ApiResponse { } } -#[derive(Deserialize)] -pub struct AddHistory { - id: String, - timestamp: i64, - data: String, - mac: String, -} - #[post("/history", data = "<add_history>")] #[allow( clippy::clippy::cast_sign_loss, clippy::cast_possible_truncation, clippy::clippy::needless_pass_by_value )] -pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>) -> ApiResponse { - let secs: i64 = add_history.timestamp / 1_000_000_000; - let nanosecs: u32 = (add_history.timestamp - (secs * 1_000_000_000)) as u32; - let datetime = chrono::NaiveDateTime::from_timestamp(secs, nanosecs); - - let new_history = NewHistory { - client_id: add_history.id.as_str(), - user_id: user.id, - mac: add_history.mac.as_str(), - timestamp: datetime, - data: add_history.data.as_str(), - }; +pub fn add_history( + conn: AtuinDbConn, + user: User, + add_history: Json<Vec<AddHistoryRequest>>, +) -> ApiResponse { + let new_history: Vec<NewHistory> = add_history + .iter() + .map(|h| NewHistory { + client_id: h.id.as_str(), + hostname: h.hostname.to_string(), + user_id: user.id, + timestamp: h.timestamp.naive_utc(), + data: h.data.as_str(), + }) + .collect(); match diesel::insert_into(history::table) .values(&new_history) + .on_conflict_do_nothing() .execute(&*conn) { Ok(_) => ApiResponse { status: Status::Ok, - json: json!({"status": "ok", "message": "history added", "id": new_history.client_id}), + json: json!({"status": "ok", "message": "history added"}), }, Err(_) => ApiResponse { status: Status::BadRequest, @@ -87,3 +91,95 @@ pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>) }, } } + +#[get("/sync/count")] +#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)] +pub fn sync_count(conn: AtuinDbConn, user: User) -> ApiResponse { + use crate::schema::history::dsl::*; + + // we need to return the number of history items we have for this user + // in the future I'd like to use something like a merkel tree to calculate + // which day specifically needs syncing + let count = history + .filter(user_id.eq(user.id)) + .count() + .first::<i64>(&*conn); + + if count.is_err() { + error!("failed to count: {}", count.err().unwrap()); + + return ApiResponse { + json: json!({"message": "internal server error"}), + status: Status::InternalServerError, + }; + } + + ApiResponse { + status: Status::Ok, + json: json!({"count": count.ok()}), + } +} + +pub struct UtcDateTime(chrono::DateTime<Utc>); + +impl<'v> FromFormValue<'v> for UtcDateTime { + type Error = &'v RawStr; + + fn from_form_value(form_value: &'v RawStr) -> Result<UtcDateTime, &'v RawStr> { + let time = Uri::percent_decode(form_value.as_bytes()).map_err(|_| form_value)?; + let time = time.to_string(); + + match chrono::DateTime::parse_from_rfc3339(time.as_str()) { + Ok(t) => Ok(UtcDateTime(t.with_timezone(&Utc))), + Err(e) => { + error!("failed to parse time {}, got: {}", time, e); + Err(form_value) + } + } + } +} + +// Request a list of all history items added to the DB after a given timestamp. +// Provide the current hostname, so that we don't send the client data that +// originated from them +#[get("/sync/history?<sync_ts>&<history_ts>&<host>")] +#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)] +pub fn sync_list( + conn: AtuinDbConn, + user: User, + sync_ts: UtcDateTime, + history_ts: UtcDateTime, + host: String, +) -> ApiResponse { + use crate::schema::history::dsl::*; + + // we need to return the number of history items we have for this user + // in the future I'd like to use something like a merkel tree to calculate + // which day specifically needs syncing + // TODO: Allow for configuring the page size, both from params, and setting + // the max in config. 100 is fine for now. + let h = history + .filter(user_id.eq(user.id)) + .filter(hostname.ne(host)) + .filter(created_at.ge(sync_ts.0.naive_utc())) + .filter(timestamp.ge(history_ts.0.naive_utc())) + .order(timestamp.asc()) + .limit(HISTORY_PAGE_SIZE) + .load::<History>(&*conn); + + if let Err(e) = h { + error!("failed to load history: {}", e); + + return ApiResponse { + json: json!({"message": "internal server error"}), + status: Status::InternalServerError, + }; + } + + let user_data: Vec<String> = h.unwrap().iter().map(|i| i.data.to_string()).collect(); + + ApiResponse { + status: Status::Ok, + json: json!({ "history": user_data }), + } +} |
