From dfb79989b67ceb70ebd4650acfd23af1c384949e Mon Sep 17 00:00:00 2001 From: Ben Gao Date: Sat, 6 Jun 2026 09:32:08 +0800 Subject: [PATCH] feat(runtime-api): add session save, undo/retry, and snapshot endpoints for GUI --- crates/tui/src/core/engine.rs | 17 +- crates/tui/src/core/engine/handle.rs | 10 + crates/tui/src/core/ops.rs | 19 ++ crates/tui/src/runtime_api.rs | 419 ++++++++++++++++++++++++++- crates/tui/src/runtime_threads.rs | 46 +++ 5 files changed, 509 insertions(+), 2 deletions(-) diff --git a/crates/tui/src/core/engine.rs b/crates/tui/src/core/engine.rs index fa2146171..3a40d5319 100644 --- a/crates/tui/src/core/engine.rs +++ b/crates/tui/src/core/engine.rs @@ -66,7 +66,7 @@ use super::capacity_memory::{ }; use super::coherence::{CoherenceSignal, CoherenceState, next_coherence_state}; use super::events::{Event, TurnOutcomeStatus}; -use super::ops::{Op, USER_SHELL_TOOL_ID_PREFIX}; +use super::ops::{Op, SessionSnapshot, USER_SHELL_TOOL_ID_PREFIX}; use super::session::Session; use super::tool_parser; use super::turn::{TurnContext, TurnToolCall, post_turn_snapshot, pre_turn_snapshot}; @@ -1264,6 +1264,21 @@ impl Engine { Op::CompactContext => { self.handle_manual_compaction().await; } + Op::GetSessionSnapshot { tx } => { + let total_tokens = self.session.total_usage.input_tokens + + self.session.total_usage.output_tokens; + let snapshot = SessionSnapshot { + messages: self.session.messages.clone(), + total_tokens, + model: self.session.model.clone(), + workspace: self.session.workspace.clone(), + system_prompt: self.session.system_prompt.clone(), + mode: self.current_mode.as_setting().to_string(), + }; + if let Some(tx) = tx.lock().ok().and_then(|mut g| g.take()) { + let _ = tx.send(snapshot); + } + } Op::PurgeContext => { self.handle_purge().await; } diff --git a/crates/tui/src/core/engine/handle.rs b/crates/tui/src/core/engine/handle.rs index 1ed7e95d3..1ddb98b80 100644 --- a/crates/tui/src/core/engine/handle.rs +++ b/crates/tui/src/core/engine/handle.rs @@ -110,4 +110,14 @@ impl EngineHandle { self.tx_steer.send(content.into()).await?; Ok(()) } + + /// Request a snapshot of the current session state. + /// Returns the snapshot directly via a oneshot channel, avoiding + /// competition with the SSE event stream on the mpsc receiver. + pub async fn get_session_snapshot(&self) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + let tx = std::sync::Arc::new(std::sync::Mutex::new(Some(tx))); + self.send(Op::GetSessionSnapshot { tx }).await?; + rx.await.map_err(|_| anyhow::anyhow!("Engine dropped session snapshot oneshot")) + } } diff --git a/crates/tui/src/core/ops.rs b/crates/tui/src/core/ops.rs index 4260cf0c8..b51e67758 100644 --- a/crates/tui/src/core/ops.rs +++ b/crates/tui/src/core/ops.rs @@ -12,6 +12,18 @@ use std::path::PathBuf; /// Prefix used for tool-call ids created by local composer shell shortcuts. pub const USER_SHELL_TOOL_ID_PREFIX: &str = "user_shell_"; +/// Snapshot of session state for saving to disk. +/// Returned by `Op::GetSessionSnapshot` via a oneshot channel. +#[derive(Debug, Clone)] +pub struct SessionSnapshot { + pub messages: Vec, + pub total_tokens: u64, + pub model: String, + pub workspace: PathBuf, + pub system_prompt: Option, + pub mode: String, +} + /// Operations that can be submitted to the engine. #[derive(Debug, Clone)] pub enum Op { @@ -97,6 +109,13 @@ pub enum Op { /// Run context compaction immediately. CompactContext, + /// Get a snapshot of the current session state (messages, tokens, etc.) + /// for saving to disk. Returns the result via the oneshot sender so + /// the caller doesn't have to compete with the SSE event stream. + GetSessionSnapshot { + tx: std::sync::Arc>>>, + }, + /// Run agent-driven context purging. PurgeContext, diff --git a/crates/tui/src/runtime_api.rs b/crates/tui/src/runtime_api.rs index 523b7f3ac..fbcf4f239 100644 --- a/crates/tui/src/runtime_api.rs +++ b/crates/tui/src/runtime_api.rs @@ -514,7 +514,7 @@ pub async fn run_http_server( pub fn build_router(state: RuntimeApiState) -> Router { let api_routes = Router::new() - .route("/v1/sessions", get(list_sessions)) + .route("/v1/sessions", get(list_sessions).post(save_current_session)) .route("/v1/sessions/{id}", get(get_session).delete(delete_session)) .route( "/v1/sessions/{id}/resume-thread", @@ -527,6 +527,9 @@ pub fn build_router(state: RuntimeApiState) -> Router { .route("/v1/threads/{id}", get(get_thread).patch(update_thread)) .route("/v1/threads/{id}/resume", post(resume_thread)) .route("/v1/threads/{id}/fork", post(fork_thread)) + .route("/v1/threads/{id}/undo", post(undo_thread_turn)) + .route("/v1/threads/{id}/patch-undo", post(patch_undo_thread_turn)) + .route("/v1/threads/{id}/retry", post(retry_thread_turn)) .route("/v1/threads/{id}/turns", post(start_thread_turn)) .route( "/v1/threads/{id}/turns/{turn_id}/steer", @@ -565,6 +568,8 @@ pub fn build_router(state: RuntimeApiState) -> Router { .route("/v1/automations/{id}/resume", post(resume_automation)) .route("/v1/automations/{id}/runs", get(list_automation_runs)) .route("/v1/usage", get(get_usage)) + .route("/v1/snapshots", get(list_snapshots)) + .route("/v1/snapshots/{id}/restore", post(restore_snapshot)) .route_layer(middleware::from_fn_with_state( state.clone(), require_runtime_token, @@ -774,6 +779,113 @@ async fn list_sessions( Ok(Json(SessionsResponse { sessions })) } +#[derive(Debug, Deserialize)] +struct SaveSessionRequest { + /// Thread ID to save as a session. If omitted, saves the most recently + /// active thread. + #[serde(default)] + thread_id: Option, + /// If provided, update the existing session with this ID instead of + /// creating a new one. This matches TUI's `build_session_snapshot` + /// behavior where it updates the current session in-place. + #[serde(default)] + session_id: Option, +} + +#[derive(Debug, Serialize)] +struct SaveSessionResponse { + session_id: String, + session: SessionDetailResponse, +} + +async fn save_current_session( + State(state): State, + Json(req): Json, +) -> Result, ApiError> { + // Find the thread to save. + let thread_id = match req.thread_id { + Some(id) => id, + None => { + // Find the most recently updated thread. + let threads = state + .runtime_threads + .list_threads(ThreadListFilter::IncludeArchived, Some(100)) + .await + .map_err(map_thread_err)?; + threads + .into_iter() + .max_by_key(|t| t.updated_at) + .map(|t| t.id) + .ok_or_else(|| ApiError::bad_request("No threads to save"))? + } + }; + + // Get the engine handle (loads the thread into an engine if needed), + // then request a session snapshot. This reuses the same code path as + // TUI's `build_session_snapshot`: the engine holds the authoritative + // messages and token usage, so we don't need to reconstruct from turns. + let engine = state + .runtime_threads + .get_engine(&thread_id) + .await + .map_err(|e| ApiError::internal(format!("Failed to get engine for thread: {e}")))?; + + let snapshot = engine + .get_session_snapshot() + .await + .map_err(|e| ApiError::internal(format!("Failed to get session snapshot: {e}")))?; + + let manager = SessionManager::new(state.sessions_dir.clone()) + .map_err(|e| ApiError::internal(format!("Failed to open sessions dir: {e}")))?; + + // Build or update the session, mirroring TUI's `build_session_snapshot`. + let session = if let Some(ref existing_id) = req.session_id { + match manager.load_session(existing_id) { + Ok(existing) => { + let mut updated = crate::session_manager::update_session( + existing, + &snapshot.messages, + snapshot.total_tokens, + snapshot.system_prompt.as_ref(), + ); + updated.metadata.model = snapshot.model.clone(); + updated.metadata.mode = Some(snapshot.mode.clone()); + updated + } + Err(_) => { + crate::session_manager::create_saved_session_with_id_and_mode( + existing_id.clone(), + &snapshot.messages, + &snapshot.model, + &snapshot.workspace, + snapshot.total_tokens, + snapshot.system_prompt.as_ref(), + Some(snapshot.mode.as_str()), + ) + } + } + } else { + crate::session_manager::create_saved_session_with_mode( + &snapshot.messages, + &snapshot.model, + &snapshot.workspace, + snapshot.total_tokens, + snapshot.system_prompt.as_ref(), + Some(snapshot.mode.as_str()), + ) + }; + + // Save the session. + manager + .save_session(&session) + .map_err(|e| ApiError::internal(format!("Failed to save session: {e}")))?; + + Ok(Json(SaveSessionResponse { + session_id: session.metadata.id.clone(), + session: session_to_detail(session), + })) +} + async fn get_session( State(state): State, Path(id): Path, @@ -1429,6 +1541,258 @@ async fn fork_thread( Ok((StatusCode::CREATED, Json(thread))) } +#[derive(Debug, Deserialize)] +struct UndoTurnRequest { + /// How many turns back to undo (default 0 = last turn only). + #[serde(default)] + depth: Option, +} + +#[derive(Debug, Serialize)] +struct UndoTurnResponse { + /// The new forked thread (with the last N turns removed). + thread: ThreadRecord, + /// The original user message text from the first dropped turn, + /// so the GUI can pre-populate the input box. + original_user_text: Option, +} + +async fn undo_thread_turn( + State(state): State, + Path(id): Path, + Json(req): Json, +) -> Result<(StatusCode, Json), ApiError> { + let depth = req.depth.unwrap_or(0); + let (forked_thread, original_user_text) = state + .runtime_threads + .fork_at_user_message(&id, depth) + .await + .map_err(map_thread_err)?; + Ok(( + StatusCode::CREATED, + Json(UndoTurnResponse { + thread: forked_thread, + original_user_text, + }), + )) +} + +/// Full undo that mirrors TUI's `/undo` command: tries snapshot-based file +/// rollback (`patch_undo`) first, then removes the last conversation turn +/// via `fork_at_user_message`. Returns both the file-rollback result and +/// the new forked thread. +#[derive(Debug, Serialize)] +struct PatchUndoResult { + /// Whether files were restored from a snapshot. + files_restored: bool, + /// Human-readable summary of what was restored (diff stat). + summary: Option, + /// The label of the restored snapshot (e.g. "tool:apply_patch" or "pre-turn:3"). + snapshot_label: Option, +} + +#[derive(Debug, Serialize)] +struct PatchUndoResponse { + /// Result of the snapshot-based file rollback step. + patch_result: PatchUndoResult, + /// The new forked thread (with the last turn removed). + thread: ThreadRecord, + /// The original user text from the removed turn (for re-editing). + original_user_text: Option, +} + +async fn patch_undo_thread_turn( + State(state): State, + Path(id): Path, + Json(req): Json, +) -> Result<(StatusCode, Json), ApiError> { + let depth = req.depth.unwrap_or(0); + + // Step 1: Try snapshot-based file rollback (patch_undo). + let thread = state + .runtime_threads + .get_thread(&id) + .await + .map_err(map_thread_err)?; + + let workspace = PathBuf::from(&thread.workspace); + let patch_result = match crate::snapshot::SnapshotRepo::open_or_init(&workspace) { + Ok(repo) => { + match repo.list(20) { + Ok(snapshots) => { + // Find the newest tool: or pre-turn: snapshot that differs + // from the current workspace — same logic as TUI's patch_undo. + let target = snapshots + .iter() + .filter(|s| { + s.label.starts_with("tool:") || s.label.starts_with("pre-turn:") + }) + .find(|s| { + matches!( + repo.work_tree_matches_snapshot(&s.id), + Ok(false) | Err(_) + ) + }); + + match target { + Some(target) => match repo.restore(&target.id) { + Ok(()) => { + // Compute diff stat for the summary. + let diff_stat = + crate::dependencies::Git::command() + .and_then(|mut git| { + git.args(["diff", "--stat"]) + .current_dir(&workspace) + .output() + .ok() + .and_then(|o| { + let s = String::from_utf8_lossy(&o.stdout) + .trim() + .to_string(); + if s.is_empty() { + None + } else { + Some(s) + } + }) + }); + + let short = + &target.id.as_str()[..target.id.as_str().len().min(8)]; + let summary = match diff_stat { + Some(ref stat) => { + format!( + "Restored snapshot '{}' ({}). Files affected:\n{stat}", + target.label, short + ) + } + None => { + format!( + "Restored snapshot '{}' ({}). No diff changes detected.", + target.label, short + ) + } + }; + + PatchUndoResult { + files_restored: true, + summary: Some(summary), + snapshot_label: Some(target.label.clone()), + } + } + Err(e) => PatchUndoResult { + files_restored: false, + summary: Some(format!("Restore failed: {e}")), + snapshot_label: None, + }, + }, + None => PatchUndoResult { + files_restored: false, + summary: Some( + "No older tool or pre-turn snapshots differ from the current workspace.".to_string(), + ), + snapshot_label: None, + }, + } + } + Err(e) => PatchUndoResult { + files_restored: false, + summary: Some(format!("Failed to list snapshots: {e}")), + snapshot_label: None, + }, + } + } + Err(e) => PatchUndoResult { + files_restored: false, + summary: Some(format!("Snapshot repo unavailable: {e}")), + snapshot_label: None, + }, + }; + + // Step 2: Remove the last conversation turn (undo_conversation). + let (forked_thread, original_user_text) = state + .runtime_threads + .fork_at_user_message(&id, depth) + .await + .map_err(map_thread_err)?; + + Ok(( + StatusCode::CREATED, + Json(PatchUndoResponse { + patch_result, + thread: forked_thread, + original_user_text, + }), + )) +} + +#[derive(Debug, Deserialize)] +struct RetryTurnRequest { + /// How many turns back to retry (default 0 = last turn only). + #[serde(default)] + depth: Option, + /// Override the user message text. If omitted, the original text + /// from the dropped turn is re-used. + #[serde(default)] + prompt: Option, +} + +#[derive(Debug, Serialize)] +struct RetryTurnResponse { + /// The new forked thread (with the last N turns removed). + thread: ThreadRecord, + /// The turn created by the retry. + turn: TurnRecord, +} + +async fn retry_thread_turn( + State(state): State, + Path(id): Path, + Json(req): Json, +) -> Result<(StatusCode, Json), ApiError> { + let depth = req.depth.unwrap_or(0); + let (forked_thread, original_user_text) = state + .runtime_threads + .fork_at_user_message(&id, depth) + .await + .map_err(map_thread_err)?; + + let retry_prompt = req + .prompt + .or(original_user_text) + .unwrap_or_default(); + if retry_prompt.trim().is_empty() { + return Err(ApiError::bad_request( + "No user message to retry — the dropped turn had no user text", + )); + } + + let turn = state + .runtime_threads + .start_turn( + &forked_thread.id, + StartTurnRequest { + prompt: retry_prompt, + model: None, + mode: None, + allow_shell: None, + trust_mode: None, + auto_approve: None, + input_summary: None, + }, + ) + .await + .map_err(map_thread_err)?; + + Ok(( + StatusCode::CREATED, + Json(RetryTurnResponse { + thread: forked_thread, + turn, + }), + )) +} + async fn start_thread_turn( State(state): State, Path(id): Path, @@ -2080,6 +2444,59 @@ fn map_automation_err(err: anyhow::Error) -> ApiError { } } +// ── Snapshot endpoints ────────────────────────────────────────────── + +#[derive(Debug, Serialize)] +struct SnapshotEntry { + id: String, + label: String, + timestamp: i64, +} + +#[derive(Debug, Deserialize)] +struct ListSnapshotsQuery { + /// Maximum number of snapshots to return (default 20). + #[serde(default)] + limit: Option, +} + +async fn list_snapshots( + State(state): State, + Query(query): Query, +) -> Result>, ApiError> { + let workspace = &state.workspace; + let repo = crate::snapshot::SnapshotRepo::open_or_init(workspace) + .map_err(|e| ApiError::internal(format!("Snapshot repo init failed: {e}")))?; + let limit = query.limit.unwrap_or(20); + let snapshots = repo + .list(limit) + .map_err(|e| ApiError::internal(format!("Snapshot list failed: {e}")))?; + let entries: Vec = snapshots + .into_iter() + .map(|s| SnapshotEntry { + id: s.id.as_str().to_string(), + label: s.label, + timestamp: s.timestamp, + }) + .collect(); + Ok(Json(entries)) +} + +async fn restore_snapshot( + State(state): State, + Path(id): Path, +) -> Result, ApiError> { + let workspace = &state.workspace; + let repo = crate::snapshot::SnapshotRepo::open_or_init(workspace) + .map_err(|e| ApiError::internal(format!("Snapshot repo init failed: {e}")))?; + let snapshot_id = crate::snapshot::SnapshotId(id.clone()); + repo.restore(&snapshot_id) + .map_err(|e| ApiError::internal(format!("Snapshot restore failed: {e}")))?; + Ok(Json(json!({ + "restored": id, + }))) +} + fn map_thread_err(err: anyhow::Error) -> ApiError { let message = err.to_string(); if message.contains("not found") { diff --git a/crates/tui/src/runtime_threads.rs b/crates/tui/src/runtime_threads.rs index 6973a9c3d..b82d46739 100644 --- a/crates/tui/src/runtime_threads.rs +++ b/crates/tui/src/runtime_threads.rs @@ -2081,6 +2081,16 @@ impl RuntimeThreadManager { Ok(engine) } + /// Get the engine handle for a thread, loading it if necessary. + /// Public wrapper around the private `ensure_engine_loaded`. + pub async fn get_engine(&self, thread_id: &str) -> Result { + let thread = self.get_thread(thread_id).await?; + self.ensure_engine_loaded(&thread).await + } + + /// Reconstruct the full message history of a thread from its turn items. + /// Used internally by `ensure_engine_loaded` to seed the engine's + /// message buffer when loading a thread. fn reconstruct_messages_from_turns(&self, turns: &[TurnRecord]) -> Result> { let mut messages = Vec::new(); for turn in turns { @@ -2107,6 +2117,42 @@ impl RuntimeThreadManager { }], }); } + TurnItemKind::ToolCall | TurnItemKind::FileChange => { + // Reconstruct a minimal tool_use block for the session + // so the file change cards render correctly when viewing + // a resumed session. + let name = item.summary.clone(); + let input: serde_json::Value = item + .metadata + .clone() + .unwrap_or_else(|| serde_json::Value::Object(Default::default())); + let id = item.id.clone(); + let tool_use_id = id.clone(); + messages.push(Message { + role: "assistant".to_string(), + content: vec![ContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: input.clone(), + caller: None, + }], + }); + // Synthesize a tool_result message so downstream parsers + // (e.g. turn restoration) see a complete tool round-trip. + let output = item.detail.clone().unwrap_or_default(); + messages.push(Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id, + content: output, + is_error: Some(matches!( + item.status, + TurnItemLifecycleStatus::Failed + )), + content_blocks: None, + }], + }); + } _ => {} } }