aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/user_context/walker.rs
blob: 117bbd330718c8d8461dbd021fc79e49e26ca567 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
//! Filesystem traversal for `TERMINAL.md` context files.
//!
//! Walks from the starting directory up to the filesystem root, checking for
//! `.atuin/TERMINAL.md` and `TERMINAL.md` at each level. Then checks the global
//! config directory. Returns files ordered from shallowest (global/root) to
//! deepest (most project-specific), so that context layers naturally from
//! general to specific.

use std::path::{Path, PathBuf};

use eyre::Result;
use tokio::task::JoinSet;

const CONTEXT_FILENAME: &str = "TERMINAL.md";

/// A context file found on disk, before interpolation.
#[derive(Debug)]
pub(crate) struct RawContextFile {
    pub path: PathBuf,
    pub content: String,
}

struct FoundFile {
    depth: usize,
    file: RawContextFile,
}

/// Walk from `start` up to the filesystem root collecting `TERMINAL.md`
/// context files, then check the global path. Returns files shallowest-first.
///
/// At each ancestor directory, checks two locations:
/// - `.atuin/TERMINAL.md` (dotdir-scoped)
/// - `TERMINAL.md` (project root)
pub(crate) async fn walk(start: &Path, global_path: Option<&Path>) -> Result<Vec<RawContextFile>> {
    let dirs: Vec<PathBuf> = start.ancestors().map(PathBuf::from).collect();
    let dir_count = dirs.len();

    let mut set: JoinSet<Result<Option<FoundFile>>> = JoinSet::new();

    for (index, dir) in dirs.into_iter().enumerate() {
        let dir2 = dir.clone();
        set.spawn(async move {
            load_context_file(&dir.join(".atuin").join(CONTEXT_FILENAME), index).await
        });
        set.spawn(async move { load_context_file(&dir2.join(CONTEXT_FILENAME), index).await });
    }

    if let Some(global) = global_path {
        let global = global.to_path_buf();
        let depth = dir_count;
        set.spawn(async move { load_context_file(&global, depth).await });
    }

    let mut found = Vec::new();
    while let Some(result) = set.join_next().await {
        match result? {
            Ok(Some(f)) => found.push(f),
            Ok(None) => {}
            Err(e) => {
                tracing::warn!("Error reading context file, skipping: {e}");
            }
        }
    }

    // Sort shallowest-first (highest depth index = shallowest ancestor).
    // The global file has the highest depth index so it sorts last... but we
    // actually want global first, then root → cwd. Reverse the depth ordering.
    found.sort_by_key(|b| std::cmp::Reverse(b.depth));

    Ok(found.into_iter().map(|f| f.file).collect())
}

/// The default global context file path (`~/.config/atuin/TERMINAL.md`).
pub(crate) fn global_context_path() -> PathBuf {
    atuin_common::utils::config_dir().join(CONTEXT_FILENAME)
}

async fn load_context_file(path: &Path, depth: usize) -> Result<Option<FoundFile>> {
    match tokio::fs::read_to_string(path).await {
        Ok(content) => Ok(Some(FoundFile {
            depth,
            file: RawContextFile {
                path: path.to_path_buf(),
                content,
            },
        })),
        Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
        Err(e) => Err(e.into()),
    }
}