aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-client/src/import/powershell.rs
blob: 86fd007d0b507b2b4cb621c4604c5cbffb270626 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
use async_trait::async_trait;
use directories::BaseDirs;
use eyre::{Result, eyre};
use std::path::PathBuf;
use time::{Duration, OffsetDateTime};

use super::{Importer, Loader, count_lines, unix_byte_lines};
use crate::history::History;
use crate::import::read_to_end;

#[derive(Debug)]
pub struct PowerShell {
    bytes: Vec<u8>,
    line_count: Option<usize>,
}

fn get_history_path() -> Result<PathBuf> {
    let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?;

    // The command line history in PowerShell is maintained by the PSReadLine module:
    // https://learn.microsoft.com/en-us/powershell/module/psreadline/about/about_psreadline#command-history
    //
    // > PSReadLine maintains a history file containing all the commands and data you've entered from the command line.
    // > The history files are a file named `$($Host.Name)_history.txt`.
    // > On Windows systems the history file is stored at `$Env:APPDATA\Microsoft\Windows\PowerShell\PSReadLine`.
    // > On non-Windows systems, the history files are stored at `$Env:XDG_DATA_HOME/powershell/PSReadLine`
    // > or `$Env:HOME/.local/share/powershell/PSReadLine`.

    let dir = if cfg!(windows) {
        base.data_dir()
            .join("Microsoft")
            .join("Windows")
            .join("PowerShell")
            .join("PSReadLine")
    } else {
        std::env::var("XDG_DATA_HOME")
            .map_or_else(
                |_| base.home_dir().join(".local").join("share"),
                PathBuf::from,
            )
            .join("powershell")
            .join("PSReadLine")
    };

    // The history is stored in a file named `$($Host.Name)_history.txt`.
    // For the default console host shipped by Microsoft,`$Host.Name` is `ConsoleHost`:
    // https://learn.microsoft.com/en-us/dotnet/api/system.management.automation.host.pshost.name#remarks

    let file = dir.join("ConsoleHost_history.txt");

    if file.is_file() {
        Ok(file)
    } else {
        Err(eyre!("Could not find history file: {}", file.display()))
    }
}

#[async_trait]
impl Importer for PowerShell {
    const NAME: &'static str = "PowerShell";

    async fn new() -> Result<Self> {
        let bytes = read_to_end(get_history_path()?)?;
        Ok(Self {
            bytes,
            line_count: None,
        })
    }

    async fn entries(&mut self) -> Result<usize> {
        // Commands can be split over multiple lines,
        // but this is only used for a progress bar, and multi-line commands
        // should be quite rare, so this is not an issue in practice.
        if self.line_count.is_none() {
            self.line_count = Some(count_lines(&self.bytes));
        }
        Ok(self.line_count.unwrap())
    }

    async fn load(mut self, h: &mut impl Loader) -> Result<()> {
        let line_count = self.entries().await?;
        let start = OffsetDateTime::now_utc() - Duration::milliseconds(line_count as i64);

        let mut counter = 0;
        let mut iter = unix_byte_lines(&self.bytes);

        while let Some(s) = iter.next() {
            let Ok(s) = read_line(s) else {
                continue; // We can skip past things like invalid utf8
            };

            let mut cmd = s.to_string();

            // Multi-line commands end with a backtick, append the following lines.
            while cmd.ends_with('`') {
                cmd.pop();

                let Some(next) = iter.next() else {
                    break;
                };
                let Ok(next) = read_line(next) else {
                    break;
                };

                cmd.push('\n');
                cmd.push_str(next);
            }

            if cmd.is_empty() {
                continue;
            }

            let offset = Duration::milliseconds(counter);
            counter += 1;

            let entry = History::import().timestamp(start + offset).command(cmd);
            h.push(entry.build().into()).await?;
        }

        Ok(())
    }
}

fn read_line(s: &[u8]) -> Result<&str> {
    let s = str::from_utf8(s)?;

    // History is stored in CRLF on Windows, normalize the input to LF on all platforms.
    let s = s.strip_suffix('\r').unwrap_or(s);

    Ok(s)
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::import::tests::TestLoader;
    use itertools::assert_equal;

    const INPUT: &str = r#"cargo install atuin
cargo update
echo "first line`
second line`
`
last line"
echo foo

echo bar
echo baz
"#;

    const EXPECTED: &[&str] = &[
        "cargo install atuin",
        "cargo update",
        "echo \"first line\nsecond line\n\nlast line\"",
        "echo foo",
        "echo bar",
        "echo baz",
    ];

    #[tokio::test]
    async fn test_import() {
        let loader = import(INPUT).await;

        let actual = loader.buf.iter().map(|h| h.command.clone());
        let expected = EXPECTED.iter().map(|s| s.to_string());

        assert_equal(actual, expected);
    }

    #[tokio::test]
    async fn test_crlf() {
        let input = INPUT.replace("\n", "\r\n");
        let loader = import(input.as_str()).await;

        let actual = loader.buf.iter().map(|h| h.command.clone());
        let expected = EXPECTED.iter().map(|s| s.to_string());

        assert_equal(actual, expected);
    }

    #[tokio::test]
    async fn test_timestamps() {
        let loader = import(INPUT).await;

        let mut prev = loader.buf.first().unwrap().timestamp;
        for current in loader.buf.iter().skip(1).map(|h| h.timestamp) {
            assert!(current > prev);
            prev = current;
        }
    }

    async fn import(input: &str) -> TestLoader {
        let powershell = PowerShell {
            bytes: input.as_bytes().to_vec(),
            line_count: None,
        };

        let mut loader = TestLoader::default();
        powershell.load(&mut loader).await.unwrap();
        loader
    }
}