diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-10 13:24:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-10 20:24:57 +0000 |
| commit | 09279a428659cf41824737d3e0c97bcc19a8885a (patch) | |
| tree | 64731502c065df2483e8dd680d46c5559f3094f2 /crates/atuin-ai/src/stream.rs | |
| parent | feat: add strip_trailing_whitespace, on by default (#3390) (diff) | |
| download | atuin-09279a428659cf41824737d3e0c97bcc19a8885a.zip | |
feat: Client-tool execution + permission system (#3370)
Adds client-side tool execution to Atuin AI, starting with
`atuin_history`. The server can request tool calls, which are executed
locally with a permission system, and results are sent back to continue
the conversation.
Diffstat (limited to 'crates/atuin-ai/src/stream.rs')
| -rw-r--r-- | crates/atuin-ai/src/stream.rs | 372 |
1 files changed, 372 insertions, 0 deletions
diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs new file mode 100644 index 00000000..4673f2cd --- /dev/null +++ b/crates/atuin-ai/src/stream.rs @@ -0,0 +1,372 @@ +// ─────────────────────────────────────────────────────────────────── +// SSE streaming +// ─────────────────────────────────────────────────────────────────── + +use std::sync::mpsc; + +use atuin_client::settings::AiCapabilities; +use atuin_common::tls::ensure_crypto_provider; + +use eventsource_stream::Eventsource; +use eye_declare::Handle; +use eyre::{Context, Result}; +use futures::StreamExt; +use reqwest::Url; + +use crate::{ + context::{AppContext, ClientContext}, + tools::ClientToolCall, + tui::{Session, events::AiTuiEvent}, +}; + +/// Frames that alter the stream lifecycle — terminal or state-changing. +#[derive(Debug, Clone)] +pub(crate) enum StreamControl { + Done { session_id: String }, + Error(String), + StatusChanged(String), +} + +/// Frames that carry conversation content — they mutate the event log. +#[derive(Debug, Clone)] +pub(crate) enum StreamContent { + TextChunk(String), + ToolCall { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, +} + +/// A frame from the SSE stream, classified as control or content. +#[derive(Debug, Clone)] +pub(crate) enum StreamFrame { + Content(StreamContent), + Control(StreamControl), +} + +/// Per-turn request payload for the chat API. +pub(crate) struct ChatRequest { + pub messages: Vec<serde_json::Value>, + pub session_id: Option<String>, + pub capabilities: Vec<String>, +} + +impl ChatRequest { + pub(crate) fn new( + messages: Vec<serde_json::Value>, + session_id: Option<String>, + capabilities: &AiCapabilities, + ) -> Self { + let mut caps = vec![]; + if capabilities.enable_history_search.unwrap_or(true) { + caps.push("client_v1_atuin_history".to_string()); + } + if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { + caps.extend( + extra + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()), + ); + } + + Self { + messages, + session_id, + capabilities: caps, + } + } +} + +fn create_chat_stream( + hub_address: String, + token: String, + request: ChatRequest, + client_ctx: ClientContext, + send_cwd: bool, + last_command: Option<String>, +) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamFrame>> + Send>> { + Box::pin(async_stream::stream! { + ensure_crypto_provider(); + let endpoint = match hub_url(&hub_address, "/api/cli/chat") { + Ok(url) => url, + Err(e) => { + yield Err(e); + return; + } + }; + + tracing::debug!("Sending SSE request to {endpoint}"); + + let context = client_ctx.to_json(send_cwd, last_command.as_deref()); + + let mut request_body = serde_json::json!({ + "messages": request.messages, + "context": context, + "capabilities": request.capabilities, + }); + + if let Some(ref sid) = request.session_id { + tracing::trace!("Including session_id in request: {sid}"); + request_body["session_id"] = serde_json::json!(sid); + } + + let client = reqwest::Client::new(); + let response = match client + .post(endpoint.clone()) + .header("Accept", "text/event-stream") + .bearer_auth(&token) + .json(&request_body) + .send() + .await + { + Ok(resp) => resp, + Err(e) => { + yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); + return; + } + }; + + let status = response.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + tracing::error!("SSE request failed with status: {status}, clearing session"); + let _ = atuin_client::hub::delete_session().await; + yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); + return; + } + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + tracing::error!("SSE request failed ({}): {}", status, body); + yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); + return; + } + + let byte_stream = response.bytes_stream(); + let mut stream = byte_stream.eventsource(); + + while let Some(event) = stream.next().await { + match event { + Ok(sse_event) => { + let event_type = sse_event.event.as_str(); + let data = sse_event.data.clone(); + + tracing::debug!(event_type = %event_type, "SSE event received"); + + match event_type { + "text" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) + && let Some(content) = json.get("content").and_then(|v| v.as_str()) + { + yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string()))); + } + } + "tool_call" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); + yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input })); + } + } + "tool_result" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); + yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error })); + } + } + "status" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) + && let Some(state) = json.get("state").and_then(|v| v.as_str()) + { + yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string()))); + } + } + "done" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let session_id = json.get("session_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + yield Ok(StreamFrame::Control(StreamControl::Done { session_id })); + } else { + yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() })); + } + break; + } + "error" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); + tracing::error!("SSE error: {}", message); + yield Ok(StreamFrame::Control(StreamControl::Error(message))); + } else { + tracing::error!("SSE error: {}", data); + yield Ok(StreamFrame::Control(StreamControl::Error(data))); + } + break; + } + _ => {} + } + } + Err(e) => { + yield Err(eyre::eyre!("SSE error: {}", e)); + break; + } + } + } + }) +} + +// ─────────────────────────────────────────────────────────────────── +// Async streaming task — pushes updates to app state via Handle +// ─────────────────────────────────────────────────────────────────── + +pub(crate) async fn run_chat_stream( + handle: Handle<Session>, + tx: mpsc::Sender<AiTuiEvent>, + app_ctx: AppContext, + client_ctx: ClientContext, + request: ChatRequest, +) { + let capabilities = request.capabilities.clone(); + let stream = create_chat_stream( + app_ctx.endpoint.clone(), + app_ctx.token.clone(), + request, + client_ctx, + app_ctx.send_cwd, + app_ctx.last_command.clone(), + ); + futures::pin_mut!(stream); + + while let Some(event) = stream.next().await { + match event { + Ok(StreamFrame::Content(content)) => { + apply_content_frame(&handle, &tx, &capabilities, content); + } + Ok(StreamFrame::Control(control)) => { + let terminal = apply_control_frame(&handle, control); + if terminal { + break; + } + } + Err(e) => { + let msg = e.to_string(); + handle.update(move |state| { + state.streaming_error(msg); + }); + break; + } + } + } +} + +/// Apply a content frame to session state. +/// Control flow: always continues the stream. +fn apply_content_frame( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + capabilities: &[String], + content: StreamContent, +) { + match content { + StreamContent::TextChunk(text) => { + handle.update(move |state| { + state.conversation.append_streaming_text(&text); + }); + } + StreamContent::ToolCall { id, name, input } => { + if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { + // Enforce capability gating: reject tool calls the client didn't advertise. + if let Some(required_cap) = tool.descriptor().capability + && !capabilities.iter().any(|c| c == required_cap) + { + tracing::warn!( + tool = name, + capability = required_cap, + "Rejecting tool call: capability not advertised" + ); + handle.update(move |state| { + state.add_tool_call(id.clone(), name, input.clone()); + state.conversation.add_tool_result( + id, + format!("Tool not enabled: capability '{required_cap}' was not advertised by this client"), + true, + ); + }); + return; + } + + // Client-side tool — add to tracker and conversation, queue permission check + let id_for_event = id.clone(); + handle.update(move |state| { + state.handle_client_tool_call(id_for_event, tool, input); + }); + let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); + } else { + // Server-side tool — just add to conversation events + handle.update(move |state| { + state.add_tool_call(id, name, input); + }); + } + } + StreamContent::ToolResult { + tool_use_id, + content, + is_error, + } => { + handle.update(move |state| { + state + .conversation + .add_tool_result(tool_use_id, content, is_error); + }); + } + } +} + +/// Apply a control frame to session state. +/// Returns true if the stream should terminate. +fn apply_control_frame(handle: &Handle<Session>, control: StreamControl) -> bool { + match control { + StreamControl::StatusChanged(status) => { + handle.update(move |state| { + state.update_streaming_status(&status); + }); + false + } + StreamControl::Done { session_id } => { + handle.update(move |state| { + if !session_id.is_empty() { + state.conversation.store_session_id(session_id); + } + state.finalize_streaming(); + }); + true + } + StreamControl::Error(msg) => { + handle.update(move |state| { + state.streaming_error(msg); + }); + true + } + } +} + +fn hub_url(base: &str, path: &str) -> Result<Url> { + let base_with_slash = if base.ends_with('/') { + base.to_string() + } else { + format!("{base}/") + }; + let stripped = path.strip_prefix('/').unwrap_or(path); + Url::parse(&base_with_slash)? + .join(stripped) + .context("failed to build hub URL") +} |
