aboutsummaryrefslogtreecommitdiffstats
path: root/atuin-server/src/handlers/user.rs
blob: 1bcfce2fde908067db884b656ece16d8748c205e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use std::borrow::Borrow;

use atuin_common::api::*;
use atuin_common::utils::hash_secret;
use axum::extract::Path;
use axum::{Extension, Json};
use http::StatusCode;
use sodiumoxide::crypto::pwhash::argon2id13;
use uuid::Uuid;

use crate::database::{Database, Postgres};
use crate::models::{NewSession, NewUser};
use crate::settings::Settings;

pub fn verify_str(secret: &str, verify: &str) -> bool {
    sodiumoxide::init().unwrap();

    let mut padded = [0_u8; 128];
    secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
        padded[i] = *val;
    });

    match argon2id13::HashedPassword::from_slice(&padded) {
        Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
        None => false,
    }
}

pub async fn get(
    Path(username): Path<String>,
    db: Extension<Postgres>,
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
    let user = match db.get_user(username.as_ref()).await {
        Ok(user) => user,
        Err(e) => {
            debug!("user not found: {}", e);
            return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
        }
    };

    Ok(Json(UserResponse {
        username: user.username,
    }))
}

pub async fn register(
    Json(register): Json<RegisterRequest>,
    settings: Extension<Settings>,
    db: Extension<Postgres>,
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
    if !settings.open_registration {
        return Err(
            ErrorResponse::reply("this server is not open for registrations")
                .with_status(StatusCode::BAD_REQUEST),
        );
    }

    let hashed = hash_secret(&register.password);

    let new_user = NewUser {
        email: register.email,
        username: register.username,
        password: hashed,
    };

    let user_id = match db.add_user(&new_user).await {
        Ok(id) => id,
        Err(e) => {
            error!("failed to add user: {}", e);
            return Err(
                ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST)
            );
        }
    };

    let token = Uuid::new_v4().to_simple().to_string();

    let new_session = NewSession {
        user_id,
        token: (&token).into(),
    };

    match db.add_session(&new_session).await {
        Ok(_) => Ok(Json(RegisterResponse { session: token })),
        Err(e) => {
            error!("failed to add session: {}", e);
            Err(ErrorResponse::reply("failed to register user")
                .with_status(StatusCode::BAD_REQUEST))
        }
    }
}

pub async fn login(
    login: Json<LoginRequest>,
    db: Extension<Postgres>,
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
    let user = match db.get_user(login.username.borrow()).await {
        Ok(u) => u,
        Err(e) => {
            error!("failed to get user {}: {}", login.username.clone(), e);

            return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
        }
    };

    let session = match db.get_user_session(&user).await {
        Ok(u) => u,
        Err(e) => {
            error!("failed to get session for {}: {}", login.username, e);

            return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
        }
    };

    let verified = verify_str(user.password.as_str(), login.password.borrow());

    if !verified {
        return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
    }

    Ok(Json(LoginResponse {
        session: session.token,
    }))
}