aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorEllie Huxtable <e@elm.sh>2021-04-20 17:07:11 +0100
committerGitHub <noreply@github.com>2021-04-20 16:07:11 +0000
commit34888827f8a06de835cbe5833a06914f28cce514 (patch)
tree8b56f20e50065cd2c222d5e8e067ec55cf1947a1 /src
parentOptimise docker (#34) (diff)
downloadatuin-34888827f8a06de835cbe5833a06914f28cce514.zip
Switch to Warp + SQLx, use async, switch to Rust stable (#36)
* Switch to warp + sql, use async and stable rust * Update CI to use stable
Diffstat (limited to 'src')
-rw-r--r--src/api.rs42
-rw-r--r--src/command/history.rs8
-rw-r--r--src/command/login.rs7
-rw-r--r--src/command/mod.rs8
-rw-r--r--src/command/search.rs110
-rw-r--r--src/command/server.rs6
-rw-r--r--src/command/sync.rs4
-rw-r--r--src/local/api_client.rs87
-rw-r--r--src/local/database.rs8
-rw-r--r--src/local/import.rs7
-rw-r--r--src/local/sync.rs36
-rw-r--r--src/main.rs43
-rw-r--r--src/remote/database.rs22
-rw-r--r--src/remote/mod.rs5
-rw-r--r--src/remote/server.rs61
-rw-r--r--src/remote/views.rs185
-rw-r--r--src/schema.rs30
-rw-r--r--src/server/auth.rs (renamed from src/remote/auth.rs)2
-rw-r--r--src/server/database.rs202
-rw-r--r--src/server/handlers/history.rs89
-rw-r--r--src/server/handlers/mod.rs6
-rw-r--r--src/server/handlers/user.rs140
-rw-r--r--src/server/mod.rs23
-rw-r--r--src/server/models.rs (renamed from src/remote/models.rs)43
-rw-r--r--src/server/router.rs121
-rw-r--r--src/settings.rs2
-rw-r--r--src/shell/atuin.zsh1
27 files changed, 832 insertions, 466 deletions
diff --git a/src/api.rs b/src/api.rs
index 90977404..82ee6604 100644
--- a/src/api.rs
+++ b/src/api.rs
@@ -1,8 +1,9 @@
use chrono::Utc;
-// This is shared between the client and the server, and has the data structures
-// representing the requests/responses for each method.
-// TODO: Properly define responses rather than using json!
+#[derive(Debug, Serialize, Deserialize)]
+pub struct UserResponse {
+ pub username: String,
+}
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterRequest {
@@ -12,12 +13,22 @@ pub struct RegisterRequest {
}
#[derive(Debug, Serialize, Deserialize)]
+pub struct RegisterResponse {
+ pub session: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize, Deserialize)]
+pub struct LoginResponse {
+ pub session: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
pub struct AddHistoryRequest {
pub id: String,
pub timestamp: chrono::DateTime<Utc>,
@@ -31,6 +42,29 @@ pub struct CountResponse {
}
#[derive(Debug, Serialize, Deserialize)]
-pub struct ListHistoryResponse {
+pub struct SyncHistoryRequest {
+ pub sync_ts: chrono::DateTime<chrono::FixedOffset>,
+ pub history_ts: chrono::DateTime<chrono::FixedOffset>,
+ pub host: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct SyncHistoryResponse {
pub history: Vec<String>,
}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ErrorResponse {
+ pub reason: String,
+}
+
+impl ErrorResponse {
+ pub fn reply(reason: &str, status: warp::http::StatusCode) -> impl warp::Reply {
+ warp::reply::with_status(
+ warp::reply::json(&ErrorResponse {
+ reason: String::from(reason),
+ }),
+ status,
+ )
+ }
+}
diff --git a/src/command/history.rs b/src/command/history.rs
index 3b4a717c..627efae4 100644
--- a/src/command/history.rs
+++ b/src/command/history.rs
@@ -53,7 +53,7 @@ fn print_list(h: &[History]) {
}
impl Cmd {
- pub fn run(&self, settings: &Settings, db: &mut impl Database) -> Result<()> {
+ pub async fn run(&self, settings: &Settings, db: &mut (impl Database + Send)) -> Result<()> {
match self {
Self::Start { command: words } => {
let command = words.join(" ");
@@ -69,6 +69,10 @@ impl Cmd {
}
Self::End { id, exit } => {
+ if id.trim() == "" {
+ return Ok(());
+ }
+
let mut h = db.load(id)?;
h.exit = *exit;
h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp.timestamp_nanos();
@@ -82,7 +86,7 @@ impl Cmd {
}
Ok(Fork::Child) => {
debug!("running periodic background sync");
- sync::sync(settings, false, db)?;
+ sync::sync(settings, false, db).await?;
}
Err(_) => println!("Fork failed"),
}
diff --git a/src/command/login.rs b/src/command/login.rs
index 4f58b77f..636ac0d3 100644
--- a/src/command/login.rs
+++ b/src/command/login.rs
@@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::fs::File;
use std::io::prelude::*;
-use eyre::Result;
+use eyre::{eyre, Result};
use structopt::StructOpt;
use crate::settings::Settings;
@@ -28,8 +28,13 @@ impl Cmd {
let url = format!("{}/login", settings.local.sync_address);
let client = reqwest::blocking::Client::new();
+
let resp = client.post(url).json(&map).send()?;
+ if resp.status() != reqwest::StatusCode::OK {
+ return Err(eyre!("invalid login details"));
+ }
+
let session = resp.json::<HashMap<String, String>>()?;
let session = session["session"].clone();
diff --git a/src/command/mod.rs b/src/command/mod.rs
index eeb11a87..cd857e9f 100644
--- a/src/command/mod.rs
+++ b/src/command/mod.rs
@@ -63,16 +63,16 @@ pub fn uuid_v4() -> String {
}
impl AtuinCmd {
- pub fn run(self, db: &mut impl Database, settings: &Settings) -> Result<()> {
+ pub async fn run<T: Database + Send>(self, db: &mut T, settings: &Settings) -> Result<()> {
match self {
- Self::History(history) => history.run(settings, db),
+ Self::History(history) => history.run(settings, db).await,
Self::Import(import) => import.run(db),
- Self::Server(server) => server.run(settings),
+ Self::Server(server) => server.run(settings).await,
Self::Stats(stats) => stats.run(db, settings),
Self::Init => init::init(),
Self::Search { query } => search::run(&query, db),
- Self::Sync { force } => sync::run(settings, force, db),
+ Self::Sync { force } => sync::run(settings, force, db).await,
Self::Login(l) => l.run(settings),
Self::Register(r) => register::run(
settings,
diff --git a/src/command/search.rs b/src/command/search.rs
index b9f3987c..d7b477da 100644
--- a/src/command/search.rs
+++ b/src/command/search.rs
@@ -1,6 +1,8 @@
use eyre::Result;
use itertools::Itertools;
use std::io::stdout;
+use std::time::Duration;
+
use termion::{event::Key, input::MouseTerminal, raw::IntoRawMode, screen::AlternateScreen};
use tui::{
backend::TermionBackend,
@@ -26,6 +28,78 @@ struct State {
results_state: ListState,
}
+#[allow(clippy::clippy::cast_sign_loss)]
+impl State {
+ fn durations(&self) -> Vec<String> {
+ self.results
+ .iter()
+ .map(|h| {
+ let duration =
+ Duration::from_millis(std::cmp::max(h.duration, 0) as u64 / 1_000_000);
+ let duration = humantime::format_duration(duration).to_string();
+ let duration: Vec<&str> = duration.split(' ').collect();
+
+ duration[0].to_string()
+ })
+ .collect()
+ }
+
+ fn render_results<T: tui::backend::Backend>(
+ &mut self,
+ f: &mut tui::Frame<T>,
+ r: tui::layout::Rect,
+ ) {
+ let durations = self.durations();
+ let max_length = durations
+ .iter()
+ .fold(0, |largest, i| std::cmp::max(largest, i.len()));
+
+ let results: Vec<ListItem> = self
+ .results
+ .iter()
+ .enumerate()
+ .map(|(i, m)| {
+ let command = m.command.to_string().replace("\n", " ").replace("\t", " ");
+
+ let mut command = Span::raw(command);
+
+ let mut duration = durations[i].clone();
+
+ while duration.len() < max_length {
+ duration.push(' ');
+ }
+
+ let duration = Span::styled(
+ duration,
+ Style::default().fg(if m.exit == 0 || m.duration == -1 {
+ Color::Green
+ } else {
+ Color::Red
+ }),
+ );
+
+ if let Some(selected) = self.results_state.selected() {
+ if selected == i {
+ command.style =
+ Style::default().fg(Color::Red).add_modifier(Modifier::BOLD);
+ }
+ }
+
+ let spans = Spans::from(vec![duration, Span::raw(" "), command]);
+
+ ListItem::new(spans)
+ })
+ .collect();
+
+ let results = List::new(results)
+ .block(Block::default().borders(Borders::ALL).title("History"))
+ .start_corner(Corner::BottomLeft)
+ .highlight_symbol(">> ");
+
+ f.render_stateful_widget(results, r, &mut self.results_state);
+ }
+}
+
fn query_results(app: &mut State, db: &mut impl Database) {
let results = match app.input.as_str() {
"" => db.list(),
@@ -48,7 +122,11 @@ fn key_handler(input: Key, db: &mut impl Database, app: &mut State) -> Option<St
Key::Esc | Key::Char('\n') => {
let i = app.results_state.selected().unwrap_or(0);
- return Some(app.results.get(i).unwrap().command.clone());
+ return Some(
+ app.results
+ .get(i)
+ .map_or("".to_string(), |h| h.command.clone()),
+ );
}
Key::Char(c) => {
app.input.push(c);
@@ -163,32 +241,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result<String> {
let help = Text::from(Spans::from(help));
let help = Paragraph::new(help);
- let input = Paragraph::new(app.input.as_ref())
- .block(Block::default().borders(Borders::ALL).title("Search"));
-
- let results: Vec<ListItem> = app
- .results
- .iter()
- .enumerate()
- .map(|(i, m)| {
- let mut content =
- Span::raw(m.command.to_string().replace("\n", " ").replace("\t", " "));
-
- if let Some(selected) = app.results_state.selected() {
- if selected == i {
- content.style =
- Style::default().fg(Color::Red).add_modifier(Modifier::BOLD);
- }
- }
-
- ListItem::new(content)
- })
- .collect();
-
- let results = List::new(results)
- .block(Block::default().borders(Borders::ALL).title("History"))
- .start_corner(Corner::BottomLeft)
- .highlight_symbol(">> ");
+ let input = Paragraph::new(app.input.clone())
+ .block(Block::default().borders(Borders::ALL).title("Query"));
let stats = Paragraph::new(Text::from(Span::raw(format!(
"history count: {}",
@@ -199,8 +253,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result<String> {
f.render_widget(title, top_left_chunks[0]);
f.render_widget(help, top_left_chunks[1]);
+ app.render_results(f, chunks[1]);
f.render_widget(stats, top_right_chunks[0]);
- f.render_stateful_widget(results, chunks[1], &mut app.results_state);
f.render_widget(input, chunks[2]);
f.set_cursor(
diff --git a/src/command/server.rs b/src/command/server.rs
index bf757948..a7835092 100644
--- a/src/command/server.rs
+++ b/src/command/server.rs
@@ -1,7 +1,7 @@
use eyre::Result;
use structopt::StructOpt;
-use crate::remote::server;
+use crate::server;
use crate::settings::Settings;
#[derive(StructOpt)]
@@ -20,7 +20,7 @@ pub enum Cmd {
}
impl Cmd {
- pub fn run(&self, settings: &Settings) -> Result<()> {
+ pub async fn run(&self, settings: &Settings) -> Result<()> {
match self {
Self::Start { host, port } => {
let host = host.as_ref().map_or(
@@ -29,7 +29,7 @@ impl Cmd {
);
let port = port.map_or(settings.server.port, |p| p);
- server::launch(settings, host, port)
+ server::launch(settings, host, port).await
}
}
}
diff --git a/src/command/sync.rs b/src/command/sync.rs
index facbe578..88217b3c 100644
--- a/src/command/sync.rs
+++ b/src/command/sync.rs
@@ -4,8 +4,8 @@ use crate::local::database::Database;
use crate::local::sync;
use crate::settings::Settings;
-pub fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
- sync::sync(settings, force, db)?;
+pub async fn run(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> {
+ sync::sync(settings, force, db).await?;
println!(
"Sync complete! {} items in database, force: {}",
db.history_count()?,
diff --git a/src/local/api_client.rs b/src/local/api_client.rs
index 434c07ba..1b64a295 100644
--- a/src/local/api_client.rs
+++ b/src/local/api_client.rs
@@ -1,93 +1,94 @@
use chrono::Utc;
use eyre::Result;
-use reqwest::header::AUTHORIZATION;
+use reqwest::header::{HeaderMap, AUTHORIZATION};
+use reqwest::Url;
+use sodiumoxide::crypto::secretbox;
-use crate::api::{AddHistoryRequest, CountResponse, ListHistoryResponse};
-use crate::local::encryption::{decrypt, load_key};
+use crate::api::{AddHistoryRequest, CountResponse, SyncHistoryResponse};
+use crate::local::encryption::decrypt;
use crate::local::history::History;
-use crate::settings::Settings;
use crate::utils::hash_str;
pub struct Client<'a> {
- settings: &'a Settings,
+ sync_addr: &'a str,
+ token: &'a str,
+ key: secretbox::Key,
+ client: reqwest::Client,
}
impl<'a> Client<'a> {
- pub const fn new(settings: &'a Settings) -> Self {
- Client { settings }
+ pub fn new(sync_addr: &'a str, token: &'a str, key: secretbox::Key) -> Self {
+ Client {
+ sync_addr,
+ token,
+ key,
+ client: reqwest::Client::new(),
+ }
}
- pub fn count(&self) -> Result<i64> {
- let url = format!("{}/sync/count", self.settings.local.sync_address);
- let client = reqwest::blocking::Client::new();
+ pub async fn count(&self) -> Result<i64> {
+ let url = format!("{}/sync/count", self.sync_addr);
+ let url = Url::parse(url.as_str())?;
+ let token = format!("Token {}", self.token);
+ let token = token.parse()?;
- let resp = client
- .get(url)
- .header(
- AUTHORIZATION,
- format!("Token {}", self.settings.local.session_token),
- )
- .send()?;
+ let mut headers = HeaderMap::new();
+ headers.insert(AUTHORIZATION, token);
+
+ let resp = self.client.get(url).headers(headers).send().await?;
- let count = resp.json::<CountResponse>()?;
+ let count = resp.json::<CountResponse>().await?;
Ok(count.count)
}
- pub fn get_history(
+ pub async fn get_history(
&self,
sync_ts: chrono::DateTime<Utc>,
history_ts: chrono::DateTime<Utc>,
host: Option<String>,
) -> Result<Vec<History>> {
- let key = load_key(self.settings)?;
-
let host = match host {
None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())),
Some(h) => h,
};
- // this allows for syncing between users on the same machine
let url = format!(
"{}/sync/history?sync_ts={}&history_ts={}&host={}",
- self.settings.local.sync_address,
- sync_ts.to_rfc3339(),
- history_ts.to_rfc3339(),
+ self.sync_addr,
+ urlencoding::encode(sync_ts.to_rfc3339().as_str()),
+ urlencoding::encode(history_ts.to_rfc3339().as_str()),
host,
);
- let client = reqwest::blocking::Client::new();
- let resp = client
+ let resp = self
+ .client
.get(url)
- .header(
- AUTHORIZATION,
- format!("Token {}", self.settings.local.session_token),
- )
- .send()?;
+ .header(AUTHORIZATION, format!("Token {}", self.token))
+ .send()
+ .await?;
- let history = resp.json::<ListHistoryResponse>()?;
+ let history = resp.json::<SyncHistoryResponse>().await?;
let history = history
.history
.iter()
.map(|h| serde_json::from_str(h).expect("invalid base64"))
- .map(|h| decrypt(&h, &key).expect("failed to decrypt history! check your key"))
+ .map(|h| decrypt(&h, &self.key).expect("failed to decrypt history! check your key"))
.collect();
Ok(history)
}
- pub fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
- let client = reqwest::blocking::Client::new();
+ pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
+ let url = format!("{}/history", self.sync_addr);
+ let url = Url::parse(url.as_str())?;
- let url = format!("{}/history", self.settings.local.sync_address);
- client
+ self.client
.post(url)
.json(history)
- .header(
- AUTHORIZATION,
- format!("Token {}", self.settings.local.session_token),
- )
- .send()?;
+ .header(AUTHORIZATION, format!("Token {}", self.token))
+ .send()
+ .await?;
Ok(())
}
diff --git a/src/local/database.rs b/src/local/database.rs
index 977f11cc..abc22bb8 100644
--- a/src/local/database.rs
+++ b/src/local/database.rs
@@ -215,9 +215,9 @@ impl Database for Sqlite {
}
fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> {
- let mut stmt = self.conn.prepare(
- "SELECT * FROM history where timestamp <= ? order by timestamp desc limit ?",
- )?;
+ let mut stmt = self
+ .conn
+ .prepare("SELECT * FROM history where timestamp < ? order by timestamp desc limit ?")?;
let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| {
history_from_sqlite_row(None, row)
@@ -236,7 +236,7 @@ impl Database for Sqlite {
fn prefix_search(&self, query: &str) -> Result<Vec<History>> {
self.query(
- "select * from history where command like ?1 || '%' order by timestamp asc",
+ "select * from history where command like ?1 || '%' order by timestamp asc limit 1000",
&[query],
)
}
diff --git a/src/local/import.rs b/src/local/import.rs
index d0f679c9..3b0b2a69 100644
--- a/src/local/import.rs
+++ b/src/local/import.rs
@@ -7,6 +7,7 @@ use std::{fs::File, path::Path};
use chrono::prelude::*;
use chrono::Utc;
use eyre::{eyre, Result};
+use itertools::Itertools;
use super::history::History;
@@ -42,8 +43,8 @@ impl Zsh {
fn parse_extended(line: &str, counter: i64) -> History {
let line = line.replacen(": ", "", 2);
- let (time, duration) = line.split_once(':').unwrap();
- let (duration, command) = duration.split_once(';').unwrap();
+ let (time, duration) = line.splitn(2, ':').collect_tuple().unwrap();
+ let (duration, command) = duration.splitn(2, ';').collect_tuple().unwrap();
let time = time
.parse::<i64>()
@@ -60,7 +61,7 @@ fn parse_extended(line: &str, counter: i64) -> History {
time,
command.trim_end().to_string(),
String::from("unknown"),
- -1,
+ 0, // assume 0, we have no way of knowing :(
duration,
None,
None,
diff --git a/src/local/sync.rs b/src/local/sync.rs
index c22d2f27..e0feb759 100644
--- a/src/local/sync.rs
+++ b/src/local/sync.rs
@@ -20,12 +20,12 @@ use crate::{api::AddHistoryRequest, utils::hash_str};
// Check if remote has things we don't, and if so, download them.
// Returns (num downloaded, total local)
-fn sync_download(
+async fn sync_download(
force: bool,
- client: &api_client::Client,
- db: &mut impl Database,
+ client: &api_client::Client<'_>,
+ db: &mut (impl Database + Send),
) -> Result<(i64, i64)> {
- let remote_count = client.count()?;
+ let remote_count = client.count().await?;
let initial_local = db.history_count()?;
let mut local_count = initial_local;
@@ -41,7 +41,9 @@ fn sync_download(
let host = if force { Some(String::from("")) } else { None };
while remote_count > local_count {
- let page = client.get_history(last_sync, last_timestamp, host.clone())?;
+ let page = client
+ .get_history(last_sync, last_timestamp, host.clone())
+ .await?;
if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() {
break;
@@ -71,13 +73,13 @@ fn sync_download(
}
// Check if we have things remote doesn't, and if so, upload them
-fn sync_upload(
+async fn sync_upload(
settings: &Settings,
_force: bool,
- client: &api_client::Client,
- db: &mut impl Database,
+ client: &api_client::Client<'_>,
+ db: &mut (impl Database + Send),
) -> Result<()> {
- let initial_remote_count = client.count()?;
+ let initial_remote_count = client.count().await?;
let mut remote_count = initial_remote_count;
let local_count = db.history_count()?;
@@ -111,21 +113,25 @@ fn sync_upload(
}
// anything left over outside of the 100 block size
- client.post_history(&buffer)?;
+ client.post_history(&buffer).await?;
cursor = buffer.last().unwrap().timestamp;
- remote_count = client.count()?;
+ remote_count = client.count().await?;
}
Ok(())
}
-pub fn sync(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
- let client = api_client::Client::new(settings);
+pub async fn sync(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> {
+ let client = api_client::Client::new(
+ settings.local.sync_address.as_str(),
+ settings.local.session_token.as_str(),
+ load_key(settings)?,
+ );
- sync_upload(settings, force, &client, db)?;
+ sync_upload(settings, force, &client, db).await?;
- let download = sync_download(force, &client, db)?;
+ let download = sync_download(force, &client, db).await?;
debug!("sync downloaded {}", download.0);
diff --git a/src/main.rs b/src/main.rs
index 94c7366d..0045a943 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,11 +1,10 @@
-#![feature(proc_macro_hygiene)]
-#![feature(decl_macro)]
#![warn(clippy::pedantic, clippy::nursery)]
#![allow(clippy::use_self)] // not 100% reliable
use std::path::PathBuf;
use eyre::{eyre, Result};
+use fern::colors::{Color, ColoredLevelConfig};
use human_panic::setup_panic;
use structopt::{clap::AppSettings, StructOpt};
@@ -13,20 +12,8 @@ use structopt::{clap::AppSettings, StructOpt};
extern crate log;
#[macro_use]
-extern crate rocket;
-
-#[macro_use]
extern crate serde_derive;
-#[macro_use]
-extern crate diesel;
-
-#[macro_use]
-extern crate diesel_migrations;
-
-#[macro_use]
-extern crate rocket_contrib;
-
use command::AtuinCmd;
use local::database::Sqlite;
use settings::Settings;
@@ -34,12 +21,10 @@ use settings::Settings;
mod api;
mod command;
mod local;
-mod remote;
+mod server;
mod settings;
mod utils;
-pub mod schema;
-
#[derive(StructOpt)]
#[structopt(
author = "Ellie Huxtable <e@elm.sh>",
@@ -56,7 +41,7 @@ struct Atuin {
}
impl Atuin {
- fn run(self, settings: &Settings) -> Result<()> {
+ async fn run(self, settings: &Settings) -> Result<()> {
let db_path = if let Some(db_path) = self.db {
let path = db_path
.to_str()
@@ -69,26 +54,32 @@ impl Atuin {
let mut db = Sqlite::new(db_path)?;
- self.atuin.run(&mut db, settings)
+ self.atuin.run(&mut db, settings).await
}
}
-fn main() -> Result<()> {
- setup_panic!();
- let settings = Settings::new()?;
+#[tokio::main]
+async fn main() -> Result<()> {
+ let colors = ColoredLevelConfig::new()
+ .warn(Color::Yellow)
+ .error(Color::Red);
fern::Dispatch::new()
- .format(|out, message, record| {
+ .format(move |out, message, record| {
out.finish(format_args!(
"{} [{}] {}",
- chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"),
- record.level(),
+ chrono::Local::now().to_rfc3339(),
+ colors.color(record.level()),
message
))
})
.level(log::LevelFilter::Info)
+ .level_for("sqlx", log::LevelFilter::Warn)
.chain(std::io::stdout())
.apply()?;
- Atuin::from_args().run(&settings)
+ let settings = Settings::new()?;
+ setup_panic!();
+
+ Atuin::from_args().run(&settings).await
}
diff --git a/src/remote/database.rs b/src/remote/database.rs
deleted file mode 100644
index 03973ca1..00000000
--- a/src/remote/database.rs
+++ /dev/null
@@ -1,22 +0,0 @@
-use diesel::pg::PgConnection;
-use diesel::prelude::*;
-use eyre::{eyre, Result};
-
-use crate::settings::Settings;
-
-#[database("atuin")]
-pub struct AtuinDbConn(diesel::PgConnection);
-
-// TODO: connection pooling
-pub fn establish_connection(settings: &Settings) -> Result<PgConnection> {
- if settings.server.db_uri == "default_uri" {
- Err(eyre!(
- "Please configure your database! Set db_uri in config.toml"
- ))
- } else {
- let database_url = &settings.server.db_uri;
- let conn = PgConnection::establish(database_url)?;
-
- Ok(conn)
- }
-}
diff --git a/src/remote/mod.rs b/src/remote/mod.rs
deleted file mode 100644
index 7147b88e..00000000
--- a/src/remote/mod.rs
+++ /dev/null
@@ -1,5 +0,0 @@
-pub mod auth;
-pub mod database;
-pub mod models;
-pub mod server;
-pub mod views;
diff --git a/src/remote/server.rs b/src/remote/server.rs
deleted file mode 100644
index ee481ca4..00000000
--- a/src/remote/server.rs
+++ /dev/null
@@ -1,61 +0,0 @@
-use std::collections::HashMap;
-
-use crate::remote::database::establish_connection;
-use crate::settings::Settings;
-
-use super::database::AtuinDbConn;
-
-use eyre::Result;
-use rocket::config::{Config, Environment, LoggingLevel, Value};
-
-// a bunch of these imports are generated by macros, it's easier to wildcard
-#[allow(clippy::clippy::wildcard_imports)]
-use super::views::*;
-
-#[allow(clippy::clippy::wildcard_imports)]
-use super::auth::*;
-
-embed_migrations!("migrations");
-
-pub fn launch(settings: &Settings, host: String, port: u16) -> Result<()> {
- let settings: Settings = settings.clone(); // clone so rocket can manage it
-
- let mut database_config = HashMap::new();
- let mut databases = HashMap::new();
-
- database_config.insert("url", Value::from(settings.server.db_uri.clone()));
- databases.insert("atuin", Value::from(database_config));
-
- let connection = establish_connection(&settings)?;
-
- embedded_migrations::run(&connection).expect("failed to run migrations");
-
- let config = Config::build(Environment::Production)
- .address(host)
- .log_level(LoggingLevel::Normal)
- .port(port)
- .extra("databases", databases)
- .finalize()
- .unwrap();
-
- let app = rocket::custom(config);
-
- app.mount(
- "/",
- routes![
- index,
- register,
- add_history,
- login,
- get_user,
- sync_count,
- sync_list
- ],
- )
- .manage(settings)
- .attach(AtuinDbConn::fairing())
- .register(catchers![internal_error, bad_request])
- .launch();
-
- Ok(())
-}
diff --git a/src/remote/views.rs b/src/remote/views.rs
deleted file mode 100644
index 08dff13e..00000000
--- a/src/remote/views.rs
+++ /dev/null
@@ -1,185 +0,0 @@
-use chrono::Utc;
-use rocket::http::uri::Uri;
-use rocket::http::RawStr;
-use rocket::http::{ContentType, Status};
-use rocket::request::FromFormValue;
-use rocket::request::Request;
-use rocket::response;
-use rocket::response::{Responder, Response};
-use rocket_contrib::databases::diesel;
-use rocket_contrib::json::{Json, JsonValue};
-
-use self::diesel::prelude::*;
-
-use crate::api::AddHistoryRequest;
-use crate::schema::history;
-use crate::settings::HISTORY_PAGE_SIZE;
-
-use super::database::AtuinDbConn;
-use super::models::{History, NewHistory, User};
-
-#[derive(Debug)]
-pub struct ApiResponse {
- pub json: JsonValue,
- pub status: Status,
-}
-
-impl<'r> Responder<'r> for ApiResponse {
- fn respond_to(self, req: &Request) -> response::Result<'r> {
- Response::build_from(self.json.respond_to(req).unwrap())
- .status(self.status)
- .header(ContentType::JSON)
- .ok()
- }
-}
-
-#[get("/")]
-pub const fn index() -> &'static str {
- "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
-}
-
-#[catch(500)]
-pub fn internal_error(_req: &Request) -> ApiResponse {
- ApiResponse {
- status: Status::InternalServerError,
- json: json!({"status": "error", "message": "an internal server error has occured"}),
- }
-}
-
-#[catch(400)]
-pub fn bad_request(_req: &Request) -> ApiResponse {
- ApiResponse {
- status: Status::InternalServerError,
- json: json!({"status": "error", "message": "bad request. don't do that."}),
- }
-}
-
-#[post("/history", data = "<add_history>")]
-#[allow(
- clippy::clippy::cast_sign_loss,
- clippy::cast_possible_truncation,
- clippy::clippy::needless_pass_by_value
-)]
-pub fn add_history(
- conn: AtuinDbConn,
- user: User,
- add_history: Json<Vec<AddHistoryRequest>>,
-) -> ApiResponse {
- let new_history: Vec<NewHistory> = add_history
- .iter()
- .map(|h| NewHistory {
- client_id: h.id.as_str(),
- hostname: h.hostname.to_string(),
- user_id: user.id,
- timestamp: h.timestamp.naive_utc(),
- data: h.data.as_str(),
- })
- .collect();
-
- match diesel::insert_into(history::table)
- .values(&new_history)
- .on_conflict_do_nothing()
- .execute(&*conn)
- {
- Ok(_) => ApiResponse {
- status: Status::Ok,
- json: json!({"status": "ok", "message": "history added"}),
- },
- Err(_) => ApiResponse {
- status: Status::BadRequest,
- json: json!({"status": "error", "message": "failed to add history"}),
- },
- }
-}
-
-#[get("/sync/count")]
-#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
-pub fn sync_count(conn: AtuinDbConn, user: User) -> ApiResponse {
- use crate::schema::history::dsl::*;
-
- // we need to return the number of history items we have for this user
- // in the future I'd like to use something like a merkel tree to calculate
- // which day specifically needs syncing
- let count = history
- .filter(user_id.eq(user.id))
- .count()
- .first::<i64>(&*conn);
-
- if count.is_err() {
- error!("failed to count: {}", count.err().unwrap());
-
- return ApiResponse {
- json: json!({"message": "internal server error"}),
- status: Status::InternalServerError,
- };
- }
-
- ApiResponse {
- status: Status::Ok,
- json: json!({"count": count.ok()}),
- }
-}
-
-pub struct UtcDateTime(chrono::DateTime<Utc>);
-
-impl<'v> FromFormValue<'v> for UtcDateTime {
- type Error = &'v RawStr;
-
- fn from_form_value(form_value: &'v RawStr) -> Result<UtcDateTime, &'v RawStr> {
- let time = Uri::percent_decode(form_value.as_bytes()).map_err(|_| form_value)?;
- let time = time.to_string();
-
- match chrono::DateTime::parse_from_rfc3339(time.as_str()) {
- Ok(t) => Ok(UtcDateTime(t.with_timezone(&Utc))),
- Err(e) => {
- error!("failed to parse time {}, got: {}", time, e);
- Err(form_value)
- }
- }
- }
-}
-
-// Request a list of all history items added to the DB after a given timestamp.
-// Provide the current hostname, so that we don't send the client data that
-// originated from them
-#[get("/sync/history?<sync_ts>&<history_ts>&<host>")]
-#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
-pub fn sync_list(
- conn: AtuinDbConn,
- user: User,
- sync_ts: UtcDateTime,
- history_ts: UtcDateTime,
- host: String,
-) -> ApiResponse {
- use crate::schema::history::dsl::*;
-
- // we need to return the number of history items we have for this user
- // in the future I'd like to use something like a merkel tree to calculate
- // which day specifically needs syncing
- // TODO: Allow for configuring the page size, both from params, and setting
- // the max in config. 100 is fine for now.
- let h = history
- .filter(user_id.eq(user.id))
- .filter(hostname.ne(host))
- .filter(created_at.ge(sync_ts.0.naive_utc()))
- .filter(timestamp.ge(history_ts.0.naive_utc()))
- .order(timestamp.asc())
- .limit(HISTORY_PAGE_SIZE)
- .load::<History>(&*conn);
-
- if let Err(e) = h {
- error!("failed to load history: {}", e);
-
- return ApiResponse {
- json: json!({"message": "internal server error"}),
- status: Status::InternalServerError,
- };
- }
-
- let user_data: Vec<String> = h.unwrap().iter().map(|i| i.data.to_string()).collect();
-
- ApiResponse {
- status: Status::Ok,
- json: json!({ "history": user_data }),
- }
-}
diff --git a/src/schema.rs b/src/schema.rs
deleted file mode 100644
index 84bf5bab..00000000
--- a/src/schema.rs
+++ /dev/null
@@ -1,30 +0,0 @@
-table! {
- history (id) {
- id -> Int8,
- client_id -> Text,
- user_id -> Int8,
- hostname -> Text,
- timestamp -> Timestamp,
- data -> Varchar,
- created_at -> Timestamp,
- }
-}
-
-table! {
- sessions (id) {
- id -> Int8,
- user_id -> Int8,
- token -> Varchar,
- }
-}
-
-table! {
- users (id) {
- id -> Int8,
- username -> Varchar,
- email -> Varchar,
- password -> Varchar,
- }
-}
-
-allow_tables_to_appear_in_same_query!(history, sessions, users,);
diff --git a/src/remote/auth.rs b/src/server/auth.rs
index cf61b077..52a73108 100644
--- a/src/remote/auth.rs
+++ b/src/server/auth.rs
@@ -1,3 +1,4 @@
+/*
use self::diesel::prelude::*;
use eyre::Result;
use rocket::http::Status;
@@ -218,3 +219,4 @@ pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse {
json: json!({"session": session.token}),
}
}
+*/
diff --git a/src/server/database.rs b/src/server/database.rs
new file mode 100644
index 00000000..5945baaf
--- /dev/null
+++ b/src/server/database.rs
@@ -0,0 +1,202 @@
+use async_trait::async_trait;
+
+use eyre::{eyre, Result};
+use sqlx::postgres::PgPoolOptions;
+
+use crate::settings::HISTORY_PAGE_SIZE;
+
+use super::models::{History, NewHistory, NewSession, NewUser, Session, User};
+
+#[async_trait]
+pub trait Database {
+ async fn get_session(&self, token: &str) -> Result<Session>;
+ async fn get_session_user(&self, token: &str) -> Result<User>;
+ async fn add_session(&self, session: &NewSession) -> Result<()>;
+
+ async fn get_user(&self, username: String) -> Result<User>;
+ async fn get_user_session(&self, u: &User) -> Result<Session>;
+ async fn add_user(&self, user: NewUser) -> Result<i64>;
+
+ async fn count_history(&self, user: &User) -> Result<i64>;
+ async fn list_history(
+ &self,
+ user: &User,
+ created_since: chrono::NaiveDateTime,
+ since: chrono::NaiveDateTime,
+ host: String,
+ ) -> Result<Vec<History>>;
+ async fn add_history(&self, history: &[NewHistory]) -> Result<()>;
+}
+
+#[derive(Clone)]
+pub struct Postgres {
+ pool: sqlx::Pool<sqlx::postgres::Postgres>,
+}
+
+impl Postgres {
+ pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
+ let pool = PgPoolOptions::new()
+ .max_connections(100)
+ .connect(uri)
+ .await?;
+
+ Ok(Self { pool })
+ }
+}
+
+#[async_trait]
+impl Database for Postgres {
+ async fn get_session(&self, token: &str) -> Result<Session> {
+ let res: Option<Session> =
+ sqlx::query_as::<_, Session>("select * from sessions where token = $1")
+ .bind(token)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ if let Some(s) = res {
+ Ok(s)
+ } else {
+ Err(eyre!("could not find session"))
+ }
+ }
+
+ async fn get_user(&self, username: String) -> Result<User> {
+ let res: Option<User> =
+ sqlx::query_as::<_, User>("select * from users where username = $1")
+ .bind(username)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ if let Some(u) = res {
+ Ok(u)
+ } else {
+ Err(eyre!("could not find user"))
+ }
+ }
+
+ async fn get_session_user(&self, token: &str) -> Result<User> {
+ let res: Option<User> = sqlx::query_as::<_, User>(
+ "select * from users
+ inner join sessions
+ on users.id = sessions.user_id
+ and sessions.token = $1",
+ )
+ .bind(token)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ if let Some(u) = res {
+ Ok(u)
+ } else {
+ Err(eyre!("could not find user"))
+ }
+ }
+
+ async fn count_history(&self, user: &User) -> Result<i64> {
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .fetch_one(&self.pool)
+ .await?;
+
+ Ok(res.0)
+ }
+
+ async fn list_history(
+ &self,
+ user: &User,
+ created_since: chrono::NaiveDateTime,
+ since: chrono::NaiveDateTime,
+ host: String,
+ ) -> Result<Vec<History>> {
+ let res = sqlx::query_as::<_, History>(
+ "select * from history
+ where user_id = $1
+ and hostname != $2
+ and created_at >= $3
+ and timestamp >= $4
+ order by timestamp asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(host)
+ .bind(created_since)
+ .bind(since)
+ .bind(HISTORY_PAGE_SIZE)
+ .fetch_all(&self.pool)
+ .await?;
+
+ Ok(res)
+ }
+
+ async fn add_history(&self, history: &[NewHistory]) -> Result<()> {
+ let mut tx = self.pool.begin().await?;
+
+ for i in history {
+ sqlx::query(
+ "insert into history
+ (client_id, user_id, hostname, timestamp, data)
+ values ($1, $2, $3, $4, $5)
+ on conflict do nothing
+ ",
+ )
+ .bind(i.client_id)
+ .bind(i.user_id)
+ .bind(i.hostname)
+ .bind(i.timestamp)
+ .bind(i.data)
+ .execute(&mut tx)
+ .await?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ async fn add_user(&self, user: NewUser) -> Result<i64> {
+ let res: (i64,) = sqlx::query_as(
+ "insert into users
+ (username, email, password)
+ values($1, $2, $3)
+ returning id",
+ )
+ .bind(user.username.as_str())
+ .bind(user.email.as_str())
+ .bind(user.password)
+ .fetch_one(&self.pool)
+ .await?;
+
+ Ok(res.0)
+ }
+
+ async fn add_session(&self, session: &NewSession) -> Result<()> {
+ sqlx::query(
+ "insert into sessions
+ (user_id, token)
+ values($1, $2)",
+ )
+ .bind(session.user_id)
+ .bind(session.token)
+ .execute(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ async fn get_user_session(&self, u: &User) -> Result<Session> {
+ let res: Option<Session> =
+ sqlx::query_as::<_, Session>("select * from sessions where user_id = $1")
+ .bind(u.id)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ if let Some(s) = res {
+ Ok(s)
+ } else {
+ Err(eyre!("could not find session"))
+ }
+ }
+}
diff --git a/src/server/handlers/history.rs b/src/server/handlers/history.rs
new file mode 100644
index 00000000..4fd6f03f
--- /dev/null
+++ b/src/server/handlers/history.rs
@@ -0,0 +1,89 @@
+use std::convert::Infallible;
+
+use warp::{http::StatusCode, reply::json};
+
+use crate::api::{
+ AddHistoryRequest, CountResponse, ErrorResponse, SyncHistoryRequest, SyncHistoryResponse,
+};
+use crate::server::database::Database;
+use crate::server::models::{NewHistory, User};
+
+pub async fn count(
+ user: User,
+ db: impl Database + Clone + Send + Sync,
+) -> Result<Box<dyn warp::Reply>, Infallible> {
+ db.count_history(&user).await.map_or(
+ Ok(Box::new(ErrorResponse::reply(
+ "failed to query history count",
+ StatusCode::INTERNAL_SERVER_ERROR,
+ ))),
+ |count| Ok(Box::new(json(&CountResponse { count }))),
+ )
+}
+
+pub async fn list(
+ req: SyncHistoryRequest,
+ user: User,
+ db: impl Database + Clone + Send + Sync,
+) -> Result<Box<dyn warp::Reply>, Infallible> {
+ let history = db
+ .list_history(
+ &user,
+ req.sync_ts.naive_utc(),
+ req.history_ts.naive_utc(),
+ req.host,
+ )
+ .await;
+
+ if let Err(e) = history {
+ error!("failed to load history: {}", e);
+ let resp =
+ ErrorResponse::reply("failed to load history", StatusCode::INTERNAL_SERVER_ERROR);
+ let resp = Box::new(resp);
+ return Ok(resp);
+ }
+
+ let history: Vec<String> = history
+ .unwrap()
+ .iter()
+ .map(|i| i.data.to_string())
+ .collect();
+
+ debug!(
+ "loaded {} items of history for user {}",
+ history.len(),
+ user.id
+ );
+
+ Ok(Box::new(json(&SyncHistoryResponse { history })))
+}
+
+pub async fn add(
+ req: Vec<AddHistoryRequest>,
+ user: User,
+ db: impl Database + Clone + Send + Sync,
+) -> Result<Box<dyn warp::Reply>, Infallible> {
+ debug!("request to add {} history items", req.len());
+
+ let history: Vec<NewHistory> = req
+ .iter()
+ .map(|h| NewHistory {
+ client_id: h.id.as_str(),
+ user_id: user.id,
+ hostname: h.hostname.as_str(),
+ timestamp: h.timestamp.naive_utc(),
+ data: h.data.as_str(),
+ })
+ .collect();
+
+ if let Err(e) = db.add_history(&history).await {
+ error!("failed to add history: {}", e);
+
+ return Ok(Box::new(ErrorResponse::reply(
+ "failed to add history",
+ StatusCode::INTERNAL_SERVER_ERROR,
+ )));
+ };
+
+ Ok(Box::new(warp::reply()))
+}
diff --git a/src/server/handlers/mod.rs b/src/server/handlers/mod.rs
new file mode 100644
index 00000000..3c20538c
--- /dev/null
+++ b/src/server/handlers/mod.rs
@@ -0,0 +1,6 @@
+pub mod history;
+pub mod user;
+
+pub const fn index() -> &'static str {
+ "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
+}
diff --git a/src/server/handlers/user.rs b/src/server/handlers/user.rs
new file mode 100644
index 00000000..782d7dbd
--- /dev/null
+++ b/src/server/handlers/user.rs
@@ -0,0 +1,140 @@
+use std::convert::Infallible;
+
+use sodiumoxide::crypto::pwhash::argon2id13;
+use uuid::Uuid;
+use warp::http::StatusCode;
+use warp::reply::json;
+
+use crate::api::{
+ ErrorResponse, LoginRequest, LoginResponse, RegisterRequest, RegisterResponse, UserResponse,
+};
+use crate::server::database::Database;
+use crate::server::models::{NewSession, NewUser};
+use crate::settings::Settings;
+use crate::utils::hash_secret;
+
+pub fn verify_str(secret: &str, verify: &str) -> bool {
+ sodiumoxide::init().unwrap();
+
+ let mut padded = [0_u8; 128];
+ secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
+ padded[i] = *val;
+ });
+
+ match argon2id13::HashedPassword::from_slice(&padded) {
+ Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
+ None => false,
+ }
+}
+
+pub async fn get(
+ username: String,
+ db: impl Database + Clone + Send + Sync,
+) -> Result<Box<dyn warp::Reply>, Infallible> {
+ let user = match db.get_user(username).await {
+ Ok(user) => user,
+ Err(e) => {
+ debug!("user not found: {}", e);
+ return Ok(Box::new(ErrorResponse::reply(
+ "user not found",
+ StatusCode::NOT_FOUND,
+ )));
+ }
+ };
+
+ Ok(Box::new(warp::reply::json(&UserResponse {
+ username: user.username,
+ })))
+}
+
+pub async fn register(
+ register: RegisterRequest,
+ settings: Settings,
+ db: impl Database + Clone + Send + Sync,
+) -> Result<Box<dyn warp::Reply>, Infallible> {
+ if !settings.server.open_registration {
+ return Ok(Box::new(ErrorResponse::reply(
+ "this server is not open for registrations",
+ StatusCode::BAD_REQUEST,
+ )));
+ }
+
+ let hashed = hash_secret(register.password.as_str());
+
+ let new_user = NewUser {
+ email: register.email,
+ username: register.username,
+ password: hashed,
+ };
+
+ let user_id = match db.add_user(new_user).await {
+ Ok(id) => id,
+ Err(e) => {
+ error!("failed to add user: {}", e);
+ return Ok(Box::new(ErrorResponse::reply(
+ "failed to add user",
+ StatusCode::BAD_REQUEST,
+ )));
+ }
+ };
+
+ let token = Uuid::new_v4().to_simple().to_string();
+
+ let new_session = NewSession {
+ user_id,
+ token: token.as_str(),
+ };
+
+ match db.add_session(&new_session).await {
+ Ok(_) => Ok(Box::new(json(&RegisterResponse { session: token }))),
+ Err(e) => {
+ error!("failed to add session: {}", e);
+ Ok(Box::new(ErrorResponse::reply(
+ "failed to register user",
+ StatusCode::BAD_REQUEST,
+ )))
+ }
+ }
+}
+
+pub async fn login(
+ login: LoginRequest,
+ db: impl Database + Clone + Send + Sync,
+) -> Result<Box<dyn warp::Reply>, Infallible> {
+ let user = match db.get_user(login.username.clone()).await {
+ Ok(u) => u,
+ Err(e) => {
+ error!("failed to get user {}: {}", login.username.clone(), e);
+
+ return Ok(Box::new(ErrorResponse::reply(
+ "user not found",
+ StatusCode::NOT_FOUND,
+ )));
+ }
+ };
+
+ let session = match db.get_user_session(&user).await {
+ Ok(u) => u,
+ Err(e) => {
+ error!("failed to get session for {}: {}", login.username, e);
+
+ return Ok(Box::new(ErrorResponse::reply(
+ "user not found",
+ StatusCode::NOT_FOUND,
+ )));
+ }
+ };
+
+ let verified = verify_str(user.password.as_str(), login.password.as_str());
+
+ if !verified {
+ return Ok(Box::new(ErrorResponse::reply(
+ "user not found",
+ StatusCode::NOT_FOUND,
+ )));
+ }
+
+ Ok(Box::new(warp::reply::json(&LoginResponse {
+ session: session.token,
+ })))
+}
diff --git a/src/server/mod.rs b/src/server/mod.rs
new file mode 100644
index 00000000..d5e083df
--- /dev/null
+++ b/src/server/mod.rs
@@ -0,0 +1,23 @@
+use std::net::IpAddr;
+
+use eyre::Result;
+
+use crate::settings::Settings;
+
+pub mod auth;
+pub mod database;
+pub mod handlers;
+pub mod models;
+pub mod router;
+
+pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> {
+ // routes to run:
+ // index, register, add_history, login, get_user, sync_count, sync_list
+ let host = host.parse::<IpAddr>()?;
+
+ let r = router::router(settings).await?;
+
+ warp::serve(r).run((host, port)).await;
+
+ Ok(())
+}
diff --git a/src/remote/models.rs b/src/server/models.rs
index 7f6f7766..fbf1897e 100644
--- a/src/remote/models.rs
+++ b/src/server/models.rs
@@ -1,10 +1,6 @@
use chrono::prelude::*;
-use crate::schema::{history, sessions, users};
-
-#[derive(Deserialize, Serialize, Identifiable, Queryable, Associations)]
-#[table_name = "history"]
-#[belongs_to(User)]
+#[derive(sqlx::FromRow)]
pub struct History {
pub id: i64,
pub client_id: String, // a client generated ID
@@ -17,7 +13,16 @@ pub struct History {
pub created_at: NaiveDateTime,
}
-#[derive(Identifiable, Queryable, Associations)]
+pub struct NewHistory<'a> {
+ pub client_id: &'a str,
+ pub user_id: i64,
+ pub hostname: &'a str,
+ pub timestamp: chrono::NaiveDateTime,
+
+ pub data: &'a str,
+}
+
+#[derive(sqlx::FromRow)]
pub struct User {
pub id: i64,
pub username: String,
@@ -25,35 +30,19 @@ pub struct User {
pub password: String,
}
-#[derive(Queryable, Identifiable, Associations)]
-#[belongs_to(User)]
+#[derive(sqlx::FromRow)]
pub struct Session {
pub id: i64,
pub user_id: i64,
pub token: String,
}
-#[derive(Insertable)]
-#[table_name = "history"]
-pub struct NewHistory<'a> {
- pub client_id: &'a str,
- pub user_id: i64,
- pub hostname: String,
- pub timestamp: chrono::NaiveDateTime,
-
- pub data: &'a str,
-}
-
-#[derive(Insertable)]
-#[table_name = "users"]
-pub struct NewUser<'a> {
- pub username: &'a str,
- pub email: &'a str,
- pub password: &'a str,
+pub struct NewUser {
+ pub username: String,
+ pub email: String,
+ pub password: String,
}
-#[derive(Insertable)]
-#[table_name = "sessions"]
pub struct NewSession<'a> {
pub user_id: i64,
pub token: &'a str,
diff --git a/src/server/router.rs b/src/server/router.rs
new file mode 100644
index 00000000..ed317ab2
--- /dev/null
+++ b/src/server/router.rs
@@ -0,0 +1,121 @@
+use std::convert::Infallible;
+
+use eyre::Result;
+use warp::Filter;
+
+use super::handlers;
+use super::{database::Database, database::Postgres};
+use crate::server::models::User;
+use crate::{api::SyncHistoryRequest, settings::Settings};
+
+fn with_settings(
+ settings: Settings,
+) -> impl Filter<Extract = (Settings,), Error = Infallible> + Clone {
+ warp::any().map(move || settings.clone())
+}
+
+fn with_db(
+ db: impl Database + Clone + Send + Sync,
+) -> impl Filter<Extract = (impl Database + Clone,), Error = Infallible> + Clone {
+ warp::any().map(move || db.clone())
+}
+
+fn with_user(
+ postgres: Postgres,
+) -> impl Filter<Extract = (User,), Error = warp::Rejection> + Clone {
+ warp::header::<String>("authorization").and_then(move |header: String| {
+ // async closures are still buggy :(
+ let postgres = postgres.clone();
+
+ async move {
+ let header: Vec<&str> = header.split(' ').collect();
+
+ let token;
+
+ if header.len() == 2 {
+ if header[0] != "Token" {
+ return Err(warp::reject());
+ }
+
+ token = header[1];
+ } else {
+ return Err(warp::reject());
+ }
+
+ let user = postgres
+ .get_session_user(token)
+ .await
+ .map_err(|_| warp::reject())?;
+
+ Ok(user)
+ }
+ })
+}
+
+pub async fn router(
+ settings: &Settings,
+) -> Result<impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone> {
+ let postgres = Postgres::new(settings.server.db_uri.as_str()).await?;
+ let index = warp::get().and(warp::path::end()).map(handlers::index);
+
+ let count = warp::get()
+ .and(warp::path("sync"))
+ .and(warp::path("count"))
+ .and(warp::path::end())
+ .and(with_user(postgres.clone()))
+ .and(with_db(postgres.clone()))
+ .and_then(handlers::history::count);
+
+ let sync = warp::get()
+ .and(warp::path("sync"))
+ .and(warp::path("history"))
+ .and(warp::query::<SyncHistoryRequest>())
+ .and(warp::path::end())
+ .and(with_user(postgres.clone()))
+ .and(with_db(postgres.clone()))
+ .and_then(handlers::history::list);
+
+ let add_history = warp::post()
+ .and(warp::path("history"))
+ .and(warp::path::end())
+ .and(warp::body::json())
+ .and(with_user(postgres.clone()))
+ .and(with_db(postgres.clone()))
+ .and_then(handlers::history::add);
+
+ let user = warp::get()
+ .and(warp::path("user"))
+ .and(warp::path::param::<String>())
+ .and(warp::path::end())
+ .and(with_db(postgres.clone()))
+ .and_then(handlers::user::get);
+
+ let register = warp::post()
+ .and(warp::path("register"))
+ .and(warp::path::end())
+ .and(warp::body::json())
+ .and(with_settings(settings.clone()))
+ .and(with_db(postgres.clone()))
+ .and_then(handlers::user::register);
+
+ let login = warp::post()
+ .and(warp::path("login"))
+ .and(warp::path::end())
+ .and(warp::body::json())
+ .and(with_db(postgres))
+ .and_then(handlers::user::login);
+
+ let r = warp::any()
+ .and(
+ index
+ .or(count)
+ .or(sync)
+ .or(add_history)
+ .or(user)
+ .or(register)
+ .or(login),
+ )
+ .with(warp::filters::log::log("atuin::api"));
+
+ Ok(r)
+}
diff --git a/src/settings.rs b/src/settings.rs
index f3bc62e6..5325610e 100644
--- a/src/settings.rs
+++ b/src/settings.rs
@@ -161,7 +161,7 @@ impl Settings {
// Finally, set the auth token
if Path::new(session_path.to_string().as_str()).exists() {
let token = std::fs::read_to_string(session_path.to_string())?;
- s.set("local.session_token", token)?;
+ s.set("local.session_token", token.trim())?;
} else {
s.set("local.session_token", "not logged in")?;
}
diff --git a/src/shell/atuin.zsh b/src/shell/atuin.zsh
index d2abf3c1..d6d58f53 100644
--- a/src/shell/atuin.zsh
+++ b/src/shell/atuin.zsh
@@ -16,6 +16,7 @@ _atuin_precmd(){
[[ -z "${ATUIN_HISTORY_ID}" ]] && return
atuin history end $ATUIN_HISTORY_ID --exit $EXIT
+ export ATUIN_HISTORY_ID=""
}
_atuin_search(){