diff options
Diffstat (limited to 'crates/atuin-ai/src/commands')
| -rw-r--r-- | crates/atuin-ai/src/commands/init.rs | 2 | ||||
| -rw-r--r-- | crates/atuin-ai/src/commands/inline.rs | 673 |
2 files changed, 199 insertions, 476 deletions
diff --git a/crates/atuin-ai/src/commands/init.rs b/crates/atuin-ai/src/commands/init.rs index 77abc4f4..f693d892 100644 --- a/crates/atuin-ai/src/commands/init.rs +++ b/crates/atuin-ai/src/commands/init.rs @@ -1,6 +1,6 @@ use crate::commands::detect_shell; -pub async fn run(shell: String) -> eyre::Result<()> { +pub(crate) async fn run(shell: String) -> eyre::Result<()> { let integration = match shell.as_str() { "zsh" => generate_zsh_integration(), "bash" => generate_bash_integration(), diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index aeb414fb..b37bb72f 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,38 +1,44 @@ use std::path::PathBuf; use std::sync::mpsc; -use crate::commands::detect_shell; +use crate::context::{AppContext, ClientContext}; +use crate::tui::dispatch; use crate::tui::events::AiTuiEvent; -use crate::tui::state::{AppState, ExitAction}; +use crate::tui::state::{ExitAction, Session}; use crate::tui::view::ai_view; use atuin_client::database::{Database, Sqlite}; -use atuin_client::distro::detect_linux_distribution; -use atuin_common::tls::ensure_crypto_provider; -use eventsource_stream::Eventsource; -use eye_declare::{Application, CtrlCBehavior, Handle}; +use eye_declare::{Application, CtrlCBehavior}; use eyre::{Context as _, Result, bail}; -use futures::StreamExt; -use reqwest::Url; -use tracing::{debug, error, info, trace}; +use tracing::{debug, info}; -pub async fn run( +pub(crate) async fn run( initial_command: Option<String>, api_endpoint: Option<String>, api_token: Option<String>, settings: &atuin_client::settings::Settings, output_for_hook: bool, ) -> Result<()> { - if !settings.ai.enabled.unwrap_or(false) { - emit_shell_result( - Action::Print( - "Atuin AI is not enabled. Please enable it in your settings or run `atuin setup`." - .to_string(), - ), - output_for_hook, - ); + if settings.ai.enabled == Some(false) { return Ok(()); } + if settings.ai.enabled.is_none() { + match prompt_ai_setup()? { + SetupChoice::EnableAi => { + set_ai_enabled(true).await?; + } + SetupChoice::DisableKeybind => { + set_ai_enabled(false).await?; + emit_shell_result(Action::Cancel, output_for_hook); + return Ok(()); + } + SetupChoice::Cancel => { + emit_shell_result(Action::Cancel, output_for_hook); + return Ok(()); + } + } + } + let endpoint = api_endpoint.as_deref().unwrap_or( settings .ai @@ -48,7 +54,36 @@ pub async fn run( ensure_hub_session(settings).await? }; - let action = run_inline_tui(endpoint.to_string(), token, initial_command, settings).await?; + let history_db_path = PathBuf::from(settings.db_path.as_str()); + let history_db = Sqlite::new(history_db_path, settings.local_timeout) + .await + .context("failed to open history database for AI")?; + + // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd + let send_cwd = + settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); + + let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { + history_db.last().await.ok().flatten().map(|h| h.command) + } else { + None + }; + + let git_root = std::env::current_dir() + .ok() + .and_then(|cwd| atuin_common::utils::in_git_repo(cwd.to_str()?)); + + let ctx = AppContext { + endpoint: endpoint.to_string(), + token, + send_cwd, + last_command, + history_db: std::sync::Arc::new(history_db), + git_root, + capabilities: settings.ai.capabilities.clone(), + }; + + let action = run_inline_tui(ctx, initial_command).await?; emit_shell_result(action, output_for_hook); Ok(()) @@ -69,7 +104,7 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu if will_sync { println!( "Once logged in, your shell history will be synchronized via Atuin Hub if auto_sync is enabled or when manually syncing." - ) + ); } println!( "If you have an existing Atuin sync account, you can log in with your existing credentials." @@ -111,279 +146,16 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu } // ─────────────────────────────────────────────────────────────────── -// SSE streaming -// ─────────────────────────────────────────────────────────────────── -#[derive(Debug, Clone)] -enum ChatStreamEvent { - TextChunk(String), - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - }, - Status(String), - Done { - session_id: String, - }, - Error(String), -} - -fn create_chat_stream( - hub_address: String, - token: String, - session_id: Option<String>, - messages: Vec<serde_json::Value>, - send_cwd: bool, - last_command: Option<String>, -) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<ChatStreamEvent>> + Send>> { - Box::pin(async_stream::stream! { - ensure_crypto_provider(); - let endpoint = match hub_url(&hub_address, "/api/cli/chat") { - Ok(url) => url, - Err(e) => { - yield Err(e); - return; - } - }; - - debug!("Sending SSE request to {endpoint}"); - - let os = detect_os(); - let shell = detect_shell(); - - let mut context = serde_json::json!({ - "os": os, - "shell": shell, - "pwd": if send_cwd { std::env::current_dir() - .ok() - .map(|path| path.to_string_lossy().into_owned()) } else { None }, - "last_command": last_command, - }); - - if os == "linux" { - context["distro"] = serde_json::json!(detect_linux_distribution()); - } - - let mut request_body = serde_json::json!({ - "messages": messages, - "context": context, - }); - - if let Some(ref sid) = session_id { - trace!("Including session_id in request: {sid}"); - request_body["session_id"] = serde_json::json!(sid); - } - - let client = reqwest::Client::new(); - let response = match client - .post(endpoint.clone()) - .header("Accept", "text/event-stream") - .bearer_auth(&token) - .json(&request_body) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); - return; - } - }; +async fn run_inline_tui(ctx: AppContext, initial_prompt: Option<String>) -> Result<Action> { + let client_ctx = ClientContext::detect(); - let status = response.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - error!("SSE request failed with status: {status}, clearing session"); - let _ = atuin_client::hub::delete_session().await; - yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); - return; - } - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - error!("SSE request failed ({}): {}", status, body); - yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); - return; - } - - let byte_stream = response.bytes_stream(); - let mut stream = byte_stream.eventsource(); - - while let Some(event) = stream.next().await { - match event { - Ok(sse_event) => { - let event_type = sse_event.event.as_str(); - let data = sse_event.data.clone(); - - debug!(event_type = %event_type, "SSE event received"); - - match event_type { - "text" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(content) = json.get("content").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::TextChunk(content.to_string())); - } - } - "tool_call" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(ChatStreamEvent::ToolCall { id, name, input }); - } - } - "tool_result" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error }); - } - } - "status" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(state) = json.get("state").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::Status(state.to_string())); - } - } - "done" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let session_id = json.get("session_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - yield Ok(ChatStreamEvent::Done { session_id }); - } else { - yield Ok(ChatStreamEvent::Done { session_id: String::new() }); - } - break; - } - "error" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); - error!("SSE error: {}", message); - yield Ok(ChatStreamEvent::Error(message)); - } else { - error!("SSE error: {}", data); - yield Ok(ChatStreamEvent::Error(data)); - } - break; - } - _ => {} - } - } - Err(e) => { - yield Err(eyre::eyre!("SSE error: {}", e)); - break; - } - } - } - }) -} - -// ─────────────────────────────────────────────────────────────────── -// Async streaming task — pushes updates to app state via Handle -// ─────────────────────────────────────────────────────────────────── - -async fn run_chat_stream( - handle: Handle<AppState>, - endpoint: String, - token: String, - session_id: Option<String>, - messages: Vec<serde_json::Value>, - send_cwd: bool, - last_command: Option<String>, -) { - let stream = create_chat_stream( - endpoint, - token, - session_id, - messages, - send_cwd, - last_command, - ); - futures::pin_mut!(stream); - - while let Some(event) = stream.next().await { - match event { - Ok(ChatStreamEvent::TextChunk(text)) => { - trace!(text = %text, "Processing TextChunk"); - handle.update(move |state| { - state.append_streaming_text(&text); - }); - } - Ok(ChatStreamEvent::ToolCall { id, name, input }) => { - trace!(id = %id, name = %name, "Processing ToolCall"); - handle.update(move |state| { - state.add_tool_call(id, name, input); - }); - } - Ok(ChatStreamEvent::ToolResult { - tool_use_id, - content, - is_error, - }) => { - trace!(tool_use_id = %tool_use_id, "Processing ToolResult"); - handle.update(move |state| { - state.add_tool_result(tool_use_id, content, is_error); - }); - } - Ok(ChatStreamEvent::Status(status)) => { - trace!(status = %status, "Processing Status"); - handle.update(move |state| { - state.update_streaming_status(&status); - }); - } - Ok(ChatStreamEvent::Done { session_id }) => { - trace!(session_id = %session_id, "Processing Done"); - handle.update(move |state| { - if !session_id.is_empty() { - state.store_session_id(session_id); - } - state.finalize_streaming(); - }); - break; - } - Ok(ChatStreamEvent::Error(msg)) => { - trace!(error = %msg, "Processing Error"); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - Err(e) => { - let msg = e.to_string(); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - } - } -} - -// ─────────────────────────────────────────────────────────────────── -// Main TUI entry point -// ─────────────────────────────────────────────────────────────────── + let (tx, rx) = mpsc::channel::<AiTuiEvent>(); -async fn run_inline_tui( - endpoint: String, - token: String, - initial_prompt: Option<String>, - settings: &atuin_client::settings::Settings, -) -> Result<Action> { - let initial_state = AppState::new(); + let initial_state = Session::new(ctx.git_root.is_some()); println!(); - let (tx, rx) = mpsc::channel::<AiTuiEvent>(); - // If there's an initial prompt, send it as a SubmitInput event // so it flows through the same path as user-typed input. if let Some(prompt) = initial_prompt { @@ -396,164 +168,17 @@ async fn run_inline_tui( .ctrl_c(CtrlCBehavior::Deliver) .keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced) .bracketed_paste(true) - .with_context(tx) + .with_context(tx.clone()) .extra_newlines_at_exit(1) .build()?; - // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd - let send_cwd = - settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); - - let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { - let db_path = PathBuf::from(settings.db_path.as_str()); - match Sqlite::new(db_path, settings.local_timeout).await { - Ok(db) => db.last().await.ok().flatten().map(|h| h.command), - Err(e) => { - debug!("Failed to open history database for read_history: {e}"); - None - } - } - } else { - None - }; - // Event loop: receives AiTuiEvent from components, mutates state via Handle. let h = handle.clone(); - let ep = endpoint.clone(); - let tk = token.clone(); tokio::task::spawn_blocking(move || { + let tx = tx.clone(); + let client_ctx = client_ctx; while let Ok(event) = rx.recv() { - match event { - AiTuiEvent::InputUpdated(input) => { - let input_blank = input.trim().is_empty(); - - h.update(move |state| { - state.is_input_blank = input_blank; - }); - } - AiTuiEvent::SubmitInput(input) => { - let input = input.trim().to_string(); - if input.is_empty() { - let h2 = h.clone(); - h.update(move |state| { - if state.has_any_command() { - state.exit_action = Some(ExitAction::Execute( - state.current_command().unwrap().to_string(), - )); - } else { - state.exit_action = Some(ExitAction::Cancel); - } - h2.exit(); - }); - continue; - } - - if input.starts_with('/') { - let input_clone = input.clone(); - h.update(move |state| { - state.handle_slash_command(&input_clone); - }); - continue; - } - - // Start generation and spawn streaming task - let ep = ep.clone(); - let tk = tk.clone(); - let h2 = h.clone(); - let lc = last_command.clone(); - h.update(move |state| { - state.start_generating(input); - state.start_streaming(); - state.is_input_blank = true; - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::SlashCommand(command) => { - h.update(move |state| { - state.handle_slash_command(&command); - }); - } - - AiTuiEvent::CancelGeneration => { - h.update(|state| match state.mode { - crate::tui::state::AppMode::Generating => { - state.cancel_generation(); - } - crate::tui::state::AppMode::Streaming => { - state.cancel_streaming(); - } - _ => {} - }); - } - - AiTuiEvent::ExecuteCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - if state.is_current_command_dangerous() && !state.confirmation_pending { - state.confirmation_pending = true; - } else { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Execute(cmd)); - h2.exit(); - } - } - }); - } - - AiTuiEvent::CancelConfirmation => { - h.update(move |state| { - state.confirmation_pending = false; - }); - } - - AiTuiEvent::InsertCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Insert(cmd)); - h2.exit(); - } - }); - } - - AiTuiEvent::Retry => { - let ep = ep.clone(); - let tk = tk.clone(); - let h2 = h.clone(); - let lc = last_command.clone(); - h.update(move |state| { - state.retry(); - state.start_streaming(); - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::Exit => { - let h2 = h.clone(); - h.update(move |state| { - if let Some(abort) = state.stream_abort.take() { - abort.abort(); - } - state.exit_action = Some(ExitAction::Cancel); - h2.exit(); - }); - } - } + dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx); } }); @@ -573,51 +198,125 @@ async fn run_inline_tui( // Helpers // ─────────────────────────────────────────────────────────────────── -fn hub_url(base: &str, path: &str) -> Result<Url> { - let base_with_slash = if base.ends_with('/') { - base.to_string() - } else { - format!("{base}/") - }; - let stripped = path.strip_prefix('/').unwrap_or(path); - Url::parse(&base_with_slash)? - .join(stripped) - .context("failed to build hub URL") +enum SetupChoice { + EnableAi, + DisableKeybind, + Cancel, } -fn detect_os() -> String { - match std::env::consts::OS { - "macos" => "macos".to_string(), - "linux" => "linux".to_string(), - "windows" => "windows".to_string(), - other => format!("Other: {other}"), +fn prompt_ai_setup() -> Result<SetupChoice> { + use crossterm::{ + cursor, + event::{self, Event, KeyCode}, + terminal, + }; + + let options = ["Enable Atuin AI", "Disable ? Keybind", "Cancel"]; + let mut selected: usize = 0; + let mut stdout = std::io::stdout(); + + // Print header before raw mode so newlines render correctly. + // Use stdout because the shell hook swaps stdout/stderr — stdout goes + // to the terminal in both hook and non-hook modes. + println!(); + println!(" Atuin AI is not yet configured."); + println!(); + + terminal::enable_raw_mode().context("failed to enable raw mode")?; + struct Guard; + impl Drop for Guard { + fn drop(&mut self) { + let _ = terminal::disable_raw_mode(); + } } -} + let _guard = Guard; -#[derive(Clone)] -enum Action { - Execute(String), - Insert(String), - Print(String), - Cancel, -} + crossterm::execute!(stdout, cursor::Hide)?; -fn emit_shell_result(action: Action, output_for_hook: bool) { - if output_for_hook { - match action { - Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), - Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), - Action::Print(output) => eprintln!("__atuin_ai_print__:{output}"), - Action::Cancel => eprintln!("__atuin_ai_cancel__"), + loop { + render_setup_options(&mut stdout, &options, selected)?; + + let ev = event::read().context("failed to read key event")?; + + crossterm::execute!(stdout, cursor::MoveUp(options.len() as u16))?; + + if let Event::Key(key) = ev { + match key.code { + KeyCode::Up | KeyCode::Char('k') => { + selected = selected.saturating_sub(1); + } + KeyCode::Down | KeyCode::Char('j') => { + if selected < options.len() - 1 { + selected += 1; + } + } + KeyCode::Enter => break, + KeyCode::Esc => { + selected = 2; + break; + } + _ => {} + } } - } else { - match action { - Action::Execute(output) => eprintln!("{output}"), - Action::Insert(output) => eprintln!("{output}"), - Action::Print(output) => eprintln!("{output}"), - Action::Cancel => eprintln!(), + } + + // Final render with selection visible + render_setup_options(&mut stdout, &options, selected)?; + crossterm::execute!(stdout, cursor::Show)?; + + Ok(match selected { + 0 => SetupChoice::EnableAi, + 1 => SetupChoice::DisableKeybind, + _ => SetupChoice::Cancel, + }) +} + +fn render_setup_options( + w: &mut impl std::io::Write, + options: &[&str], + selected: usize, +) -> Result<()> { + use crossterm::{ + style::Stylize, + terminal::{Clear, ClearType}, + }; + + for (i, option) in options.iter().enumerate() { + if i == selected { + write!(w, "\r {}", format!("> {option}").bold().cyan())?; + } else { + write!(w, "\r {option}")?; } + crossterm::execute!(w, Clear(ClearType::UntilNewLine))?; + write!(w, "\r\n")?; + } + w.flush()?; + Ok(()) +} + +async fn set_ai_enabled(enabled: bool) -> Result<()> { + let config_file = atuin_client::settings::Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc = config_str.parse::<toml_edit::DocumentMut>()?; + + if !doc.contains_key("ai") { + doc["ai"] = toml_edit::table(); + } + doc["ai"]["enabled"] = toml_edit::value(enabled); + + tokio::fs::write(&config_file, doc.to_string()).await?; + + if !enabled { + println!( + "Atuin AI keybind disabled. You can re-enable with `atuin config set ai.enabled true`.", + ); + println!("Restart your shell for changes to take effect."); + // Two printlns to ensure the message is visible above the shell prompt after program ends. + println!(); + println!(); } + + Ok(()) } fn wait_for_login_confirmation() -> Result<bool> { @@ -646,3 +345,27 @@ fn wait_for_login_confirmation() -> Result<bool> { } } } + +#[derive(Clone)] +enum Action { + Execute(String), + Insert(String), + Cancel, +} + +fn emit_shell_result(action: Action, output_for_hook: bool) { + if output_for_hook { + match action { + Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), + Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), + Action::Cancel => eprintln!("__atuin_ai_cancel__"), + } + } else { + match action { + Action::Execute(output) | Action::Insert(output) => { + println!("{output}"); + } + Action::Cancel => {} + } + } +} |
