aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/tui/view/turn.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-ai/src/tui/view/turn.rs')
-rw-r--r--crates/atuin-ai/src/tui/view/turn.rs111
1 files changed, 67 insertions, 44 deletions
diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs
index 9f4460eb..c74395b8 100644
--- a/crates/atuin-ai/src/tui/view/turn.rs
+++ b/crates/atuin-ai/src/tui/view/turn.rs
@@ -204,16 +204,23 @@ pub(crate) enum ToolResultStatus {
}
#[derive(Debug)]
-pub(crate) enum UiTurn {
+pub(crate) struct UiTurn {
+ pub(crate) id: usize,
+ pub(crate) kind: UiTurnKind,
+}
+
+#[derive(Debug)]
+pub(crate) enum UiTurnKind {
User { events: Vec<UiEvent> },
Agent { events: Vec<UiEvent> },
OutOfBand { events: Vec<UiEvent> },
}
pub(crate) struct TurnBuilder<'a> {
- turns: Vec<UiTurn>,
- current_turn: Option<UiTurn>,
+ turns: Vec<UiTurnKind>,
+ current_turn: Option<UiTurnKind>,
tracker: &'a ToolManager,
+ next_id: usize,
}
/// A struct to iteratively build [UiTurn] events from [ConversationEvent]s.
@@ -223,6 +230,16 @@ impl<'a> TurnBuilder<'a> {
turns: Vec::new(),
current_turn: None,
tracker,
+ next_id: 0,
+ }
+ }
+
+ pub(crate) fn new_starting_at(tracker: &'a ToolManager, start_id: usize) -> Self {
+ Self {
+ turns: Vec::new(),
+ current_turn: None,
+ tracker,
+ next_id: start_id,
}
}
@@ -280,7 +297,7 @@ impl<'a> TurnBuilder<'a> {
// into a ToolGroup (e.g. N file reads → one group)
// - All other events pass through unchanged
for turn in &mut self.turns {
- if let UiTurn::Agent { events } = turn {
+ if let UiTurnKind::Agent { events } = turn {
let mut new_events: Vec<UiEvent> = Vec::new();
let mut pending_remote: Vec<ToolCallDetails> = Vec::new();
let mut pending_group: Option<(ToolGroupKind, Vec<ToolCallDetails>)> = None;
@@ -322,7 +339,15 @@ impl<'a> TurnBuilder<'a> {
}
}
- std::mem::take(&mut self.turns)
+ let kinds = std::mem::take(&mut self.turns);
+ kinds
+ .into_iter()
+ .enumerate()
+ .map(|(i, kind)| UiTurn {
+ id: self.next_id + i,
+ kind,
+ })
+ .collect()
}
fn commit_turn(&mut self) {
@@ -332,37 +357,39 @@ impl<'a> TurnBuilder<'a> {
}
fn start_user_turn(&mut self) {
- if !matches!(self.current_turn, Some(UiTurn::User { .. })) {
+ if !matches!(self.current_turn, Some(UiTurnKind::User { .. })) {
self.commit_turn();
- self.current_turn = Some(UiTurn::User { events: vec![] });
+ self.current_turn = Some(UiTurnKind::User { events: vec![] });
}
}
fn start_agent_turn(&mut self) {
- if !matches!(self.current_turn, Some(UiTurn::Agent { .. })) {
+ if !matches!(self.current_turn, Some(UiTurnKind::Agent { .. })) {
self.commit_turn();
- self.current_turn = Some(UiTurn::Agent { events: vec![] });
+ self.current_turn = Some(UiTurnKind::Agent { events: vec![] });
}
}
fn start_out_of_band_turn(&mut self) {
- if !matches!(self.current_turn, Some(UiTurn::OutOfBand { .. })) {
+ if !matches!(self.current_turn, Some(UiTurnKind::OutOfBand { .. })) {
self.commit_turn();
- self.current_turn = Some(UiTurn::OutOfBand { events: vec![] });
+ self.current_turn = Some(UiTurnKind::OutOfBand { events: vec![] });
}
}
- fn turn_mut_unsafe(&mut self) -> &mut UiTurn {
- self.current_turn.as_mut().unwrap()
+ fn current_events_mut(&mut self) -> &mut Vec<UiEvent> {
+ match self.current_turn.as_mut().unwrap() {
+ UiTurnKind::User { events }
+ | UiTurnKind::Agent { events }
+ | UiTurnKind::OutOfBand { events } => events,
+ }
}
fn add_user_message(&mut self, content: &str) {
self.start_user_turn();
- if let UiTurn::User { events } = self.turn_mut_unsafe() {
- events.push(UiEvent::Text {
- content: content.to_string(),
- });
- }
+ self.current_events_mut().push(UiEvent::Text {
+ content: content.to_string(),
+ });
}
fn add_agent_text(&mut self, content: &str) {
@@ -370,11 +397,9 @@ impl<'a> TurnBuilder<'a> {
return;
}
self.start_agent_turn();
- if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
- events.push(UiEvent::Text {
- content: content.to_string(),
- });
- }
+ self.current_events_mut().push(UiEvent::Text {
+ content: content.to_string(),
+ });
}
fn add_suggested_command(&mut self, input: &serde_json::Value) {
@@ -389,7 +414,8 @@ impl<'a> TurnBuilder<'a> {
}
self.start_agent_turn();
- if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
+ {
+ let events = self.current_events_mut();
let danger_level = input
.get("danger")
.and_then(|v| v.as_str())
@@ -433,14 +459,13 @@ impl<'a> TurnBuilder<'a> {
let render_data = self.build_render_data(id, name);
self.start_agent_turn();
- if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
- events.push(UiEvent::ToolCall(ToolCallDetails {
+ self.current_events_mut()
+ .push(UiEvent::ToolCall(ToolCallDetails {
tool_use_id: id.to_string(),
name: name.to_string(),
status: ToolResultStatus::Pending,
render_data,
}));
- }
}
/// Build tool-type-specific render data from the ToolTracker.
@@ -482,31 +507,29 @@ impl<'a> TurnBuilder<'a> {
fn add_tool_result(&mut self, tool_use_id: &str, _content: &str, is_error: bool) {
self.start_agent_turn();
- if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
- let event = events.iter_mut().find(|e| match e {
- UiEvent::ToolCall(ToolCallDetails {
- tool_use_id: id, ..
- }) => id == tool_use_id,
- _ => false,
- });
- if let Some(UiEvent::ToolCall(ToolCallDetails { status, .. })) = event {
- *status = if is_error {
- ToolResultStatus::Error
- } else {
- ToolResultStatus::Success
- };
- }
+ let events = self.current_events_mut();
+ let event = events.iter_mut().find(|e| match e {
+ UiEvent::ToolCall(ToolCallDetails {
+ tool_use_id: id, ..
+ }) => id == tool_use_id,
+ _ => false,
+ });
+ if let Some(UiEvent::ToolCall(ToolCallDetails { status, .. })) = event {
+ *status = if is_error {
+ ToolResultStatus::Error
+ } else {
+ ToolResultStatus::Success
+ };
}
}
fn add_out_of_band_output(&mut self, _name: &str, command: Option<&str>, content: &str) {
self.start_out_of_band_turn();
- if let UiTurn::OutOfBand { events } = self.turn_mut_unsafe() {
- events.push(UiEvent::OutOfBandOutput(OutOfBandOutputDetails {
+ self.current_events_mut()
+ .push(UiEvent::OutOfBandOutput(OutOfBandOutputDetails {
command: command.map(|c| c.to_string()),
content: content.to_string(),
}));
- }
}
}