aboutsummaryrefslogtreecommitdiffstats
path: root/crates/turtle/src/atuin_client/api_client.rs
blob: bd5bf59e2d00edcb512f0686e5ecff5b8ed5e6ec (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
use std::env;
use std::time::Duration;

use eyre::{Result, bail, eyre};
use reqwest::{Response, StatusCode, Url, header::HeaderMap};
use tracing::debug;
use uuid::Uuid;

use crate::atuin_common::{api::ErrorResponse, record::RecordStatus};
use crate::atuin_common::{
    api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
    record::{EncryptedData, HostId, Record, RecordIdx},
    tls::ensure_crypto_provider,
};

use semver::Version;

static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);

pub(crate) struct Client<'a> {
    sync_addr: &'a str,
    user_id: Uuid,
    inner: reqwest::Client,
}

fn make_url(address: &str, path: &str, user_id: Uuid) -> Result<String> {
    let address = address.strip_suffix('/').unwrap_or(address);

    // `join()` expects a trailing `/` in order to join paths
    // e.g. it treats `http://host:port/subdir` as a file called `subdir`
    let address = &format!("{address}/api/v0/{user_id}/");

    // passing a path with a leading `/` will cause `join()` to replace the entire URL path
    let path = path.strip_prefix("/").unwrap_or(path);

    let url = Url::parse(address)
        .map(|url| url.join(path))?
        .map_err(|_| eyre!("invalid address"))?;

    Ok(url.to_string())
}

pub(crate) fn ensure_version(response: &Response) -> Result<bool> {
    let version = response.headers().get(ATUIN_HEADER_VERSION);

    let version = if let Some(version) = version {
        match version.to_str() {
            Ok(v) => Version::parse(v),
            Err(e) => bail!("failed to parse server version: {:?}", e),
        }
    } else {
        bail!("Server not reporting its version: it is either too old or unhealthy");
    }?;

    // If the client is newer than the server
    if version.major < ATUIN_VERSION.major {
        println!(
            "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"
        );
        println!("Client: {ATUIN_CARGO_VERSION}");
        println!("Server: {version}");

        return Ok(false);
    }

    Ok(true)
}

async fn handle_resp_error(resp: Response) -> Result<Response> {
    let status = resp.status();
    let url = resp.url().to_string();

    if status == StatusCode::SERVICE_UNAVAILABLE {
        bail!(
            "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
        );
    }

    if status == StatusCode::TOO_MANY_REQUESTS {
        bail!("Rate limited; please wait before doing that again");
    }

    if !status.is_success() {
        if let Ok(error) = resp.json::<ErrorResponse<'_>>().await {
            let reason = error.reason;

            if status.is_client_error() {
                bail!("Invalid request to the service at {url}, {status} - {reason}.")
            }

            bail!(
                "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host"
            )
        }

        bail!(
            "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host"
        )
    }

    Ok(resp)
}

impl<'a> Client<'a> {
    pub(crate) fn new(
        sync_addr: &'a str,
        connect_timeout: u64,
        timeout: u64,
        user_id: Uuid,
    ) -> Result<Self> {
        ensure_crypto_provider();
        let mut headers = HeaderMap::new();

        // used for semver server check
        headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);

        Ok(Client {
            user_id,
            sync_addr,
            inner: reqwest::Client::builder()
                .user_agent(APP_USER_AGENT)
                .default_headers(headers)
                .connect_timeout(Duration::new(connect_timeout, 0))
                .timeout(Duration::new(timeout, 0))
                .build()?,
        })
    }

    pub(crate) async fn delete_store(&self) -> Result<()> {
        let url = make_url(self.sync_addr, "/store", self.user_id)?;
        let url = Url::parse(url.as_str())?;

        let resp = self.inner.delete(url).send().await?;

        handle_resp_error(resp).await?;

        Ok(())
    }

    pub(crate) async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
        let url = make_url(self.sync_addr, "/record", self.user_id)?;
        let url = Url::parse(url.as_str())?;

        debug!("uploading {} records to {url}", records.len());

        let resp = self.inner.post(url).json(records).send().await?;
        handle_resp_error(resp).await?;

        Ok(())
    }

    pub(crate) async fn next_records(
        &self,
        host: HostId,
        tag: String,
        start: RecordIdx,
        count: u64,
    ) -> Result<Vec<Record<EncryptedData>>> {
        debug!("fetching record/s from host {}/{}/{}", host.0, tag, start);

        let url = make_url(
            self.sync_addr,
            &format!(
                "/record/next?host={}&tag={}&count={}&start={}",
                host.0, tag, count, start
            ),
            self.user_id,
        )?;

        let url = Url::parse(url.as_str())?;

        let resp = self.inner.get(url).send().await?;
        let resp = handle_resp_error(resp).await?;

        let records = resp.json::<Vec<Record<EncryptedData>>>().await?;

        Ok(records)
    }

    pub(crate) async fn record_status(&self) -> Result<RecordStatus> {
        let url = make_url(self.sync_addr, "/record", self.user_id)?;
        let url = Url::parse(url.as_str())?;

        let resp = self.inner.get(url).send().await?;
        let resp = handle_resp_error(resp).await?;

        if !ensure_version(&resp)? {
            bail!("could not sync records due to version mismatch");
        }

        let index = resp.json().await?;

        debug!("got remote index {index:?}");

        Ok(index)
    }
}