use std::env; use std::time::Duration; use eyre::{Result, bail, eyre}; use reqwest::{ Response, StatusCode, Url, header::{AUTHORIZATION, HeaderMap}, }; use tracing::debug; use crate::atuin_common::{ api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, record::{EncryptedData, HostId, Record, RecordIdx}, tls::ensure_crypto_provider, }; use crate::atuin_common::{ api::{ErrorResponse, MeResponse}, record::RecordStatus, }; use semver::Version; static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); /// Authentication token for sync API requests. /// /// Used with `Token ` header. #[derive(Debug, Clone)] pub(crate) struct AuthToken(pub(crate) String); impl AuthToken { /// Format the token as an Authorization header value fn to_header_value(&self) -> String { format!("Token {}", self.0) } } pub(crate) struct Client<'a> { sync_addr: &'a str, client: reqwest::Client, } fn make_url(address: &str, path: &str) -> Result { // `join()` expects a trailing `/` in order to join paths // e.g. it treats `http://host:port/subdir` as a file called `subdir` let address = if address.ends_with('/') { address } else { &format!("{address}/") }; // passing a path with a leading `/` will cause `join()` to replace the entire URL path let path = path.strip_prefix("/").unwrap_or(path); let url = Url::parse(address) .map(|url| url.join(path))? .map_err(|_| eyre!("invalid address"))?; Ok(url.to_string()) } pub(crate) fn ensure_version(response: &Response) -> Result { let version = response.headers().get(ATUIN_HEADER_VERSION); let version = if let Some(version) = version { match version.to_str() { Ok(v) => Version::parse(v), Err(e) => bail!("failed to parse server version: {:?}", e), } } else { bail!("Server not reporting its version: it is either too old or unhealthy"); }?; // If the client is newer than the server if version.major < ATUIN_VERSION.major { println!( "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin" ); println!("Client: {ATUIN_CARGO_VERSION}"); println!("Server: {version}"); return Ok(false); } Ok(true) } async fn handle_resp_error(resp: Response) -> Result { let status = resp.status(); let url = resp.url().to_string(); if status == StatusCode::SERVICE_UNAVAILABLE { bail!( "Service unavailable: check https://status.atuin.sh (or get in touch with your host)" ); } if status == StatusCode::TOO_MANY_REQUESTS { bail!("Rate limited; please wait before doing that again"); } if !status.is_success() { if let Ok(error) = resp.json::().await { let reason = error.reason; if status.is_client_error() { bail!("Invalid request to the service at {url}, {status} - {reason}.") } bail!( "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host" ) } bail!( "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host" ) } Ok(resp) } impl<'a> Client<'a> { pub(crate) fn new( sync_addr: &'a str, auth: AuthToken, connect_timeout: u64, timeout: u64, ) -> Result { ensure_crypto_provider(); let mut headers = HeaderMap::new(); headers.insert(AUTHORIZATION, auth.to_header_value().parse()?); // used for semver server check headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); Ok(Client { sync_addr, client: reqwest::Client::builder() .user_agent(APP_USER_AGENT) .default_headers(headers) .connect_timeout(Duration::new(connect_timeout, 0)) .timeout(Duration::new(timeout, 0)) .build()?, }) } pub(crate) async fn me(&self) -> Result { let url = make_url(self.sync_addr, "/api/v0/me")?; let url = Url::parse(url.as_str())?; let resp = self.client.get(url).send().await?; let resp = handle_resp_error(resp).await?; let status = resp.json::().await?; Ok(status) } pub(crate) async fn delete_store(&self) -> Result<()> { let url = make_url(self.sync_addr, "/api/v0/store")?; let url = Url::parse(url.as_str())?; let resp = self.client.delete(url).send().await?; handle_resp_error(resp).await?; Ok(()) } pub(crate) async fn post_records(&self, records: &[Record]) -> Result<()> { let url = make_url(self.sync_addr, "/api/v0/record")?; let url = Url::parse(url.as_str())?; debug!("uploading {} records to {url}", records.len()); let resp = self.client.post(url).json(records).send().await?; handle_resp_error(resp).await?; Ok(()) } pub(crate) async fn next_records( &self, host: HostId, tag: String, start: RecordIdx, count: u64, ) -> Result>> { debug!("fetching record/s from host {}/{}/{}", host.0, tag, start); let url = make_url( self.sync_addr, &format!( "/api/v0/record/next?host={}&tag={}&count={}&start={}", host.0, tag, count, start ), )?; let url = Url::parse(url.as_str())?; let resp = self.client.get(url).send().await?; let resp = handle_resp_error(resp).await?; let records = resp.json::>>().await?; Ok(records) } pub(crate) async fn record_status(&self) -> Result { let url = make_url(self.sync_addr, "/api/v0/record")?; let url = Url::parse(url.as_str())?; let resp = self.client.get(url).send().await?; let resp = handle_resp_error(resp).await?; if !ensure_version(&resp)? { bail!("could not sync records due to version mismatch"); } let index = resp.json().await?; debug!("got remote index {index:?}"); Ok(index) } }