diff options
Diffstat (limited to 'crates/atuin-ai/src/stream.rs')
| -rw-r--r-- | crates/atuin-ai/src/stream.rs | 288 |
1 files changed, 0 insertions, 288 deletions
diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs deleted file mode 100644 index e78dc2e1..00000000 --- a/crates/atuin-ai/src/stream.rs +++ /dev/null @@ -1,288 +0,0 @@ -// ─────────────────────────────────────────────────────────────────── -// SSE streaming -// ─────────────────────────────────────────────────────────────────── - -use atuin_client::history::History; -use atuin_client::settings::AiCapabilities; - -use crate::context::history_output_capability_available; -use atuin_common::tls::ensure_crypto_provider; - -use eventsource_stream::Eventsource; -use eyre::{Context, Result}; -use futures::StreamExt; -use reqwest::Url; -use reqwest::header::USER_AGENT; - -use crate::context::ClientContext; - -static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); - -/// Frames that alter the stream lifecycle — terminal or state-changing. -#[derive(Debug, Clone)] -pub(crate) enum StreamControl { - Done { session_id: String }, - Error(String), - StatusChanged(String), -} - -/// Frames that carry conversation content — they mutate the event log. -#[derive(Debug, Clone)] -pub(crate) enum StreamContent { - TextChunk(String), - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - remote: bool, - content_length: Option<usize>, - }, -} - -/// A frame from the SSE stream, classified as control or content. -#[derive(Debug, Clone)] -pub(crate) enum StreamFrame { - Content(StreamContent), - Control(StreamControl), -} - -/// Per-turn request payload for the chat API. -pub(crate) struct ChatRequest { - pub messages: Vec<serde_json::Value>, - pub session_id: Option<String>, - pub capabilities: Vec<String>, - pub invocation_id: String, -} - -impl ChatRequest { - pub(crate) fn new( - messages: Vec<serde_json::Value>, - session_id: Option<String>, - capabilities: &AiCapabilities, - history_output_available: bool, - invocation_id: String, - ) -> Self { - let mut caps = vec![ - "client_invocations".to_string(), - "client_v1_load_skill".to_string(), - ]; - if capabilities.enable_history_search.unwrap_or(true) { - caps.push("client_v1_atuin_history".to_string()); - } - if capabilities.enable_file_tools.unwrap_or(true) { - caps.push("client_v1_read_file".to_string()); - caps.push("client_v1_edit_file".to_string()); - caps.push("client_v1_write_file".to_string()); - } - if capabilities.enable_command_execution.unwrap_or(true) { - caps.push("client_v1_execute_shell_command".to_string()); - } - if history_output_capability_available(history_output_available) - && capabilities.enable_history_output.unwrap_or(true) - { - caps.push("client_v1_atuin_output".to_string()); - } - if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { - caps.extend( - extra - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()), - ); - } - - Self { - messages, - session_id, - capabilities: caps, - invocation_id, - } - } -} - -#[allow(clippy::too_many_arguments)] -pub(crate) fn create_chat_stream( - hub_address: String, - token: String, - request: ChatRequest, - client_ctx: ClientContext, - send_cwd: bool, - last_command: Option<History>, - user_contexts: Vec<crate::user_context::UserContext>, - skill_summaries: Vec<crate::skills::SkillSummary>, - skill_overflow: Option<String>, -) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamFrame>> + Send>> { - Box::pin(async_stream::stream! { - ensure_crypto_provider(); - let endpoint = match hub_url(&hub_address, "/api/cli/chat") { - Ok(url) => url, - Err(e) => { - yield Err(e); - return; - } - }; - - tracing::debug!("Sending SSE request to {endpoint}"); - - let context = client_ctx.to_json(send_cwd, last_command.as_ref()); - - let mut config = serde_json::json!({ - "capabilities": request.capabilities, - }); - - if !user_contexts.is_empty() { - config["user_contexts"] = serde_json::json!(user_contexts); - } - - if !skill_summaries.is_empty() { - config["skills"] = serde_json::json!(skill_summaries); - if let Some(ref overflow) = skill_overflow { - config["skills_overflow"] = serde_json::json!(overflow); - } - } - - if let Ok(model) = std::env::var("ATUIN_AI__MODEL") - && !model.trim().is_empty() { - config["model"] = serde_json::json!(model.trim()); - - } - - - let mut request_body = serde_json::json!({ - "messages": request.messages, - "context": context, - "config": config, - "invocation_id": request.invocation_id - }); - - if let Some(ref sid) = request.session_id { - tracing::trace!("Including session_id in request: {sid}"); - request_body["session_id"] = serde_json::json!(sid); - } - - let client = reqwest::Client::new(); - let response = match client - .post(endpoint.clone()) - .header("Accept", "text/event-stream") - .header(USER_AGENT, APP_USER_AGENT) - .bearer_auth(&token) - .json(&request_body) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); - return; - } - }; - - let status = response.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - tracing::error!("SSE request failed with status: {status}, clearing session"); - let _ = atuin_client::hub::delete_session().await; - yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); - return; - } - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - tracing::error!("SSE request failed ({}): {}", status, body); - yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); - return; - } - - let byte_stream = response.bytes_stream(); - let mut stream = byte_stream.eventsource(); - - while let Some(event) = stream.next().await { - match event { - Ok(sse_event) => { - let event_type = sse_event.event.as_str(); - let data = sse_event.data.clone(); - - tracing::debug!(event_type = %event_type, "SSE event received"); - - match event_type { - "text" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(content) = json.get("content").and_then(|v| v.as_str()) - { - yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string()))); - } - } - "tool_call" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input })); - } - } - "tool_result" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - let remote = json.get("remote").and_then(|v| v.as_bool()).unwrap_or(false); - let content_length = json.get("content_length").and_then(|v| v.as_u64()).map(|v| v as usize); - yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error, remote, content_length })); - } - } - "status" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(state) = json.get("state").and_then(|v| v.as_str()) - { - yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string()))); - } - } - "done" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let session_id = json.get("session_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - yield Ok(StreamFrame::Control(StreamControl::Done { session_id })); - } else { - yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() })); - } - break; - } - "error" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); - tracing::error!("SSE error: {}", message); - yield Ok(StreamFrame::Control(StreamControl::Error(message))); - } else { - tracing::error!("SSE error: {}", data); - yield Ok(StreamFrame::Control(StreamControl::Error(data))); - } - break; - } - _ => {} - } - } - Err(e) => { - yield Err(eyre::eyre!("SSE error: {}", e)); - break; - } - } - } - }) -} - -fn hub_url(base: &str, path: &str) -> Result<Url> { - let base_with_slash = if base.ends_with('/') { - base.to_string() - } else { - format!("{base}/") - }; - let stripped = path.strip_prefix('/').unwrap_or(path); - Url::parse(&base_with_slash)? - .join(stripped) - .context("failed to build hub URL") -} |
