aboutsummaryrefslogtreecommitdiffstats
path: root/src/remote
diff options
context:
space:
mode:
authorEllie Huxtable <e@elm.sh>2021-04-13 19:14:07 +0100
committerGitHub <noreply@github.com>2021-04-13 19:14:07 +0100
commit5751463942cc91f1f1ffaf6e2ac633d7a0085f25 (patch)
treef7b5b9a4702c4c3ef29aa60d36612f61ffeae052 /src/remote
parentUpdate config (diff)
downloadatuin-5751463942cc91f1f1ffaf6e2ac633d7a0085f25.zip
Add history sync, resolves #13 (#31)
* Add encryption * Add login and register command * Add count endpoint * Write initial sync push * Add single sync command Confirmed working for one client only * Automatically sync on a configurable frequency * Add key command, key arg to login * Only load session if it exists * Use sync and history timestamps for download * Bind other key code Seems like some systems have this code for up arrow? I'm not sure why, and it's not an easy one to google. * Simplify upload * Try and fix download sync loop * Change sync order to avoid uploading what we just downloaded * Multiline import fix * Fix time parsing * Fix importing history with no time * Add hostname to sync * Use hostname to filter sync * Fixes * Add binding * Stuff from yesterday * Set cursor modes * Make clippy happy * Bump version
Diffstat (limited to 'src/remote')
-rw-r--r--src/remote/auth.rs92
-rw-r--r--src/remote/database.rs2
-rw-r--r--src/remote/models.rs16
-rw-r--r--src/remote/server.rs26
-rw-r--r--src/remote/views.rs144
5 files changed, 207 insertions, 73 deletions
diff --git a/src/remote/auth.rs b/src/remote/auth.rs
index 8f9e9b46..cf61b077 100644
--- a/src/remote/auth.rs
+++ b/src/remote/auth.rs
@@ -1,6 +1,8 @@
use self::diesel::prelude::*;
+use eyre::Result;
use rocket::http::Status;
use rocket::request::{self, FromRequest, Outcome, Request};
+use rocket::State;
use rocket_contrib::databases::diesel;
use sodiumoxide::crypto::pwhash::argon2id13;
@@ -9,7 +11,11 @@ use uuid::Uuid;
use super::models::{NewSession, NewUser, Session, User};
use super::views::ApiResponse;
+
+use crate::api::{LoginRequest, RegisterRequest};
use crate::schema::{sessions, users};
+use crate::settings::Settings;
+use crate::utils::hash_secret;
use super::database::AtuinDbConn;
@@ -19,20 +25,6 @@ pub enum KeyError {
Invalid,
}
-pub fn hash_str(secret: &str) -> String {
- sodiumoxide::init().unwrap();
- let hash = argon2id13::pwhash(
- secret.as_bytes(),
- argon2id13::OPSLIMIT_INTERACTIVE,
- argon2id13::MEMLIMIT_INTERACTIVE,
- )
- .unwrap();
- let texthash = std::str::from_utf8(&hash.0).unwrap().to_string();
-
- // postgres hates null chars. don't do that to postgres
- texthash.trim_end_matches('\u{0}').to_string()
-}
-
pub fn verify_str(secret: &str, verify: &str) -> bool {
sodiumoxide::init().unwrap();
@@ -95,19 +87,54 @@ impl<'a, 'r> FromRequest<'a, 'r> for User {
}
}
-#[derive(Deserialize)]
-pub struct Register {
- email: String,
- password: String,
+#[get("/user/<user>")]
+#[allow(clippy::clippy::needless_pass_by_value)]
+pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse {
+ use crate::schema::users::dsl::{username, users};
+
+ let user: Result<String, diesel::result::Error> = users
+ .select(username)
+ .filter(username.eq(user))
+ .first(&*conn);
+
+ if user.is_err() {
+ return ApiResponse {
+ json: json!({
+ "message": "could not find user",
+ }),
+ status: Status::NotFound,
+ };
+ }
+
+ let user = user.unwrap();
+
+ ApiResponse {
+ json: json!({ "username": user.as_str() }),
+ status: Status::Ok,
+ }
}
#[post("/register", data = "<register>")]
#[allow(clippy::clippy::needless_pass_by_value)]
-pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
- let hashed = hash_str(register.password.as_str());
+pub fn register(
+ conn: AtuinDbConn,
+ register: Json<RegisterRequest>,
+ settings: State<Settings>,
+) -> ApiResponse {
+ if !settings.server.open_registration {
+ return ApiResponse {
+ status: Status::BadRequest,
+ json: json!({
+ "message": "registrations are not open"
+ }),
+ };
+ }
+
+ let hashed = hash_secret(register.password.as_str());
let new_user = NewUser {
email: register.email.as_str(),
+ username: register.username.as_str(),
password: hashed.as_str(),
};
@@ -119,8 +146,7 @@ pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
return ApiResponse {
status: Status::BadRequest,
json: json!({
- "status": "error",
- "message": "failed to create user - is the email already in use?",
+ "message": "failed to create user - username or email in use?",
}),
};
}
@@ -139,32 +165,26 @@ pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
{
Ok(_) => ApiResponse {
status: Status::Ok,
- json: json!({"status": "ok", "message": "user created!", "session": token}),
+ json: json!({"message": "user created!", "session": token}),
},
Err(_) => ApiResponse {
status: Status::BadRequest,
- json: json!({"status": "error", "message": "failed to create user"}),
+ json: json!({ "message": "failed to create user"}),
},
}
}
-#[derive(Deserialize)]
-pub struct Login {
- email: String,
- password: String,
-}
-
#[post("/login", data = "<login>")]
#[allow(clippy::clippy::needless_pass_by_value)]
-pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
+pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse {
let user = users::table
- .filter(users::email.eq(login.email.as_str()))
+ .filter(users::username.eq(login.username.as_str()))
.first(&*conn);
if user.is_err() {
return ApiResponse {
status: Status::NotFound,
- json: json!({"status": "error", "message": "user not found"}),
+ json: json!({"message": "user not found"}),
};
}
@@ -178,7 +198,7 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
if session.is_err() {
return ApiResponse {
status: Status::InternalServerError,
- json: json!({"status": "error", "message": "something went wrong"}),
+ json: json!({"message": "something went wrong"}),
};
}
@@ -187,7 +207,7 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
if !verified {
return ApiResponse {
status: Status::NotFound,
- json: json!({"status": "error", "message": "user not found"}),
+ json: json!({"message": "user not found"}),
};
}
@@ -195,6 +215,6 @@ pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
ApiResponse {
status: Status::Ok,
- json: json!({"status": "ok", "token": session.token}),
+ json: json!({"session": session.token}),
}
}
diff --git a/src/remote/database.rs b/src/remote/database.rs
index fabd07de..ddcffda0 100644
--- a/src/remote/database.rs
+++ b/src/remote/database.rs
@@ -8,7 +8,7 @@ pub struct AtuinDbConn(diesel::PgConnection);
// TODO: connection pooling
pub fn establish_connection(settings: &Settings) -> PgConnection {
- let database_url = &settings.remote.db_uri;
+ let database_url = &settings.server.db_uri;
PgConnection::establish(database_url)
.unwrap_or_else(|_| panic!("Error connecting to {}", database_url))
}
diff --git a/src/remote/models.rs b/src/remote/models.rs
index 058b2f0b..7f6f7766 100644
--- a/src/remote/models.rs
+++ b/src/remote/models.rs
@@ -1,23 +1,26 @@
-use chrono::naive::NaiveDateTime;
+use chrono::prelude::*;
use crate::schema::{history, sessions, users};
-#[derive(Identifiable, Queryable, Associations)]
+#[derive(Deserialize, Serialize, Identifiable, Queryable, Associations)]
#[table_name = "history"]
#[belongs_to(User)]
pub struct History {
pub id: i64,
- pub client_id: String,
+ pub client_id: String, // a client generated ID
pub user_id: i64,
- pub mac: String,
+ pub hostname: String,
pub timestamp: NaiveDateTime,
pub data: String,
+
+ pub created_at: NaiveDateTime,
}
#[derive(Identifiable, Queryable, Associations)]
pub struct User {
pub id: i64,
+ pub username: String,
pub email: String,
pub password: String,
}
@@ -35,8 +38,8 @@ pub struct Session {
pub struct NewHistory<'a> {
pub client_id: &'a str,
pub user_id: i64,
- pub mac: &'a str,
- pub timestamp: NaiveDateTime,
+ pub hostname: String,
+ pub timestamp: chrono::NaiveDateTime,
pub data: &'a str,
}
@@ -44,6 +47,7 @@ pub struct NewHistory<'a> {
#[derive(Insertable)]
#[table_name = "users"]
pub struct NewUser<'a> {
+ pub username: &'a str,
pub email: &'a str,
pub password: &'a str,
}
diff --git a/src/remote/server.rs b/src/remote/server.rs
index cd2ca7b8..de58397d 100644
--- a/src/remote/server.rs
+++ b/src/remote/server.rs
@@ -17,13 +17,15 @@ use super::auth::*;
embed_migrations!("migrations");
pub fn launch(settings: &Settings, host: String, port: u16) {
+ let settings: Settings = settings.clone(); // clone so rocket can manage it
+
let mut database_config = HashMap::new();
let mut databases = HashMap::new();
- database_config.insert("url", Value::from(settings.remote.db_uri.clone()));
+ database_config.insert("url", Value::from(settings.server.db_uri.clone()));
databases.insert("atuin", Value::from(database_config));
- let connection = establish_connection(settings);
+ let connection = establish_connection(&settings);
embedded_migrations::run(&connection).expect("failed to run migrations");
let config = Config::build(Environment::Production)
@@ -36,8 +38,20 @@ pub fn launch(settings: &Settings, host: String, port: u16) {
let app = rocket::custom(config);
- app.mount("/", routes![index, register, add_history, login])
- .attach(AtuinDbConn::fairing())
- .register(catchers![internal_error, bad_request])
- .launch();
+ app.mount(
+ "/",
+ routes![
+ index,
+ register,
+ add_history,
+ login,
+ get_user,
+ sync_count,
+ sync_list
+ ],
+ )
+ .manage(settings)
+ .attach(AtuinDbConn::fairing())
+ .register(catchers![internal_error, bad_request])
+ .launch();
}
diff --git a/src/remote/views.rs b/src/remote/views.rs
index 2af3f369..08dff13e 100644
--- a/src/remote/views.rs
+++ b/src/remote/views.rs
@@ -1,14 +1,22 @@
-use self::diesel::prelude::*;
+use chrono::Utc;
+use rocket::http::uri::Uri;
+use rocket::http::RawStr;
use rocket::http::{ContentType, Status};
+use rocket::request::FromFormValue;
use rocket::request::Request;
use rocket::response;
use rocket::response::{Responder, Response};
use rocket_contrib::databases::diesel;
use rocket_contrib::json::{Json, JsonValue};
-use super::database::AtuinDbConn;
-use super::models::{NewHistory, User};
+use self::diesel::prelude::*;
+
+use crate::api::AddHistoryRequest;
use crate::schema::history;
+use crate::settings::HISTORY_PAGE_SIZE;
+
+use super::database::AtuinDbConn;
+use super::models::{History, NewHistory, User};
#[derive(Debug)]
pub struct ApiResponse {
@@ -46,40 +54,36 @@ pub fn bad_request(_req: &Request) -> ApiResponse {
}
}
-#[derive(Deserialize)]
-pub struct AddHistory {
- id: String,
- timestamp: i64,
- data: String,
- mac: String,
-}
-
#[post("/history", data = "<add_history>")]
#[allow(
clippy::clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::clippy::needless_pass_by_value
)]
-pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>) -> ApiResponse {
- let secs: i64 = add_history.timestamp / 1_000_000_000;
- let nanosecs: u32 = (add_history.timestamp - (secs * 1_000_000_000)) as u32;
- let datetime = chrono::NaiveDateTime::from_timestamp(secs, nanosecs);
-
- let new_history = NewHistory {
- client_id: add_history.id.as_str(),
- user_id: user.id,
- mac: add_history.mac.as_str(),
- timestamp: datetime,
- data: add_history.data.as_str(),
- };
+pub fn add_history(
+ conn: AtuinDbConn,
+ user: User,
+ add_history: Json<Vec<AddHistoryRequest>>,
+) -> ApiResponse {
+ let new_history: Vec<NewHistory> = add_history
+ .iter()
+ .map(|h| NewHistory {
+ client_id: h.id.as_str(),
+ hostname: h.hostname.to_string(),
+ user_id: user.id,
+ timestamp: h.timestamp.naive_utc(),
+ data: h.data.as_str(),
+ })
+ .collect();
match diesel::insert_into(history::table)
.values(&new_history)
+ .on_conflict_do_nothing()
.execute(&*conn)
{
Ok(_) => ApiResponse {
status: Status::Ok,
- json: json!({"status": "ok", "message": "history added", "id": new_history.client_id}),
+ json: json!({"status": "ok", "message": "history added"}),
},
Err(_) => ApiResponse {
status: Status::BadRequest,
@@ -87,3 +91,95 @@ pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>)
},
}
}
+
+#[get("/sync/count")]
+#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
+pub fn sync_count(conn: AtuinDbConn, user: User) -> ApiResponse {
+ use crate::schema::history::dsl::*;
+
+ // we need to return the number of history items we have for this user
+ // in the future I'd like to use something like a merkel tree to calculate
+ // which day specifically needs syncing
+ let count = history
+ .filter(user_id.eq(user.id))
+ .count()
+ .first::<i64>(&*conn);
+
+ if count.is_err() {
+ error!("failed to count: {}", count.err().unwrap());
+
+ return ApiResponse {
+ json: json!({"message": "internal server error"}),
+ status: Status::InternalServerError,
+ };
+ }
+
+ ApiResponse {
+ status: Status::Ok,
+ json: json!({"count": count.ok()}),
+ }
+}
+
+pub struct UtcDateTime(chrono::DateTime<Utc>);
+
+impl<'v> FromFormValue<'v> for UtcDateTime {
+ type Error = &'v RawStr;
+
+ fn from_form_value(form_value: &'v RawStr) -> Result<UtcDateTime, &'v RawStr> {
+ let time = Uri::percent_decode(form_value.as_bytes()).map_err(|_| form_value)?;
+ let time = time.to_string();
+
+ match chrono::DateTime::parse_from_rfc3339(time.as_str()) {
+ Ok(t) => Ok(UtcDateTime(t.with_timezone(&Utc))),
+ Err(e) => {
+ error!("failed to parse time {}, got: {}", time, e);
+ Err(form_value)
+ }
+ }
+ }
+}
+
+// Request a list of all history items added to the DB after a given timestamp.
+// Provide the current hostname, so that we don't send the client data that
+// originated from them
+#[get("/sync/history?<sync_ts>&<history_ts>&<host>")]
+#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
+pub fn sync_list(
+ conn: AtuinDbConn,
+ user: User,
+ sync_ts: UtcDateTime,
+ history_ts: UtcDateTime,
+ host: String,
+) -> ApiResponse {
+ use crate::schema::history::dsl::*;
+
+ // we need to return the number of history items we have for this user
+ // in the future I'd like to use something like a merkel tree to calculate
+ // which day specifically needs syncing
+ // TODO: Allow for configuring the page size, both from params, and setting
+ // the max in config. 100 is fine for now.
+ let h = history
+ .filter(user_id.eq(user.id))
+ .filter(hostname.ne(host))
+ .filter(created_at.ge(sync_ts.0.naive_utc()))
+ .filter(timestamp.ge(history_ts.0.naive_utc()))
+ .order(timestamp.asc())
+ .limit(HISTORY_PAGE_SIZE)
+ .load::<History>(&*conn);
+
+ if let Err(e) = h {
+ error!("failed to load history: {}", e);
+
+ return ApiResponse {
+ json: json!({"message": "internal server error"}),
+ status: Status::InternalServerError,
+ };
+ }
+
+ let user_data: Vec<String> = h.unwrap().iter().map(|i| i.data.to_string()).collect();
+
+ ApiResponse {
+ status: Status::Ok,
+ json: json!({ "history": user_data }),
+ }
+}