diff --git a/docs/reference/tui-gateway-contract-matrix.md b/docs/reference/tui-gateway-contract-matrix.md new file mode 100644 index 00000000..6e62abac --- /dev/null +++ b/docs/reference/tui-gateway-contract-matrix.md @@ -0,0 +1,66 @@ +# TUI-Gateway Contract Matrix (Single-Version Baseline) + +This document freezes the contract that TUI consumes from gateway. +It is intentionally single-version and fail-fast by design. + +## Scope + +- Transport contract: JSON-RPC 2.0 (`internal/gateway/protocol`) +- Runtime contract: gateway DTOs (`internal/gateway/contracts.go`) +- Event payload version source of truth: `internal/runtime/controlplane/envelope.go` + +## RPC Methods Used By TUI + +| Method | Params Type | Result Payload | Notes | +| --- | --- | --- | --- | +| `gateway.authenticate` | `protocol.AuthenticateParams` | frame ack | Must succeed before runtime actions | +| `gateway.bindStream` | `protocol.BindStreamParams` | frame ack | Binds session/run event stream | +| `gateway.run` | `protocol.RunParams` | frame ack with `session_id`/`run_id` | Async acceptance only | +| `gateway.compact` | `protocol.CompactParams` | `gateway.CompactResult` | Manual compact | +| `gateway.executeSystemTool` | `protocol.ExecuteSystemToolParams` | `tools.ToolResult` | Tool execution passthrough | +| `gateway.resolvePermission` | `protocol.ResolvePermissionParams` | frame ack | Permission decision submit | +| `gateway.cancel` | `protocol.CancelParams` | frame ack | Cancels run by run/session binding | +| `gateway.listSessions` | none | `[]gateway.SessionSummary` | Session list | +| `gateway.loadSession` | `protocol.LoadSessionParams` | `gateway.Session` | Full session snapshot | +| `gateway.activateSessionSkill` | `protocol.ActivateSessionSkillParams` | frame ack | Activate skill in session | +| `gateway.deactivateSessionSkill` | `protocol.DeactivateSessionSkillParams` | frame ack | Deactivate skill in session | +| `gateway.listSessionSkills` | `protocol.ListSessionSkillsParams` | `[]gateway.SessionSkillState` | Active skill states | +| `gateway.listAvailableSkills` | `protocol.ListAvailableSkillsParams` | `[]gateway.AvailableSkillState` | Available skill catalog | + +## Runtime Event Contract + +- Notification method: `gateway.event` +- TUI only accepts a runtime envelope payload with these required keys: + - `runtime_event_type` (string) + - `turn` (number) + - `phase` (string) + - `timestamp` (RFC3339 or RFC3339Nano) + - `payload_version` (number) + - `payload` (event-specific object/string) +- `payload_version` must equal `controlplane.PayloadVersion`. +- Version mismatch is treated as a hard incompatibility and must fail fast. + +## Error Contract + +TUI consumes standard JSON-RPC errors and gateway extended error codes from +`protocol.JSONRPCError.Data.GatewayCode`. + +Primary gateway codes used for UI mapping: + +- `invalid_frame` +- `invalid_action` +- `invalid_multimodal_payload` +- `missing_required_field` +- `unsupported_action` +- `internal_error` +- `timeout` +- `unsafe_path` +- `unauthorized` +- `access_denied` +- `resource_not_found` + +## Non-Goals + +- No multi-version payload decoding. +- No alias method fallback. +- No legacy field fallback in event payload. diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index 283d1695..8a4b4ef8 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -2627,3 +2627,22 @@ func TestGatewayRuntimePortBridgeDeleteMCPServerSuccess(t *testing.T) { t.Fatalf("servers = %+v, want [srv-2]", cfgMgr.cfg.Tools.MCP.Servers) } } + +func TestDefaultBuildGatewayRuntimePortListSessionsWithoutExplicitWorkdir(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(home, ".config")) + + port, cleanup, err := defaultBuildGatewayRuntimePort(context.Background(), "") + if err != nil { + t.Fatalf("defaultBuildGatewayRuntimePort() error = %v", err) + } + if cleanup != nil { + defer func() { _ = cleanup() }() + } + + if _, err := port.ListSessions(context.Background()); err != nil { + t.Fatalf("ListSessions() with empty cli workdir should succeed, got %v", err) + } +} diff --git a/internal/gateway/multi_workspace_runtime.go b/internal/gateway/multi_workspace_runtime.go index 7f1e5e4a..2a8a7b2c 100644 --- a/internal/gateway/multi_workspace_runtime.go +++ b/internal/gateway/multi_workspace_runtime.go @@ -14,11 +14,11 @@ import ( // MultiWorkspaceRuntime 将多个工作区的 runtime 聚合为单个 gateway.RuntimePort。 // 根据连接上下文中的 workspaceHash 路由到对应工作区的 runtime。 type MultiWorkspaceRuntime struct { - index *agentsession.WorkspaceIndex - bundles map[string]*workspaceBundle - mu sync.RWMutex - buildPort func(ctx context.Context, workdir string) (RuntimePort, func() error, error) - defaultHash string + index *agentsession.WorkspaceIndex + bundles map[string]*workspaceBundle + mu sync.RWMutex + buildPort func(ctx context.Context, workdir string) (RuntimePort, func() error, error) + defaultHash string managementPort ManagementRuntimePort events chan RuntimeEvent @@ -55,16 +55,28 @@ func NewMultiWorkspaceRuntime( func (m *MultiWorkspaceRuntime) getPort(ctx context.Context) (RuntimePort, error) { hash := WorkspaceHashFromContext(ctx) if hash == "" { + m.mu.RLock() hash = m.defaultHash + m.mu.RUnlock() } if hash == "" { + // Support startup flows where gateway preloads a default runtime bundle + // but no explicit workspace hash has been persisted yet. + m.mu.RLock() + if preloaded := m.bundles[""]; preloaded != nil { + port := preloaded.port + m.mu.RUnlock() + return port, nil + } + m.mu.RUnlock() + records := m.index.List() if len(records) > 0 { hash = records[0].Hash } } if hash == "" { - return nil, fmt.Errorf("workspace hash is empty and no default configured") + return nil, fmt.Errorf("%w: workspace hash is empty and no default configured", ErrRuntimeResourceNotFound) } return m.getPortForHash(hash) } @@ -86,7 +98,7 @@ func (m *MultiWorkspaceRuntime) getPortForHash(hash string) (RuntimePort, error) record, ok := m.index.Get(hash) if !ok { - return nil, fmt.Errorf("workspace %s not found", hash) + return nil, fmt.Errorf("%w: workspace %s not found", ErrRuntimeResourceNotFound, hash) } port, cleanup, err := m.buildPort(context.Background(), record.Path) @@ -161,7 +173,7 @@ func (m *MultiWorkspaceRuntime) Close() error { func (m *MultiWorkspaceRuntime) SwitchWorkspace(ctx context.Context, hash string) error { _, ok := m.index.Get(hash) if !ok { - return fmt.Errorf("workspace %s not found", hash) + return fmt.Errorf("%w: workspace %s not found", ErrRuntimeResourceNotFound, hash) } // 预加载对应 runtime,确保后续请求可用 if _, err := m.getPortForHash(hash); err != nil { @@ -210,6 +222,13 @@ func (m *MultiWorkspaceRuntime) DeleteWorkspace(hash string, removeData bool) er if ok { delete(m.bundles, hash) } + if strings.EqualFold(strings.TrimSpace(hash), strings.TrimSpace(m.defaultHash)) { + m.defaultHash = "" + records := m.index.List() + if len(records) > 0 { + m.defaultHash = strings.TrimSpace(records[0].Hash) + } + } m.mu.Unlock() if ok && b != nil && b.cleanup != nil { diff --git a/internal/gateway/multi_workspace_runtime_test.go b/internal/gateway/multi_workspace_runtime_test.go index b0702eed..b1152c8d 100644 --- a/internal/gateway/multi_workspace_runtime_test.go +++ b/internal/gateway/multi_workspace_runtime_test.go @@ -286,12 +286,35 @@ func TestMultiWorkspaceRuntime_NoHashConfigured(t *testing.T) { if _, err := mw.ListSessions(context.Background()); err == nil { t.Fatalf("expected error when no hash is configured") + } else if !errors.Is(err, ErrRuntimeResourceNotFound) { + t.Fatalf("expected ErrRuntimeResourceNotFound, got %v", err) } if got := builder.callCount(); got != 0 { t.Fatalf("buildPort should not be called when no hash, got %d", got) } } +func TestMultiWorkspaceRuntime_NoHashUsesPreloadedAnonymousBundle(t *testing.T) { + idx := agentsession.NewWorkspaceIndex(t.TempDir()) + builder := newTestBuilder() + + mw := NewMultiWorkspaceRuntime(idx, "", builder.build) + t.Cleanup(func() { _ = mw.Close() }) + + preloaded := newRecordingPort("anonymous-default") + mw.PreloadWorkspaceBundle("", preloaded, preloaded.cleanup) + + if _, err := mw.ListSessions(context.Background()); err != nil { + t.Fatalf("ListSessions with anonymous preloaded bundle: %v", err) + } + if got := preloaded.listSessionsCalls.Load(); got != 1 { + t.Fatalf("anonymous preloaded listSessions calls = %d, want 1", got) + } + if got := builder.callCount(); got != 0 { + t.Fatalf("buildPort should not be called when anonymous preloaded bundle exists; got %d", got) + } +} + func TestMultiWorkspaceRuntime_ContextHashOverridesDefault(t *testing.T) { idx, alpha, beta := setupIndex(t) builder := newTestBuilder() @@ -356,6 +379,8 @@ func TestMultiWorkspaceRuntime_UnknownHashErrors(t *testing.T) { _, err := mw.ListSessions(ctxWithHash(t, "deadbeef")) if err == nil { t.Fatalf("expected error for unknown hash") + } else if !errors.Is(err, ErrRuntimeResourceNotFound) { + t.Fatalf("expected ErrRuntimeResourceNotFound, got %v", err) } if got := builder.callCount(); got != 0 { t.Fatalf("buildPort should not be invoked for unknown hash; got %d", got) @@ -510,6 +535,27 @@ func TestMultiWorkspaceRuntime_RenameAndDeletePersist(t *testing.T) { } } +func TestMultiWorkspaceRuntime_DeleteDefaultHashFallsBackToRemainingWorkspace(t *testing.T) { + idx, alpha, beta := setupIndex(t) + builder := newTestBuilder() + mw := NewMultiWorkspaceRuntime(idx, alpha.Hash, builder.build) + t.Cleanup(func() { _ = mw.Close() }) + + if err := mw.DeleteWorkspace(alpha.Hash, false); err != nil { + t.Fatalf("Delete default workspace: %v", err) + } + + if _, err := mw.ListSessions(context.Background()); err != nil { + t.Fatalf("ListSessions fallback after deleting default: %v", err) + } + if builder.portFor(alpha.Path) != nil { + t.Fatalf("alpha port should not be rebuilt after delete") + } + if builder.portFor(beta.Path) == nil { + t.Fatalf("expected fallback to remaining workspace beta") + } +} + func TestMultiWorkspaceRuntime_DeleteUnknownErrors(t *testing.T) { idx, alpha, _ := setupIndex(t) builder := newTestBuilder() diff --git a/internal/gateway/workspace_handlers.go b/internal/gateway/workspace_handlers.go index ffcfc93f..f10ed453 100644 --- a/internal/gateway/workspace_handlers.go +++ b/internal/gateway/workspace_handlers.go @@ -153,6 +153,17 @@ func handleWorkspaceDeleteFrame(ctx context.Context, frame MessageFrame, runtime if deleteErr := mw.DeleteWorkspace(params.Hash, params.RemoveData); deleteErr != nil { return errorFrame(frame, NewFrameError(ErrorCodeInternalError, deleteErr.Error())) } + if wsState, ok := ConnectionWorkspaceStateFromContext(ctx); ok { + activeHash := strings.TrimSpace(wsState.GetWorkspaceHash()) + if strings.EqualFold(activeHash, strings.TrimSpace(params.Hash)) { + wsState.SetWorkspaceHash("") + if relay, relayOK := StreamRelayFromContext(ctx); relayOK { + if connID, connOK := ConnectionIDFromContext(ctx); connOK { + relay.ClearConnectionBindings(connID) + } + } + } + } return MessageFrame{ Type: FrameTypeAck, @@ -270,4 +281,3 @@ func decodeWorkspaceDeletePayload(payload any) (workspaceDeleteParams, *FrameErr return workspaceDeleteParams{}, NewFrameError(ErrorCodeInvalidFrame, "invalid workspace.delete payload") } } - diff --git a/internal/gateway/workspace_handlers_test.go b/internal/gateway/workspace_handlers_test.go new file mode 100644 index 00000000..83aab74b --- /dev/null +++ b/internal/gateway/workspace_handlers_test.go @@ -0,0 +1,73 @@ +package gateway + +import ( + "context" + "testing" + + "neo-code/internal/gateway/protocol" +) + +func TestHandleWorkspaceDeleteFrameClearsActiveWorkspaceStateAndBindings(t *testing.T) { + idx, alpha, _ := setupIndex(t) + builder := newTestBuilder() + mw := NewMultiWorkspaceRuntime(idx, alpha.Hash, builder.build) + t.Cleanup(func() { _ = mw.Close() }) + + relay := NewStreamRelay(StreamRelayOptions{}) + connID := NewConnectionID() + registerErr := relay.RegisterConnection(ConnectionRegistration{ + ConnectionID: connID, + Channel: StreamChannelIPC, + Context: context.Background(), + Cancel: func() {}, + Write: func(message RelayMessage) error { + return nil + }, + Close: func() {}, + }) + if registerErr != nil { + t.Fatalf("register connection: %v", registerErr) + } + t.Cleanup(func() { relay.dropConnection(connID) }) + + if bindErr := relay.BindConnection(connID, StreamBinding{ + SessionID: "session-delete-check", + Channel: StreamChannelAll, + Explicit: true, + }); bindErr != nil { + t.Fatalf("bind connection: %v", bindErr) + } + + wsState := NewConnectionWorkspaceState() + wsState.SetWorkspaceHash(alpha.Hash) + ctx := WithConnectionID( + WithStreamRelay( + WithConnectionWorkspaceState(context.Background(), wsState), + relay, + ), + connID, + ) + + response := handleWorkspaceDeleteFrame(ctx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionWorkspaceDelete, + RequestID: "workspace-delete-active", + Payload: protocol.DeleteWorkspaceParams{ + WorkspaceHash: alpha.Hash, + }, + }, mw) + if response.Type != FrameTypeAck { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeAck) + } + + if got := wsState.GetWorkspaceHash(); got != "" { + t.Fatalf("workspace hash should be cleared after deleting active workspace, got %q", got) + } + + relay.mu.RLock() + _, exists := relay.connectionBindings[NormalizeConnectionID(connID)] + relay.mu.RUnlock() + if exists { + t.Fatalf("connection bindings should be cleared after deleting active workspace") + } +} diff --git a/internal/runtime/session_logs.go b/internal/runtime/session_logs.go index 49a32ffb..b203adf2 100644 --- a/internal/runtime/session_logs.go +++ b/internal/runtime/session_logs.go @@ -21,6 +21,7 @@ type SessionLogEntry struct { Level string `json:"level"` Source string `json:"source"` Message string `json:"message"` + Inline string `json:"inline_message,omitempty"` } // LoadSessionLogEntries 按会话 ID 读取日志查看器持久化数据。 diff --git a/internal/tui/core/app/app.go b/internal/tui/core/app/app.go index 7a0f40ca..f82a9f16 100644 --- a/internal/tui/core/app/app.go +++ b/internal/tui/core/app/app.go @@ -28,6 +28,7 @@ type logEntry struct { Level string Source string Message string + Inline string } type panel = tuistate.Panel @@ -160,6 +161,9 @@ type appRuntimeState struct { logPersistDirty bool logPersistVersion int transcriptContent string + transcriptProcessFoldAvailable bool + transcriptProcessExpanded bool + transcriptProcessExpandedOrdinal int transcriptScrollbarDrag bool startupScreenLocked bool suppressAssistantForRun string diff --git a/internal/tui/core/app/checkpoint_commands.go b/internal/tui/core/app/checkpoint_commands.go new file mode 100644 index 00000000..c145dd5a --- /dev/null +++ b/internal/tui/core/app/checkpoint_commands.go @@ -0,0 +1,172 @@ +package tui + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + + tuiservices "neo-code/internal/tui/services" +) + +const maxCheckpointPatchPreviewChars = 4000 + +type checkpointCommandRuntime interface { + ListCheckpoints(ctx context.Context, input tuiservices.CheckpointListInput) ([]tuiservices.CheckpointEntry, error) + RestoreCheckpoint(ctx context.Context, input tuiservices.CheckpointRestoreInput) (tuiservices.CheckpointRestoreResult, error) + UndoRestoreCheckpoint(ctx context.Context, sessionID string) (tuiservices.CheckpointRestoreResult, error) + CheckpointDiff(ctx context.Context, sessionID string, checkpointID string) (tuiservices.CheckpointDiffResult, error) +} + +func (a *App) handleCheckpointCommand(rest string) tea.Cmd { + sessionID := strings.TrimSpace(a.state.ActiveSessionID) + if sessionID == "" { + a.applyInlineCommandError("checkpoint command requires an active session; send one message first or switch session via /session") + return nil + } + runtime, ok := a.runtime.(checkpointCommandRuntime) + if !ok { + a.applyInlineCommandError("checkpoint command is unavailable in current runtime mode") + return nil + } + + action, argument := splitFirstWord(strings.TrimSpace(rest)) + switch strings.ToLower(strings.TrimSpace(action)) { + case "", "list": + if strings.TrimSpace(argument) != "" { + a.applyInlineCommandError(fmt.Sprintf("usage: %s", slashUsageCheckpoint)) + return nil + } + return a.runCheckpointCommand(func(ctx context.Context) (string, error) { + entries, err := runtime.ListCheckpoints(ctx, tuiservices.CheckpointListInput{ + SessionID: sessionID, + Limit: 20, + RestorableOnly: true, + }) + if err != nil { + return "", normalizeCheckpointCommandError(err) + } + return formatCheckpointList(entries), nil + }) + case "restore": + checkpointID, tail := splitFirstWord(strings.TrimSpace(argument)) + checkpointID = strings.TrimSpace(checkpointID) + if checkpointID == "" || isCommandPlaceholder(checkpointID) || strings.TrimSpace(tail) != "" { + a.applyInlineCommandError(fmt.Sprintf("usage: %s", slashUsageCheckpointRestore)) + return nil + } + return a.runCheckpointCommand(func(ctx context.Context) (string, error) { + result, err := runtime.RestoreCheckpoint(ctx, tuiservices.CheckpointRestoreInput{ + SessionID: sessionID, + CheckpointID: checkpointID, + }) + if err != nil { + return "", normalizeCheckpointCommandError(err) + } + id := fallbackText(strings.TrimSpace(result.CheckpointID), checkpointID) + if result.HasConflict { + return fmt.Sprintf("Checkpoint restored with conflicts: %s", id), nil + } + return fmt.Sprintf("Checkpoint restored: %s", id), nil + }) + case "undo": + if strings.TrimSpace(argument) != "" { + a.applyInlineCommandError(fmt.Sprintf("usage: %s", slashUsageCheckpointUndo)) + return nil + } + return a.runCheckpointCommand(func(ctx context.Context) (string, error) { + result, err := runtime.UndoRestoreCheckpoint(ctx, sessionID) + if err != nil { + return "", normalizeCheckpointCommandError(err) + } + id := fallbackText(strings.TrimSpace(result.CheckpointID), "guard checkpoint") + return fmt.Sprintf("Checkpoint restore undo applied: %s", id), nil + }) + case "diff": + checkpointID, tail := splitFirstWord(strings.TrimSpace(argument)) + checkpointID = strings.TrimSpace(checkpointID) + if checkpointID == "" || isCommandPlaceholder(checkpointID) || strings.TrimSpace(tail) != "" { + a.applyInlineCommandError(fmt.Sprintf("usage: %s", slashUsageCheckpointDiff)) + return nil + } + return a.runCheckpointCommand(func(ctx context.Context) (string, error) { + diff, err := runtime.CheckpointDiff(ctx, sessionID, checkpointID) + if err != nil { + return "", normalizeCheckpointCommandError(err) + } + return formatCheckpointDiff(diff), nil + }) + default: + a.applyInlineCommandError("usage: /checkpoint | /checkpoint restore | /checkpoint undo | /checkpoint diff ") + return nil + } +} + +func (a *App) runCheckpointCommand(run func(context.Context) (string, error)) tea.Cmd { + return tuiservices.RunLocalCommandCmd( + run, + func(notice string, err error) tea.Msg { + return localCommandResultMsg{Notice: notice, Err: err} + }, + ) +} + +func normalizeCheckpointCommandError(err error) error { + if err == nil { + return nil + } + if isGatewayUnsupportedActionError(err) { + return errors.New("gateway does not support checkpoint commands; please upgrade gateway and client to the latest version") + } + return err +} + +func formatCheckpointList(entries []tuiservices.CheckpointEntry) string { + if len(entries) == 0 { + return "No restorable checkpoints in current session." + } + rows := make([]string, 0, len(entries)+1) + rows = append(rows, "Restorable checkpoints:") + for _, entry := range entries { + id := fallbackText(strings.TrimSpace(entry.CheckpointID), "(unknown)") + reason := fallbackText(strings.TrimSpace(entry.Reason), "-") + status := fallbackText(strings.TrimSpace(entry.Status), "-") + createdAt := "-" + if entry.CreatedAtMS > 0 { + createdAt = time.UnixMilli(entry.CreatedAtMS).Local().Format(time.RFC3339) + } + rows = append(rows, fmt.Sprintf("- %s | reason=%s | status=%s | created=%s", id, reason, status, createdAt)) + } + return strings.Join(rows, "\n") +} + +func formatCheckpointDiff(result tuiservices.CheckpointDiffResult) string { + checkpointID := fallbackText(strings.TrimSpace(result.CheckpointID), "(unknown)") + rows := []string{ + fmt.Sprintf( + "Checkpoint diff: %s (added=%d modified=%d deleted=%d)", + checkpointID, + len(result.Files.Added), + len(result.Files.Modified), + len(result.Files.Deleted), + ), + } + patch := strings.TrimSpace(result.Patch) + if patch == "" { + return strings.Join(rows, "\n") + } + runes := []rune(patch) + if len(runes) > maxCheckpointPatchPreviewChars { + patch = string(runes[:maxCheckpointPatchPreviewChars]) + "\n...(truncated)" + } + rows = append(rows, patch) + return strings.Join(rows, "\n") +} + +func isCommandPlaceholder(value string) bool { + trimmed := strings.TrimSpace(value) + return strings.HasPrefix(trimmed, "<") && strings.HasSuffix(trimmed, ">") +} diff --git a/internal/tui/core/app/checkpoint_commands_test.go b/internal/tui/core/app/checkpoint_commands_test.go new file mode 100644 index 00000000..c3afffd3 --- /dev/null +++ b/internal/tui/core/app/checkpoint_commands_test.go @@ -0,0 +1,212 @@ +package tui + +import ( + "context" + "errors" + "strings" + "testing" + + tuiservices "neo-code/internal/tui/services" +) + +type checkpointRuntimeStub struct { + tuiservices.Runtime + + listInput tuiservices.CheckpointListInput + listResult []tuiservices.CheckpointEntry + listErr error + + restoreInput tuiservices.CheckpointRestoreInput + restoreResult tuiservices.CheckpointRestoreResult + restoreErr error + + undoSessionID string + undoResult tuiservices.CheckpointRestoreResult + undoErr error + + diffSessionID string + diffID string + diffResult tuiservices.CheckpointDiffResult + diffErr error +} + +func (s *checkpointRuntimeStub) ListCheckpoints( + _ context.Context, + input tuiservices.CheckpointListInput, +) ([]tuiservices.CheckpointEntry, error) { + s.listInput = input + if s.listErr != nil { + return nil, s.listErr + } + return append([]tuiservices.CheckpointEntry(nil), s.listResult...), nil +} + +func (s *checkpointRuntimeStub) RestoreCheckpoint( + _ context.Context, + input tuiservices.CheckpointRestoreInput, +) (tuiservices.CheckpointRestoreResult, error) { + s.restoreInput = input + if s.restoreErr != nil { + return tuiservices.CheckpointRestoreResult{}, s.restoreErr + } + return s.restoreResult, nil +} + +func (s *checkpointRuntimeStub) UndoRestoreCheckpoint( + _ context.Context, + sessionID string, +) (tuiservices.CheckpointRestoreResult, error) { + s.undoSessionID = strings.TrimSpace(sessionID) + if s.undoErr != nil { + return tuiservices.CheckpointRestoreResult{}, s.undoErr + } + return s.undoResult, nil +} + +func (s *checkpointRuntimeStub) CheckpointDiff( + _ context.Context, + sessionID string, + checkpointID string, +) (tuiservices.CheckpointDiffResult, error) { + s.diffSessionID = strings.TrimSpace(sessionID) + s.diffID = strings.TrimSpace(checkpointID) + if s.diffErr != nil { + return tuiservices.CheckpointDiffResult{}, s.diffErr + } + return s.diffResult, nil +} + +func TestHandleCheckpointCommandRequiresActiveSession(t *testing.T) { + app, _ := newTestApp(t) + handled, cmd := app.handleImmediateSlashCommand("/checkpoint") + if !handled { + t.Fatalf("expected /checkpoint to be recognized") + } + if cmd != nil { + t.Fatalf("expected no cmd when active session is missing") + } + if !strings.Contains(app.state.StatusText, "requires an active session") { + t.Fatalf("expected active session hint, got %q", app.state.StatusText) + } +} + +func TestHandleCheckpointCommandUnsupportedRuntime(t *testing.T) { + app, _ := newTestApp(t) + app.state.ActiveSessionID = "session-1" + handled, cmd := app.handleImmediateSlashCommand("/checkpoint") + if !handled { + t.Fatalf("expected /checkpoint to be recognized") + } + if cmd != nil { + t.Fatalf("expected no cmd when checkpoint runtime is unavailable") + } + if !strings.Contains(app.state.StatusText, "unavailable") { + t.Fatalf("expected unavailable hint, got %q", app.state.StatusText) + } +} + +func TestHandleCheckpointSlashCommands(t *testing.T) { + app, runtime := newTestApp(t) + app.state.ActiveSessionID = "session-1" + + checkpointRuntime := &checkpointRuntimeStub{ + Runtime: runtime, + listResult: []tuiservices.CheckpointEntry{ + {CheckpointID: "cp-1", Reason: "pre_write", Status: "ready", Restorable: true, CreatedAtMS: 1700000000000}, + }, + restoreResult: tuiservices.CheckpointRestoreResult{CheckpointID: "cp-1", SessionID: "session-1"}, + undoResult: tuiservices.CheckpointRestoreResult{CheckpointID: "cp-guard", SessionID: "session-1"}, + diffResult: tuiservices.CheckpointDiffResult{ + CheckpointID: "cp-1", + Files: tuiservices.CheckpointDiffFiles{Modified: []string{"a.txt"}}, + Patch: "diff --git a/a.txt b/a.txt", + }, + } + app.runtime = checkpointRuntime + + handled, cmd := app.handleImmediateSlashCommand("/checkpoint") + if !handled || cmd == nil { + t.Fatalf("expected /checkpoint to return async cmd") + } + model, _ := app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "Restorable checkpoints:") { + t.Fatalf("unexpected list status text: %q", app.state.StatusText) + } + if checkpointRuntime.listInput.SessionID != "session-1" || checkpointRuntime.listInput.Limit != 20 || !checkpointRuntime.listInput.RestorableOnly { + t.Fatalf("unexpected list input: %#v", checkpointRuntime.listInput) + } + + handled, cmd = app.handleImmediateSlashCommand("/checkpoint restore cp-1") + if !handled || cmd == nil { + t.Fatalf("expected /checkpoint restore to return async cmd") + } + model, _ = app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "Checkpoint restored: cp-1") { + t.Fatalf("unexpected restore status text: %q", app.state.StatusText) + } + if checkpointRuntime.restoreInput.SessionID != "session-1" || checkpointRuntime.restoreInput.CheckpointID != "cp-1" { + t.Fatalf("unexpected restore input: %#v", checkpointRuntime.restoreInput) + } + + handled, cmd = app.handleImmediateSlashCommand("/checkpoint diff cp-1") + if !handled || cmd == nil { + t.Fatalf("expected /checkpoint diff to return async cmd") + } + model, _ = app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "Checkpoint diff: cp-1") { + t.Fatalf("unexpected diff status text: %q", app.state.StatusText) + } + if checkpointRuntime.diffSessionID != "session-1" || checkpointRuntime.diffID != "cp-1" { + t.Fatalf("unexpected diff inputs: sid=%q checkpoint=%q", checkpointRuntime.diffSessionID, checkpointRuntime.diffID) + } + + handled, cmd = app.handleImmediateSlashCommand("/checkpoint undo") + if !handled || cmd == nil { + t.Fatalf("expected /checkpoint undo to return async cmd") + } + model, _ = app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "Checkpoint restore undo applied") { + t.Fatalf("unexpected undo status text: %q", app.state.StatusText) + } + if checkpointRuntime.undoSessionID != "session-1" { + t.Fatalf("unexpected undo session id: %q", checkpointRuntime.undoSessionID) + } +} + +func TestHandleCheckpointCommandUsageAndErrorBranches(t *testing.T) { + app, runtime := newTestApp(t) + app.state.ActiveSessionID = "session-1" + checkpointRuntime := &checkpointRuntimeStub{ + Runtime: runtime, + } + app.runtime = checkpointRuntime + + if handled, cmd := app.handleImmediateSlashCommand("/checkpoint restore"); !handled || cmd != nil { + t.Fatalf("expected restore usage branch") + } + if !strings.Contains(app.state.StatusText, slashUsageCheckpointRestore) { + t.Fatalf("expected restore usage text, got %q", app.state.StatusText) + } + + if handled, cmd := app.handleImmediateSlashCommand("/checkpoint undo now"); !handled || cmd != nil { + t.Fatalf("expected undo usage branch") + } + if !strings.Contains(app.state.StatusText, slashUsageCheckpointUndo) { + t.Fatalf("expected undo usage text, got %q", app.state.StatusText) + } + + checkpointRuntime.listErr = errors.New("list failed") + handled, cmd := app.handleImmediateSlashCommand("/checkpoint") + if !handled || cmd == nil { + t.Fatalf("expected /checkpoint to return cmd") + } + model, _ := app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "list failed") { + t.Fatalf("expected list error passthrough, got %q", app.state.StatusText) + } +} diff --git a/internal/tui/core/app/commands.go b/internal/tui/core/app/commands.go index 33da877c..2d66ed42 100644 --- a/internal/tui/core/app/commands.go +++ b/internal/tui/core/app/commands.go @@ -16,6 +16,10 @@ import ( tuiservices "neo-code/internal/tui/services" ) +type runtimeModelCatalogSource interface { + ListModels(ctx context.Context, sessionID string) ([]providertypes.ModelDescriptor, string, error) +} + const ( slashPrefix = "/" slashCommandHelp = "/help" @@ -31,22 +35,27 @@ const ( slashCommandForget = "/forget" slashCommandSkills = "/skills" slashCommandSkill = "/skill" - - slashUsageHelp = "/help" - slashUsageExit = "/exit" - slashUsageClear = "/clear" - slashUsageCompact = "/compact" - slashUsageProvider = "/provider" - slashUsageProviderAdd = "/provider add" - slashUsageModel = "/model" - slashUsageSession = "/session" - slashUsageMemo = "/memo" - slashUsageRemember = "/remember " - slashUsageForget = "/forget " - slashUsageSkills = "/skills" - slashUsageSkillUse = "/skill use " - slashUsageSkillOff = "/skill off " - slashUsageSkillActive = "/skill active" + slashCommandCheckpoint = "/checkpoint" + + slashUsageHelp = "/help" + slashUsageExit = "/exit" + slashUsageClear = "/clear" + slashUsageCompact = "/compact" + slashUsageProvider = "/provider" + slashUsageProviderAdd = "/provider add" + slashUsageModel = "/model" + slashUsageSession = "/session" + slashUsageMemo = "/memo" + slashUsageRemember = "/remember " + slashUsageForget = "/forget " + slashUsageSkills = "/skills" + slashUsageSkillUse = "/skill use " + slashUsageSkillOff = "/skill off " + slashUsageSkillActive = "/skill active" + slashUsageCheckpoint = "/checkpoint" + slashUsageCheckpointRestore = "/checkpoint restore " + slashUsageCheckpointUndo = "/checkpoint undo" + slashUsageCheckpointDiff = "/checkpoint diff " commandMenuTitle = "Suggestions" providerPickerTitle = "Select Provider" @@ -138,6 +147,10 @@ var builtinSlashCommands = []slashCommand{ {Usage: slashUsageSkillUse, Description: "Activate one skill in current session"}, {Usage: slashUsageSkillOff, Description: "Deactivate one skill in current session"}, {Usage: slashUsageSkillActive, Description: "Show active skills in current session"}, + {Usage: slashUsageCheckpoint, Description: "List checkpoints of current session"}, + {Usage: slashUsageCheckpointRestore, Description: "Restore session to one checkpoint"}, + {Usage: slashUsageCheckpointUndo, Description: "Undo the latest checkpoint restore"}, + {Usage: slashUsageCheckpointDiff, Description: "Show diff for one checkpoint"}, {Usage: slashUsageProvider, Description: "Open the interactive provider picker"}, {Usage: slashUsageProviderAdd, Description: "Add a new custom provider"}, {Usage: slashUsageModel, Description: "Open the interactive model picker"}, @@ -251,6 +264,23 @@ func (a *App) refreshProviderPicker() error { } func (a *App) refreshModelPicker() error { + if source, ok := a.runtime.(runtimeModelCatalogSource); ok { + models, selectedModelID, err := source.ListModels(context.Background(), strings.TrimSpace(a.state.ActiveSessionID)) + if err != nil { + return err + } + replacePickerItems(&a.modelPicker, mapModelItems(models)) + selectedModelID = strings.TrimSpace(selectedModelID) + if selectedModelID == "" { + selectedModelID = strings.TrimSpace(a.state.CurrentModel) + } + if selectedModelID != "" { + a.state.CurrentModel = selectedModelID + } + selectPickerItemByID(&a.modelPicker, selectedModelID) + return nil + } + models, err := a.providerSvc.ListModelsSnapshot(context.Background()) if err != nil { return err diff --git a/internal/tui/core/app/commands_test.go b/internal/tui/core/app/commands_test.go index 70fd49f9..6c956686 100644 --- a/internal/tui/core/app/commands_test.go +++ b/internal/tui/core/app/commands_test.go @@ -10,6 +10,7 @@ import ( configstate "neo-code/internal/config/state" providertypes "neo-code/internal/provider/types" + tuiservices "neo-code/internal/tui/services" ) func TestBuiltinSlashCommands(t *testing.T) { @@ -21,6 +22,7 @@ func TestBuiltinSlashCommands(t *testing.T) { foundTodo := false foundSkills := false foundSkillUse := false + foundCheckpoint := false foundStatus := false for _, cmd := range builtinSlashCommands { if cmd.Usage == slashUsageHelp { @@ -35,6 +37,9 @@ func TestBuiltinSlashCommands(t *testing.T) { if cmd.Usage == slashUsageSkillUse { foundSkillUse = true } + if cmd.Usage == slashUsageCheckpoint { + foundCheckpoint = true + } if strings.EqualFold(cmd.Usage, "/status") { foundStatus = true } @@ -51,6 +56,9 @@ func TestBuiltinSlashCommands(t *testing.T) { if !foundSkillUse { t.Error("expected to find /skill use command") } + if !foundCheckpoint { + t.Error("expected to find /checkpoint command") + } if foundStatus { t.Error("did not expect /status command in builtin slash commands") } @@ -346,6 +354,51 @@ func TestRunModelCatalogRefreshCmd(t *testing.T) { } } +func TestRefreshModelPickerPrefersRuntimeModelCatalog(t *testing.T) { + app, _ := newTestApp(t) + app.state.ActiveSessionID = "session-1" + app.state.CurrentModel = "local-model" + + runtimeSource := &runtimeModelCatalogSourceStub{ + Runtime: app.runtime, + models: []providertypes.ModelDescriptor{ + {ID: "remote-a", Name: "Remote A"}, + {ID: "remote-b", Name: "Remote B"}, + }, + selectedModelID: "remote-b", + } + app.runtime = runtimeSource + + if err := app.refreshModelPicker(); err != nil { + t.Fatalf("refreshModelPicker() error = %v", err) + } + if runtimeSource.lastSessionID != "session-1" { + t.Fatalf("expected runtime model query to use active session id, got %q", runtimeSource.lastSessionID) + } + if app.state.CurrentModel != "remote-b" { + t.Fatalf("expected selected_model_id to update current model, got %q", app.state.CurrentModel) + } + selected, ok := app.modelPicker.SelectedItem().(selectionItem) + if !ok { + t.Fatalf("expected selected picker item, got %T", app.modelPicker.SelectedItem()) + } + if selected.id != "remote-b" { + t.Fatalf("expected picker selection to follow selected_model_id, got %q", selected.id) + } +} + +func TestRefreshModelPickerPropagatesRuntimeCatalogError(t *testing.T) { + app, _ := newTestApp(t) + app.runtime = &runtimeModelCatalogSourceStub{ + Runtime: app.runtime, + err: errors.New("list models failed"), + } + + if err := app.refreshModelPicker(); err == nil || !strings.Contains(err.Error(), "list models failed") { + t.Fatalf("expected runtime model catalog error, got %v", err) + } +} + func TestRefreshHelpPicker(t *testing.T) { app, _ := newTestApp(t) app.refreshHelpPicker() @@ -371,3 +424,22 @@ func TestOpenHelpPicker(t *testing.T) { t.Fatalf("expected help picker search box to be focused") } } + +type runtimeModelCatalogSourceStub struct { + tuiservices.Runtime + models []providertypes.ModelDescriptor + selectedModelID string + err error + lastSessionID string +} + +func (s *runtimeModelCatalogSourceStub) ListModels( + _ context.Context, + sessionID string, +) ([]providertypes.ModelDescriptor, string, error) { + s.lastSessionID = strings.TrimSpace(sessionID) + if s.err != nil { + return nil, "", s.err + } + return append([]providertypes.ModelDescriptor(nil), s.models...), strings.TrimSpace(s.selectedModelID), nil +} diff --git a/internal/tui/core/app/gateway_error_mapping.go b/internal/tui/core/app/gateway_error_mapping.go index bd348bbd..ad5a3ce5 100644 --- a/internal/tui/core/app/gateway_error_mapping.go +++ b/internal/tui/core/app/gateway_error_mapping.go @@ -13,17 +13,8 @@ func isGatewayUnsupportedActionError(err error) bool { if err == nil { return false } - if errors.Is(err, tuiservices.ErrUnsupportedActionInGatewayMode) { - return true - } - var rpcErr *tuiservices.GatewayRPCError - if !errors.As(err, &rpcErr) || rpcErr == nil { - return false - } - - if strings.EqualFold(strings.TrimSpace(rpcErr.GatewayCode), protocol.GatewayCodeUnsupportedAction) { - return true - } - return rpcErr.Code == protocol.JSONRPCCodeMethodNotFound + return errors.As(err, &rpcErr) && + rpcErr != nil && + strings.EqualFold(strings.TrimSpace(rpcErr.GatewayCode), protocol.GatewayCodeUnsupportedAction) } diff --git a/internal/tui/core/app/hydrate_test.go b/internal/tui/core/app/hydrate_test.go index 051bc8a2..0e187bcf 100644 --- a/internal/tui/core/app/hydrate_test.go +++ b/internal/tui/core/app/hydrate_test.go @@ -6,9 +6,11 @@ import ( "path/filepath" "strings" "testing" + "time" providertypes "neo-code/internal/provider/types" agentsession "neo-code/internal/session" + agentruntime "neo-code/internal/tui/services" ) func TestHydrateSessionLoadsHistoryAndWorkdir(t *testing.T) { @@ -89,3 +91,131 @@ func TestHydrateSessionReturnsLoadError(t *testing.T) { t.Fatalf("HydrateSession() error = %v, want contains %q", err, "load failed") } } + +func TestHydrateSessionReplaysFoldRelatedPersistedLogs(t *testing.T) { + app, runtime := newTestApp(t) + sessionID := "session-hydrate-logs" + runtime.loadSessions = map[string]agentsession.Session{ + sessionID: { + ID: sessionID, + Title: "Hydrated Logs Session", + Workdir: t.TempDir(), + Messages: []providertypes.Message{ + { + Role: roleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }, + { + Role: roleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("final answer")}, + }, + }, + }, + } + runtime.logEntriesBySID[sessionID] = []agentruntime.SessionLogEntry{ + {Timestamp: time.Unix(1_700_020_001, 0), Level: "info", Source: "verify", Message: "Verification started: completion_passed=true"}, + {Timestamp: time.Unix(1_700_020_002, 0), Level: "info", Source: "provider", Message: "Provider switched: openai"}, + } + + if err := app.HydrateSession(context.Background(), sessionID); err != nil { + t.Fatalf("HydrateSession() error = %v", err) + } + + joined := "" + for _, message := range app.activeMessages { + text := strings.TrimSpace(renderMessagePartsForDisplay(message.Parts)) + if joined != "" { + joined += "\n" + } + joined += text + } + if !strings.Contains(joined, inlineLogMarker+"verify: Verification started: completion_passed=true") { + t.Fatalf("expected verify session log to be replayed into transcript, got %q", joined) + } + if strings.Contains(joined, inlineLogMarker+"provider: Provider switched: openai") { + t.Fatalf("expected non-fold provider log to stay out of transcript replay, got %q", joined) + } +} + +func TestHydrateSessionLogReplaySkipsDuplicateInlineMessages(t *testing.T) { + app, runtime := newTestApp(t) + sessionID := "session-hydrate-log-dedup" + inline := inlineLogMarker + "verify: Verification finished: accepted" + runtime.loadSessions = map[string]agentsession.Session{ + sessionID: { + ID: sessionID, + Title: "Hydrated Dedup Session", + Workdir: t.TempDir(), + Messages: []providertypes.Message{ + { + Role: roleSystem, + Parts: []providertypes.ContentPart{providertypes.NewTextPart(inline)}, + }, + { + Role: roleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("final answer")}, + }, + }, + }, + } + runtime.logEntriesBySID[sessionID] = []agentruntime.SessionLogEntry{ + {Timestamp: time.Unix(1_700_020_011, 0), Level: "info", Source: "verify", Message: "Verification finished: accepted"}, + } + + if err := app.HydrateSession(context.Background(), sessionID); err != nil { + t.Fatalf("HydrateSession() error = %v", err) + } + + count := 0 + for _, message := range app.activeMessages { + text := strings.TrimSpace(renderMessagePartsForDisplay(message.Parts)) + if text == inline { + count++ + } + } + if count != 1 { + t.Fatalf("expected replay dedup to keep one inline message, got %d", count) + } +} + +func TestHydrateSessionReplaysPersistedInlineMessageRegardlessOfSource(t *testing.T) { + app, runtime := newTestApp(t) + sessionID := "session-hydrate-inline-source" + runtime.loadSessions = map[string]agentsession.Session{ + sessionID: { + ID: sessionID, + Title: "Hydrated Inline Source Session", + Workdir: t.TempDir(), + Messages: []providertypes.Message{ + { + Role: roleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("final answer")}, + }, + }, + }, + } + runtime.logEntriesBySID[sessionID] = []agentruntime.SessionLogEntry{ + { + Timestamp: time.Unix(1_700_020_021, 0), + Level: "info", + Source: "provider", + Message: "Provider switched: openai", + Inline: inlineLogMarker + "verify: Verification finished: accepted", + }, + } + + if err := app.HydrateSession(context.Background(), sessionID); err != nil { + t.Fatalf("HydrateSession() error = %v", err) + } + joined := "" + for _, message := range app.activeMessages { + text := strings.TrimSpace(renderMessagePartsForDisplay(message.Parts)) + if joined != "" { + joined += "\n" + } + joined += text + } + if !strings.Contains(joined, inlineLogMarker+"verify: Verification finished: accepted") { + t.Fatalf("expected persisted inline message to be replayed regardless of source, got %q", joined) + } +} diff --git a/internal/tui/core/app/skills_commands_test.go b/internal/tui/core/app/skills_commands_test.go index f62ee022..430cfb14 100644 --- a/internal/tui/core/app/skills_commands_test.go +++ b/internal/tui/core/app/skills_commands_test.go @@ -65,8 +65,8 @@ func TestSkillCommandErrorAndPlaceholderHelpers(t *testing.T) { } unsupported := normalizeSkillCommandError(tuiservices.ErrUnsupportedActionInGatewayMode) - if unsupported == nil || !strings.Contains(strings.ToLower(unsupported.Error()), "gateway") { - t.Fatalf("expected gateway hint, got %v", unsupported) + if unsupported != tuiservices.ErrUnsupportedActionInGatewayMode { + t.Fatalf("expected legacy sentinel passthrough, got %v", unsupported) } unsupportedRPCByGatewayCode := normalizeSkillCommandError(&tuiservices.GatewayRPCError{ Method: protocol.MethodGatewayListAvailableSkills, @@ -82,8 +82,8 @@ func TestSkillCommandErrorAndPlaceholderHelpers(t *testing.T) { Code: protocol.JSONRPCCodeMethodNotFound, Message: "method not found", }) - if unsupportedRPCByCodeOnly == nil || !strings.Contains(strings.ToLower(unsupportedRPCByCodeOnly.Error()), "gateway") { - t.Fatalf("expected gateway hint for method_not_found rpc error, got %v", unsupportedRPCByCodeOnly) + if unsupportedRPCByCodeOnly == nil || !strings.Contains(strings.ToLower(unsupportedRPCByCodeOnly.Error()), "method not found") { + t.Fatalf("expected method_not_found passthrough, got %v", unsupportedRPCByCodeOnly) } containsButNotSentinel := errors.New("skill id unsupported_action_in_gateway_mode is invalid") if normalizeSkillCommandError(containsButNotSentinel) != containsButNotSentinel { diff --git a/internal/tui/core/app/todo.go b/internal/tui/core/app/todo.go index 91815564..483dc162 100644 --- a/internal/tui/core/app/todo.go +++ b/internal/tui/core/app/todo.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/charmbracelet/lipgloss" + agentsession "neo-code/internal/session" ) @@ -45,6 +47,7 @@ const ( todoDefaultExpandedLimit = 14 todoMaxExpandedLimit = 24 todoHeaderLines = 1 + todoTitleMaxDefault = 84 ) type todoViewItem struct { @@ -154,6 +157,113 @@ func formatTodoUpdatedAt(ts time.Time) string { return ts.Format("2006-01-02 15:04:05") } +func isMarkdownTableSeparatorLine(line string) bool { + trimmed := strings.TrimSpace(line) + trimmed = strings.Trim(trimmed, "|") + trimmed = strings.ReplaceAll(trimmed, "|", "") + if trimmed == "" { + return false + } + hasDash := false + for _, r := range trimmed { + switch r { + case '-', ':': + hasDash = true + case ' ', '\t': + default: + return false + } + } + return hasDash +} + +func normalizeTodoTitle(title string, maxLen int) string { + raw := strings.TrimSpace(title) + if raw == "" { + return "(empty)" + } + if maxLen <= 0 { + maxLen = todoTitleMaxDefault + } + raw = strings.ReplaceAll(raw, "\r\n", "\n") + raw = strings.ReplaceAll(raw, "\r", "\n") + lines := strings.Split(raw, "\n") + parts := make([]string, 0, len(lines)) + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || isMarkdownTableSeparatorLine(line) { + continue + } + if strings.Contains(line, "|") { + cells := strings.Split(strings.Trim(line, "|"), "|") + cleanCells := make([]string, 0, len(cells)) + for _, cell := range cells { + cell = strings.TrimSpace(cell) + if cell == "" || isMarkdownTableSeparatorLine(cell) { + continue + } + cleanCells = append(cleanCells, cell) + } + if len(cleanCells) > 0 { + line = strings.Join(cleanCells, " / ") + } + } + line = strings.Join(strings.Fields(line), " ") + if line != "" { + parts = append(parts, line) + } + } + if len(parts) == 0 { + return "(empty)" + } + joined := strings.Join(parts, " | ") + runes := []rune(joined) + if len(runes) <= maxLen { + return joined + } + if maxLen < 4 { + return string(runes[:maxLen]) + } + return string(runes[:maxLen-3]) + "..." +} + +func formatTodoStatusLabel(status string) string { + normalized := strings.ToUpper(strings.TrimSpace(status)) + switch normalized { + case "PENDING": + return "PENDING" + case "IN_PROGRESS": + return "ACTIVE" + case "BLOCKED": + return "BLOCKED" + case "COMPLETED": + return "DONE" + case "FAILED": + return "FAILED" + case "CANCELED": + return "CANCELED" + default: + if normalized == "" { + return "UNKNOWN" + } + return normalized + } +} + +func (a App) todoStatusStyle(status string) lipgloss.Style { + normalized := strings.ToUpper(strings.TrimSpace(status)) + switch normalized { + case "IN_PROGRESS", "COMPLETED": + return a.styles.badgeSuccess + case "BLOCKED": + return a.styles.badgeWarning + case "FAILED", "CANCELED": + return a.styles.badgeError + default: + return a.styles.badgeMuted + } +} + func clampTodoSelection(index int, length int) int { if length <= 0 { return 0 @@ -290,39 +400,81 @@ func (a *App) rebuildTodo() { visible := a.visibleTodoItems() a.todoSelectedIndex = clampTodoSelection(a.todoSelectedIndex, len(visible)) - lines := []string{ - "ID Title Status Executor Priority Owner Updated At", - } + lines := []string{a.styles.panelSubtitle.Render("State Task")} if len(visible) == 0 { lines = append(lines, fmt.Sprintf("No todos for filter %q.", a.todoFilter)) } else { + titleMax := todoTitleMaxDefault + if a.todo.Width > 0 { + titleMax = max(20, a.todo.Width-16) + } for i, item := range visible { prefix := " " if i == a.todoSelectedIndex { prefix = ">" } - title := item.Title - if title == "" { - title = "(empty)" + title := normalizeTodoTitle(item.Title, titleMax) + statusLabel := fmt.Sprintf("%-9s", formatTodoStatusLabel(item.Status)) + statusStyled := a.todoStatusStyle(item.Status).Render(statusLabel) + + titleStyle := lipgloss.NewStyle().Foreground(lipgloss.Color(lightText)) + if i == a.todoSelectedIndex { + titleStyle = titleStyle. + Bold(true). + Foreground(lipgloss.Color(selectionFg)). + Background(lipgloss.Color(selectionBg)) + } + + metaParts := make([]string, 0, 2) + if id := strings.TrimSpace(item.ID); id != "" { + metaParts = append(metaParts, "#"+id) + } + if item.Priority > 0 { + metaParts = append(metaParts, fmt.Sprintf("P%d", item.Priority)) } + meta := "" + if len(metaParts) > 0 { + meta = " " + a.styles.panelSubtitle.Render("("+strings.Join(metaParts, " · ")+")") + } + lines = append(lines, fmt.Sprintf( - "%s %s | %s | %s | %s | P%d | %s | %s", + "%s %s %s%s", prefix, - item.ID, - title, - item.Status, - fallbackText(item.Executor, "-"), - item.Priority, - item.Owner, - formatTodoUpdatedAt(item.UpdatedAt), + statusStyled, + titleStyle.Render(title), + meta, )) } + + selected := visible[clampTodoSelection(a.todoSelectedIndex, len(visible))] + details := make([]string, 0, 5) + if id := strings.TrimSpace(selected.ID); id != "" { + details = append(details, "id="+id) + } + if selected.Priority > 0 { + details = append(details, fmt.Sprintf("priority=%d", selected.Priority)) + } + if exec := strings.TrimSpace(selected.Executor); exec != "" && exec != "-" { + details = append(details, "executor="+exec) + } + if owner := strings.TrimSpace(selected.Owner); owner != "" && owner != "-" { + details = append(details, "owner="+owner) + } + if updated := formatTodoUpdatedAt(selected.UpdatedAt); updated != "-" { + details = append(details, "updated="+updated) + } + if len(details) > 0 { + lines = append(lines, a.styles.panelSubtitle.Render("Selected: "+strings.Join(details, " · "))) + } + lines = append( lines, - fmt.Sprintf( - "Selected %d/%d | Up/Down move | Enter detail | c collapse", - a.todoSelectedIndex+1, - len(visible), + a.styles.panelSubtitle.Render( + fmt.Sprintf( + "Selected %d/%d · Up/Down move · Enter detail · c collapse", + a.todoSelectedIndex+1, + len(visible), + ), ), ) } diff --git a/internal/tui/core/app/todo_test.go b/internal/tui/core/app/todo_test.go index 3ebd70db..159cf7b5 100644 --- a/internal/tui/core/app/todo_test.go +++ b/internal/tui/core/app/todo_test.go @@ -189,6 +189,31 @@ func TestRenderTodoPreviewAndEmptyRebuild(t *testing.T) { } } +func TestRebuildTodoSanitizesMarkdownTableLikeTitle(t *testing.T) { + app, _ := newTestApp(t) + app.todoPanelVisible = true + app.todoFilter = todoFilterAll + app.todo.Width = 100 + app.todo.Height = 10 + app.todoItems = []todoViewItem{ + { + ID: "todo-md", + Status: "pending", + Priority: 2, + Title: "| col1 | col2 |\n| --- | --- |\n| value-a | value-b |", + }, + } + + app.rebuildTodo() + view := app.todo.View() + if strings.Contains(view, "| --- |") { + t.Fatalf("expected markdown table separators to be sanitized, got %q", view) + } + if !strings.Contains(view, "col1 / col2") || !strings.Contains(view, "value-a / value-b") { + t.Fatalf("expected markdown table cells to be preserved as readable text, got %q", view) + } +} + func TestSetTodoFilterAndRebuild(t *testing.T) { app, _ := newTestApp(t) app.todo.Width = 100 diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 5a5c13b5..64aa2311 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -3,7 +3,9 @@ package tui import ( "bytes" "context" + "crypto/sha256" "encoding/json" + "encoding/hex" "errors" "fmt" "io" @@ -65,12 +67,17 @@ const pasteSessionMinGuard = 2 * time.Second const pasteSessionPerLineGuard = 8 * time.Millisecond const inlineLogMarker = "[[neo-log]] " const sessionWorkdirMissingWarning = "Session workspace not found, keeping current workspace." +const localLogViewerPersistDir = "log-viewer" type sessionLogPersistenceRuntime interface { LoadSessionLogEntries(ctx context.Context, sessionID string) ([]tuiservices.SessionLogEntry, error) SaveSessionLogEntries(ctx context.Context, sessionID string, entries []tuiservices.SessionLogEntry) error } +type localSessionLogStore struct { + baseDir string +} + var supportsUserEnvPersistence = config.SupportsUserEnvPersistence var persistProviderUserEnvVar = config.PersistUserEnvVar var deleteProviderUserEnvVar = config.DeleteUserEnvVar @@ -453,10 +460,13 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, batchUpdateCmds() } - switch a.focus { - case panelTranscript: - a.handleViewportKeys(&a.transcript, typed) - return a, batchUpdateCmds() + switch a.focus { + case panelTranscript: + if key.Matches(typed, a.keys.Send) && a.toggleTranscriptProcessExpansion() { + return a, batchUpdateCmds() + } + a.handleViewportKeys(&a.transcript, typed) + return a, batchUpdateCmds() case panelActivity: a.handleViewportKeys(&a.activity, typed) return a, batchUpdateCmds() @@ -2562,6 +2572,7 @@ func (a *App) applySessionSnapshot(session agentsession.Session, warnOnMissingWo a.setCurrentAgentMode(string(session.AgentMode)) a.syncSessionWorkdir(session.Workdir, warnOnMissingWorkdir) a.loadLogEntriesForSession(session.ID) + a.replayFoldRelatedSessionLogsIntoTranscript() a.refreshRuntimeSourceSnapshot() } @@ -2791,6 +2802,12 @@ var runtimeEventHandlerRegistry = map[tuiservices.EventType]func(*App, tuiservic tuiservices.EventRepoHooksLoaded: runtimeEventRepoHooksLoadedHandler, tuiservices.EventRepoHooksSkippedUntrusted: runtimeEventRepoHooksSkippedUntrustedHandler, tuiservices.EventRepoHooksTrustStoreInvalid: runtimeEventRepoHooksTrustStoreInvalidHandler, + tuiservices.EventCheckpointCreated: runtimeEventCheckpointCreatedHandler, + tuiservices.EventCheckpointWarning: runtimeEventCheckpointWarningHandler, + tuiservices.EventCheckpointRestored: runtimeEventCheckpointRestoredHandler, + tuiservices.EventCheckpointUndoRestore: runtimeEventCheckpointUndoRestoreHandler, + tuiservices.EventToolDiff: runtimeEventToolDiffHandler, + tuiservices.EventBashSideEffect: runtimeEventBashSideEffectHandler, tuiservices.EventSubAgentStarted: runtimeEventSubAgentLifecycleHandler, tuiservices.EventSubAgentProgress: runtimeEventSubAgentLifecycleHandler, tuiservices.EventSubAgentRetried: runtimeEventSubAgentLifecycleHandler, @@ -2965,6 +2982,170 @@ func runtimeEventRepoHooksTrustStoreInvalidHandler(a *App, event tuiservices.Run return false } +func runtimeEventCheckpointCreatedHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.CheckpointCreatedPayload) + if !ok { + return false + } + checkpointID := strings.TrimSpace(payload.CheckpointID) + if checkpointID == "" { + checkpointID = "(unknown)" + } + details := []string{ + "checkpoint_id=" + checkpointID, + } + if reason := strings.TrimSpace(payload.Reason); reason != "" { + details = append(details, "reason="+reason) + } + if commit := strings.TrimSpace(payload.CommitHash); commit != "" { + details = append(details, "commit="+commit) + } + if codeRef := strings.TrimSpace(payload.CodeCheckpointRef); codeRef != "" { + details = append(details, "code_ref="+codeRef) + } + a.appendActivity("checkpoint", "Checkpoint created", strings.Join(details, ", "), false) + return false +} + +func runtimeEventCheckpointWarningHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.CheckpointWarningPayload) + if !ok { + return false + } + message := strings.TrimSpace(payload.Error) + if message == "" { + message = "checkpoint warning" + } + if phase := strings.TrimSpace(payload.Phase); phase != "" { + message = phase + ": " + message + } + a.appendActivity("checkpoint", "Checkpoint warning", message, true) + return false +} + +func runtimeEventCheckpointRestoredHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.CheckpointRestoredPayload) + if !ok { + return false + } + if sessionID := strings.TrimSpace(payload.SessionID); sessionID != "" { + a.setActiveSessionID(sessionID) + } + detail := strings.TrimSpace(payload.CheckpointID) + if detail == "" { + detail = "(unknown)" + } + if guard := strings.TrimSpace(payload.GuardCheckpointID); guard != "" { + detail = fmt.Sprintf("%s (guard=%s)", detail, guard) + } + a.appendActivity("checkpoint", "Checkpoint restored", detail, false) + if err := a.refreshMessages(); err != nil && strings.TrimSpace(a.state.ActiveSessionID) != "" { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendInlineMessage(roleError, err.Error()) + a.appendActivity("checkpoint", "Failed to refresh session after restore", err.Error(), true) + return true + } + a.syncTodosFromRun() + a.refreshRuntimeSourceSnapshot() + a.state.ExecutionError = "" + a.state.StatusText = "Checkpoint restored" + return true +} + +func runtimeEventCheckpointUndoRestoreHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.CheckpointUndoRestorePayload) + if !ok { + return false + } + if sessionID := strings.TrimSpace(payload.SessionID); sessionID != "" { + a.setActiveSessionID(sessionID) + } + detail := strings.TrimSpace(payload.GuardCheckpointID) + if detail == "" { + detail = "restore guard checkpoint" + } + a.appendActivity("checkpoint", "Checkpoint restore undo", detail, false) + if err := a.refreshMessages(); err != nil && strings.TrimSpace(a.state.ActiveSessionID) != "" { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendInlineMessage(roleError, err.Error()) + a.appendActivity("checkpoint", "Failed to refresh session after undo", err.Error(), true) + return true + } + a.syncTodosFromRun() + a.refreshRuntimeSourceSnapshot() + a.state.ExecutionError = "" + a.state.StatusText = "Checkpoint restore undo applied" + return true +} + +func runtimeEventToolDiffHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.ToolDiffPayload) + if !ok { + return false + } + files := make([]string, 0, len(payload.Files)+1) + if len(payload.Files) > 0 { + for _, file := range payload.Files { + path := strings.TrimSpace(file.Path) + if path == "" { + continue + } + kind := strings.TrimSpace(file.Kind) + if kind == "" { + files = append(files, path) + } else { + files = append(files, fmt.Sprintf("%s(%s)", path, kind)) + } + } + } else { + path := strings.TrimSpace(payload.FilePath) + if path != "" { + if payload.WasNew { + files = append(files, path+"(added)") + } else { + files = append(files, path) + } + } + } + if len(files) == 0 { + files = append(files, "(no file paths)") + } + detail := fmt.Sprintf("tool=%s, files=%s", fallbackText(strings.TrimSpace(payload.ToolName), "unknown"), strings.Join(files, ", ")) + a.appendActivity("tool", "Tool diff captured", detail, false) + return false +} + +func runtimeEventBashSideEffectHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.BashSideEffectPayload) + if !ok { + return false + } + changes := make([]string, 0, len(payload.Changes)) + for _, file := range payload.Changes { + path := strings.TrimSpace(file.Path) + if path == "" { + continue + } + kind := strings.TrimSpace(file.Kind) + if kind == "" { + changes = append(changes, path) + } else { + changes = append(changes, fmt.Sprintf("%s(%s)", path, kind)) + } + } + if len(changes) == 0 { + changes = append(changes, "(no tracked changes)") + } + detail := fmt.Sprintf("changes=%s", strings.Join(changes, ", ")) + if len(payload.UncoveredPaths) > 0 { + detail += fmt.Sprintf("; uncovered=%s", strings.Join(payload.UncoveredPaths, ", ")) + } + a.appendActivity("tool", "Bash side effects detected", detail, false) + return false +} + // runtimeEventSubAgentLifecycleHandler 统一处理 subagent 生命周期事件并写入活动区/日志。 func runtimeEventSubAgentLifecycleHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.SubAgentEventPayload) @@ -4389,8 +4570,9 @@ func (a *App) appendActivity(kind string, title string, detail string, isError b } a.syncActivityViewport(previousCount) a.viewDirty = true - a.addLogEntry(kind, title, detail) - a.appendInlineMessage(roleSystem, formatActivityInlineLog(kind, title, detail)) + inline := formatActivityInlineLog(kind, title, detail) + a.addLogEntryWithInline(kind, title, detail, inline) + a.appendInlineMessage(roleSystem, inline) a.rebuildTranscript() } @@ -4456,6 +4638,10 @@ func (a *App) clearActivities() { } func (a *App) addLogEntry(kind string, title string, detail string) { + a.addLogEntryWithInline(kind, title, detail, "") +} + +func (a *App) addLogEntryWithInline(kind string, title string, detail string, inline string) { level := "info" if strings.Contains(title, "error") || strings.Contains(title, "Error") || strings.Contains(title, "failed") { level = "error" @@ -4468,6 +4654,7 @@ func (a *App) addLogEntry(kind string, title string, detail string) { Level: level, Source: kind, Message: title + ": " + detail, + Inline: strings.TrimSpace(inline), }) a.logEntries = clampLogEntries(a.logEntries) @@ -4627,6 +4814,12 @@ func (a *App) handleTranscriptMouse(msg tea.MouseMsg) bool { return false } + if msg.Button == tea.MouseButtonLeft && msg.Action == tea.MouseActionPress { + if a.toggleTranscriptProcessExpansionOnMouse(msg) { + return true + } + } + switch { case msg.Button == tea.MouseButtonLeft && msg.Action == tea.MouseActionPress: return a.beginTextSelection(msg) @@ -4645,6 +4838,52 @@ func (a *App) handleTranscriptMouse(msg tea.MouseMsg) bool { } } +func (a *App) toggleTranscriptProcessExpansionOnMouse(msg tea.MouseMsg) bool { + if !a.transcriptProcessFoldAvailable { + return false + } + line, ok := a.transcriptLineAtMouse(msg) + if !ok { + return false + } + line = ansi.Strip(strings.TrimSpace(line)) + lower := strings.ToLower(line) + if !strings.Contains(lower, "process output hidden") && !strings.Contains(lower, "process output expanded") { + return false + } + _, y, _, _ := a.transcriptBounds() + anchorRow := msg.Y - y + if anchorRow < 0 { + anchorRow = 0 + } + contentLine := a.transcript.YOffset + anchorRow + controlOrdinal := transcriptProcessControlOrdinalAtLine(a.transcriptContent, contentLine) + return a.toggleTranscriptProcessExpansionWithAnchor(anchorRow, controlOrdinal) +} + +func (a App) transcriptLineAtMouse(msg tea.MouseMsg) (string, bool) { + x, y, width, height := a.transcriptBounds() + if width <= 0 || height <= 0 { + return "", false + } + if msg.X < x || msg.X >= x+width || msg.Y < y || msg.Y >= y+height { + return "", false + } + bodyRow := msg.Y - y + if bodyRow < 0 { + return "", false + } + contentLine := a.transcript.YOffset + bodyRow + if contentLine < 0 { + return "", false + } + lines := strings.Split(a.transcriptContent, "\n") + if contentLine >= len(lines) { + return "", false + } + return lines[contentLine], true +} + func (a App) isMouseWithinTranscript(msg tea.MouseMsg) bool { x, y, width, height := a.transcriptBounds() if width <= 0 || height <= 0 { @@ -5057,10 +5296,63 @@ func (a *App) rebuildTranscript() { } atBottom := a.transcript.AtBottom() + foldSegments := findTranscriptProcessFoldSegments(a.activeMessages) + foldExists := len(foldSegments) > 0 + a.transcriptProcessFoldAvailable = foldExists + if !foldExists { + a.transcriptProcessExpanded = false + a.transcriptProcessExpandedOrdinal = -1 + } + if a.transcriptProcessExpanded { + if a.transcriptProcessExpandedOrdinal < 0 || a.transcriptProcessExpandedOrdinal >= len(foldSegments) { + a.transcriptProcessExpandedOrdinal = 0 + } + } + applyProcessFold := foldExists + foldControl := make(map[int]int, len(foldSegments)) + foldExpandedStart := make(map[int]bool, len(foldSegments)) + foldHidden := make(map[int]struct{}) + for segIdx, seg := range foldSegments { + foldControl[seg.Start] = seg.HiddenCount + segmentExpanded := a.transcriptProcessExpanded && segIdx == a.transcriptProcessExpandedOrdinal + foldExpandedStart[seg.Start] = segmentExpanded + if applyProcessFold && !segmentExpanded { + for idx := seg.Start; idx <= seg.End; idx++ { + if idx == seg.FinalAssistant { + continue + } + foldHidden[idx] = struct{}{} + } + } + } var builder strings.Builder hasBlock := false lastRenderedRole := "" - for _, message := range a.activeMessages { + for idx := 0; idx < len(a.activeMessages); idx++ { + if hiddenCount, exists := foldControl[idx]; exists { + control := "" + if foldExpandedStart[idx] { + control = a.renderTranscriptProcessExpandedBlock(width) + } else if applyProcessFold { + control = a.renderTranscriptProcessFoldBlock(width, hiddenCount) + } else { + control = a.renderTranscriptProcessExpandedBlock(width) + } + if strings.TrimSpace(control) != "" { + if hasBlock { + builder.WriteString("\n\n") + } + builder.WriteString(control) + hasBlock = true + lastRenderedRole = "" + } + } + if applyProcessFold { + if _, hidden := foldHidden[idx]; hidden { + continue + } + } + message := a.activeMessages[idx] inlineLog := isInlineLogMessage(message) continuation := message.Role == roleAssistant && lastRenderedRole == roleAssistant if inlineLog && lastRenderedRole == roleAssistant { @@ -5105,6 +5397,186 @@ func (a *App) rebuildTranscript() { } } +func (a *App) toggleTranscriptProcessExpansion() bool { + return a.toggleTranscriptProcessExpansionWithAnchor(-1, -1) +} + +func (a *App) toggleTranscriptProcessExpansionWithAnchor(anchorViewportRow int, controlOrdinal int) bool { + if !a.transcriptProcessFoldAvailable { + return false + } + if controlOrdinal >= 0 { + if a.transcriptProcessExpanded && a.transcriptProcessExpandedOrdinal == controlOrdinal { + a.transcriptProcessExpanded = false + a.transcriptProcessExpandedOrdinal = -1 + } else { + a.transcriptProcessExpanded = true + a.transcriptProcessExpandedOrdinal = controlOrdinal + } + } else { + if a.transcriptProcessExpanded { + a.transcriptProcessExpanded = false + a.transcriptProcessExpandedOrdinal = -1 + } else { + a.transcriptProcessExpanded = true + a.transcriptProcessExpandedOrdinal = 0 + } + } + if a.transcriptProcessExpanded { + a.state.StatusText = "Process output expanded" + } else { + a.state.StatusText = "Process output collapsed" + } + a.rebuildTranscript() + if anchorViewportRow >= 0 { + a.pinTranscriptProcessControlRow(anchorViewportRow, controlOrdinal) + } + return true +} + +func (a *App) pinTranscriptProcessControlRow(anchorViewportRow int, controlOrdinal int) { + if anchorViewportRow < 0 { + return + } + target := "process output hidden" + if a.transcriptProcessExpanded { + target = "process output expanded" + } + lines := strings.Split(a.transcriptContent, "\n") + targetLine := -1 + seen := 0 + for idx, line := range lines { + plain := strings.ToLower(ansi.Strip(strings.TrimSpace(line))) + if strings.Contains(plain, target) { + // Expanded mode has only one expanded-control line; collapsed mode may have many. + if a.transcriptProcessExpanded || controlOrdinal < 0 || seen == controlOrdinal { + targetLine = idx + break + } + seen++ + } + } + if targetLine < 0 { + return + } + desired := targetLine - anchorViewportRow + if desired < 0 { + desired = 0 + } + maxOffset := a.transcriptMaxOffset() + if desired > maxOffset { + desired = maxOffset + } + a.transcript.SetYOffset(desired) +} + +func transcriptProcessControlOrdinalAtLine(content string, contentLine int) int { + if contentLine < 0 { + return -1 + } + lines := strings.Split(content, "\n") + if contentLine >= len(lines) { + return -1 + } + ordinal := 0 + for idx := 0; idx <= contentLine; idx++ { + plain := strings.ToLower(ansi.Strip(strings.TrimSpace(lines[idx]))) + if strings.Contains(plain, "process output hidden") || strings.Contains(plain, "process output expanded") { + if idx == contentLine { + return ordinal + } + ordinal++ + } + } + return -1 +} + +func (a App) renderTranscriptProcessFoldBlock(width int, hiddenCount int) string { + if hiddenCount < 1 { + hiddenCount = 1 + } + detail := fmt.Sprintf("Process output hidden (%d messages).", hiddenCount) + if a.focus == panelTranscript { + detail += " Press Enter to expand." + } else { + detail += " Focus transcript and press Enter to expand." + } + message := providertypes.Message{ + Role: roleSystem, + Parts: []providertypes.ContentPart{providertypes.NewTextPart(detail)}, + } + rendered, _ := a.renderMessageBlockWithCopy(message, width, 0, true) + return rendered +} + +func (a App) renderTranscriptProcessExpandedBlock(width int) string { + detail := "Process output expanded. Click this line or press Enter to collapse." + message := providertypes.Message{ + Role: roleSystem, + Parts: []providertypes.ContentPart{providertypes.NewTextPart(detail)}, + } + rendered, _ := a.renderMessageBlockWithCopy(message, width, 0, true) + return rendered +} + +type transcriptProcessFoldSegment struct { + Start int + End int + FinalAssistant int + HiddenCount int +} + +func findTranscriptProcessFoldSegments(messages []providertypes.Message) []transcriptProcessFoldSegment { + segments := make([]transcriptProcessFoldSegment, 0, 4) + turnStart := 0 + buildSegment := func(start int, end int) { + if start < 0 || end < start || end >= len(messages) { + return + } + finalAssistant := -1 + for idx := end; idx >= start; idx-- { + msg := messages[idx] + if msg.Role != roleAssistant { + continue + } + if strings.TrimSpace(renderMessagePartsForDisplay(msg.Parts)) == "" { + continue + } + finalAssistant = idx + break + } + if finalAssistant < 0 { + return + } + hiddenCount := 0 + for idx := start; idx <= end; idx++ { + if idx == finalAssistant { + continue + } + hiddenCount++ + } + if hiddenCount < 1 { + return + } + segments = append(segments, transcriptProcessFoldSegment{ + Start: start, + End: end, + FinalAssistant: finalAssistant, + HiddenCount: hiddenCount, + }) + } + + for idx := 0; idx < len(messages); idx++ { + if messages[idx].Role != roleUser { + continue + } + buildSegment(turnStart, idx-1) + turnStart = idx + 1 + } + buildSegment(turnStart, len(messages)-1) + return segments +} + func (a *App) setTranscriptContent(content string) { normalized := normalizeTranscriptForDisplay(content) contentChanged := a.transcriptContent != normalized @@ -5241,6 +5713,8 @@ func (a *App) handleImmediateSlashCommand(input string) (bool, tea.Cmd) { return true, a.handleSkillsCommand() case slashCommandSkill: return true, a.handleSkillCommand(rest) + case slashCommandCheckpoint: + return true, a.handleCheckpointCommand(rest) case slashCommandSession: if err := a.ensureSessionSwitchAllowed(""); err != nil { a.state.ExecutionError = err.Error() @@ -5318,6 +5792,8 @@ func (a *App) startDraftSession() { a.startupScreenLocked = false a.state.ActiveSessionTitle = draftSessionTitle a.activeMessages = nil + a.transcriptProcessFoldAvailable = false + a.transcriptProcessExpanded = false a.clearActivities() a.clearTodos() a.state.IsCompacting = false @@ -5492,6 +5968,67 @@ func (a *App) readLogEntriesForSession(sessionID string) []logEntry { return clampLogEntries(fromRuntimeSessionLogEntries(entries)) } +func (a *App) replayFoldRelatedSessionLogsIntoTranscript() { + if len(a.logEntries) == 0 { + return + } + existing := make(map[string]struct{}, len(a.activeMessages)) + for _, message := range a.activeMessages { + if !isInlineLogMessage(message) { + continue + } + content := strings.TrimSpace(renderMessagePartsForDisplay(message.Parts)) + if content == "" { + continue + } + existing[content] = struct{}{} + } + for _, entry := range a.logEntries { + inline := strings.TrimSpace(entry.Inline) + if inline != "" && !strings.HasPrefix(inline, inlineLogMarker) { + inline = "" + } + if inline == "" { + if !isFoldRelatedSessionLogSource(entry.Source) { + continue + } + inline = formatSessionLogEntryInlineMessage(entry) + } + if inline == "" { + continue + } + if _, duplicated := existing[inline]; duplicated { + continue + } + a.activeMessages = append(a.activeMessages, providertypes.Message{ + Role: roleSystem, + Parts: []providertypes.ContentPart{providertypes.NewTextPart(inline)}, + }) + existing[inline] = struct{}{} + } +} + +func isFoldRelatedSessionLogSource(source string) bool { + switch strings.ToLower(strings.TrimSpace(source)) { + case "tool", "verify", "acceptance", "decision", "runtime", "facts", "subagent", "todo", "run", "checkpoint": + return true + default: + return false + } +} + +func formatSessionLogEntryInlineMessage(entry logEntry) string { + source := strings.TrimSpace(entry.Source) + if source == "" { + source = "log" + } + message := strings.TrimSpace(entry.Message) + if message == "" { + return "" + } + return inlineLogMarker + source + ": " + message +} + func (a *App) persistLogEntriesForActiveSession() { sessionID := strings.TrimSpace(a.state.ActiveSessionID) if sessionID == "" { @@ -5521,11 +6058,103 @@ func (a *App) persistLogEntriesForActiveSession() { func (a *App) sessionLogRuntime() sessionLogPersistenceRuntime { runtimeWithPersistence, ok := a.runtime.(sessionLogPersistenceRuntime) if !ok { - return nil + baseDir := "" + if a.configManager != nil { + baseDir = strings.TrimSpace(a.configManager.BaseDir()) + } + if baseDir == "" { + return nil + } + return localSessionLogStore{baseDir: baseDir} } return runtimeWithPersistence } +func (s localSessionLogStore) LoadSessionLogEntries(ctx context.Context, sessionID string) ([]tuiservices.SessionLogEntry, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + path, err := s.sessionLogEntriesPath(sessionID) + if err != nil || path == "" { + return nil, err + } + payload, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("tui: read session log entries: %w", err) + } + entries := make([]tuiservices.SessionLogEntry, 0) + if err := json.Unmarshal(payload, &entries); err != nil { + return nil, fmt.Errorf("tui: decode session log entries: %w", err) + } + return entries, nil +} + +func (s localSessionLogStore) SaveSessionLogEntries(ctx context.Context, sessionID string, entries []tuiservices.SessionLogEntry) error { + if err := ctx.Err(); err != nil { + return err + } + path, err := s.sessionLogEntriesPath(sessionID) + if err != nil || path == "" { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("tui: ensure session log directory: %w", err) + } + payload, err := json.Marshal(entries) + if err != nil { + return fmt.Errorf("tui: encode session log entries: %w", err) + } + if err := os.WriteFile(path, payload, 0o600); err != nil { + return fmt.Errorf("tui: write session log entries: %w", err) + } + return nil +} + +func (s localSessionLogStore) sessionLogEntriesPath(sessionID string) (string, error) { + normalizedSessionID := strings.TrimSpace(sessionID) + if normalizedSessionID == "" { + return "", nil + } + baseDir := strings.TrimSpace(s.baseDir) + if baseDir == "" { + return "", errors.New("tui: config base directory is empty") + } + sum := sha256.Sum256([]byte(normalizedSessionID)) + fileName := fmt.Sprintf("%s_%s.json", sanitizeLocalSessionLogPrefix(normalizedSessionID), hex.EncodeToString(sum[:8])) + return filepath.Join(baseDir, localLogViewerPersistDir, fileName), nil +} + +func sanitizeLocalSessionLogPrefix(sessionID string) string { + var b strings.Builder + for _, r := range sessionID { + switch { + case r >= 'a' && r <= 'z': + b.WriteRune(r) + case r >= 'A' && r <= 'Z': + b.WriteRune(r) + case r >= '0' && r <= '9': + b.WriteRune(r) + case r == '_' || r == '-': + b.WriteRune(r) + default: + if b.Len() > 0 { + b.WriteByte('_') + } + } + if b.Len() >= 24 { + break + } + } + prefix := strings.Trim(b.String(), "_") + if prefix == "" { + return "session" + } + return prefix +} + // reportLogPersistenceError 统一处理日志持久化失败提示,避免错误被静默吞掉。 func (a *App) reportLogPersistenceError(action string, err error) { if err == nil { @@ -5571,6 +6200,7 @@ func toRuntimeSessionLogEntries(entries []logEntry) []tuiservices.SessionLogEntr Level: entry.Level, Source: entry.Source, Message: entry.Message, + Inline: entry.Inline, }) } return converted @@ -5585,6 +6215,7 @@ func fromRuntimeSessionLogEntries(entries []tuiservices.SessionLogEntry) []logEn Level: entry.Level, Source: entry.Source, Message: entry.Message, + Inline: strings.TrimSpace(entry.Inline), }) } return converted diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index 9ca285fc..98a797b8 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -5,6 +5,7 @@ import ( "testing" providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" agentruntime "neo-code/internal/tui/services" ) @@ -210,6 +211,24 @@ func TestRuntimeEventHandlerRegistryContainsRenamedEvents(t *testing.T) { if _, ok := runtimeEventHandlerRegistry[agentruntime.EventRepoHooksTrustStoreInvalid]; !ok { t.Fatalf("expected repo_hooks_trust_store_invalid handler to be registered") } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventCheckpointCreated]; !ok { + t.Fatalf("expected checkpoint_created handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventCheckpointWarning]; !ok { + t.Fatalf("expected checkpoint_warning handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventCheckpointRestored]; !ok { + t.Fatalf("expected checkpoint_restored handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventCheckpointUndoRestore]; !ok { + t.Fatalf("expected checkpoint_undo_restore handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventToolDiff]; !ok { + t.Fatalf("expected tool_diff handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventBashSideEffect]; !ok { + t.Fatalf("expected bash_side_effect handler to be registered") + } if _, ok := runtimeEventHandlerRegistry[agentruntime.EventSubAgentStarted]; !ok { t.Fatalf("expected subagent_started handler to be registered") } @@ -447,6 +466,86 @@ func TestRuntimeEventRepoHookLifecycleHandlers(t *testing.T) { } } +func TestRuntimeEventCheckpointAndToolDiffHandlers(t *testing.T) { + app, runtime := newTestApp(t) + app.state.ActiveSessionID = "session-1" + runtime.loadSessions = map[string]agentsession.Session{ + "session-1": agentsession.NewWithWorkdir("session-1", ""), + } + + if runtimeEventCheckpointCreatedHandler(&app, agentruntime.RuntimeEvent{Payload: "bad"}) { + t.Fatalf("expected invalid checkpoint_created payload to return false") + } + runtimeEventCheckpointCreatedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.CheckpointCreatedPayload{ + CheckpointID: "cp-1", + Reason: "pre-write", + CommitHash: "abc123", + CodeCheckpointRef: "code-ref-1", + }, + }) + last := app.activities[len(app.activities)-1] + if last.Title != "Checkpoint created" || !strings.Contains(last.Detail, "cp-1") { + t.Fatalf("unexpected checkpoint created activity: %+v", last) + } + + runtimeEventCheckpointWarningHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.CheckpointWarningPayload{Phase: "persist", Error: "disk busy"}, + }) + last = app.activities[len(app.activities)-1] + if last.Title != "Checkpoint warning" || !last.IsError { + t.Fatalf("unexpected checkpoint warning activity: %+v", last) + } + + runtimeEventToolDiffHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.ToolDiffPayload{ + ToolName: "edit", + Files: []agentruntime.FileChange{ + {Path: "a.txt", Kind: "modified"}, + {Path: "b.txt", Kind: "added"}, + }, + }, + }) + last = app.activities[len(app.activities)-1] + if last.Title != "Tool diff captured" || !strings.Contains(last.Detail, "a.txt(modified)") { + t.Fatalf("unexpected tool diff activity: %+v", last) + } + + runtimeEventBashSideEffectHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.BashSideEffectPayload{ + Changes: []agentruntime.FileChange{{Path: "c.txt", Kind: "deleted"}}, + UncoveredPaths: []string{ + "tmp.log", + }, + }, + }) + last = app.activities[len(app.activities)-1] + if last.Title != "Bash side effects detected" || !strings.Contains(last.Detail, "tmp.log") { + t.Fatalf("unexpected bash side effect activity: %+v", last) + } + + runtimeEventCheckpointRestoredHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.CheckpointRestoredPayload{ + CheckpointID: "cp-restore", + SessionID: "session-1", + GuardCheckpointID: "cp-guard", + }, + }) + if app.state.StatusText != "Checkpoint restored" || app.state.ActiveSessionID != "session-1" { + t.Fatalf("expected checkpoint restored status/session update, got status=%q session=%q", app.state.StatusText, app.state.ActiveSessionID) + } + + runtimeEventCheckpointUndoRestoreHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.CheckpointUndoRestorePayload{ + GuardCheckpointID: "cp-guard", + SessionID: "session-1", + }, + }) + if app.state.StatusText != "Checkpoint restore undo applied" { + t.Fatalf("expected undo restore status, got %q", app.state.StatusText) + } +} + func TestRuntimeEventSubAgentHandlers(t *testing.T) { app, _ := newTestApp(t) if runtimeEventSubAgentLifecycleHandler(&app, agentruntime.RuntimeEvent{ diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 136111d1..da255f50 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -156,6 +156,10 @@ type stubRuntime struct { availableSkillsErr error } +type noLogPersistenceRuntime struct { + agentruntime.Runtime +} + type snapshotRuntime struct { *stubRuntime sessionContext any @@ -3399,13 +3403,10 @@ func TestHandleMemoCommandsRouteToSystemTools(t *testing.T) { func TestHandleMemoCommandMapsUnsupportedActionErrorToUserFriendlyMessage(t *testing.T) { tests := []struct { - name string - err error + name string + err error + expectGatewayUI bool }{ - { - name: "legacy sentinel", - err: agentruntime.ErrUnsupportedActionInGatewayMode, - }, { name: "gateway rpc unsupported_action", err: &agentruntime.GatewayRPCError{ @@ -3414,6 +3415,7 @@ func TestHandleMemoCommandMapsUnsupportedActionErrorToUserFriendlyMessage(t *tes GatewayCode: protocol.GatewayCodeUnsupportedAction, Message: "method not found", }, + expectGatewayUI: true, }, { name: "gateway rpc method_not_found", @@ -3422,6 +3424,7 @@ func TestHandleMemoCommandMapsUnsupportedActionErrorToUserFriendlyMessage(t *tes Code: protocol.JSONRPCCodeMethodNotFound, Message: "method not found", }, + expectGatewayUI: false, }, } @@ -3442,9 +3445,12 @@ func TestHandleMemoCommandMapsUnsupportedActionErrorToUserFriendlyMessage(t *tes if strings.Contains(status, "unsupported_action_in_gateway_mode") { t.Fatalf("expected sentinel to be hidden from UI, got %q", app.state.StatusText) } - if !strings.Contains(status, "gateway") { + if tt.expectGatewayUI && !strings.Contains(status, "gateway") { t.Fatalf("expected gateway upgrade hint, got %q", app.state.StatusText) } + if !tt.expectGatewayUI && strings.Contains(status, "gateway does not support memo commands") { + t.Fatalf("did not expect gateway unsupported hint, got %q", app.state.StatusText) + } }) } } @@ -5953,8 +5959,294 @@ func TestRebuildTranscriptCollapsesConsecutiveAssistantTags(t *testing.T) { if count := strings.Count(plain, messageTagAgent); count != 1 { t.Fatalf("expected one agent tag for consecutive assistant chunks, got %d in %q", count, plain) } - if !strings.Contains(plain, "first chunk") || !strings.Contains(plain, "second chunk") || !strings.Contains(plain, "third chunk") { - t.Fatalf("expected all assistant chunks to be present, got %q", plain) + if !strings.Contains(plain, "third chunk") { + t.Fatalf("expected final assistant chunk to stay visible, got %q", plain) + } + if strings.Contains(plain, "first chunk") || strings.Contains(plain, "second chunk") { + t.Fatalf("expected non-final assistant chunks to be folded, got %q", plain) + } +} + +func TestRebuildTranscriptAutoFoldsDuplicateAssistantProcess(t *testing.T) { + app, _ := newTestApp(t) + app.width = 120 + app.height = 32 + app.applyComponentLayout(true) + app.activeMessages = []providertypes.Message{ + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("final answer")}}, + {Role: roleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart(inlineLogMarker + "verify: verification completed")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("final answer")}}, + } + + app.rebuildTranscript() + plain := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + if !strings.Contains(plain, "Process output hidden") { + t.Fatalf("expected folded process placeholder, got %q", plain) + } + if strings.Contains(plain, "verification completed") { + t.Fatalf("expected process details hidden when folded, got %q", plain) + } + if count := strings.Count(plain, "final answer"); count != 1 { + t.Fatalf("expected exactly one visible final answer when folded, got %d in %q", count, plain) + } + if !app.transcriptProcessFoldAvailable { + t.Fatalf("expected transcript fold availability to be true") + } + if app.transcriptProcessExpanded { + t.Fatalf("expected transcript process to be collapsed by default") + } +} + +func TestTranscriptEnterTogglesFoldedProcessVisibility(t *testing.T) { + app, _ := newTestApp(t) + app.width = 120 + app.height = 32 + app.applyComponentLayout(true) + app.activeMessages = []providertypes.Message{ + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello world")}}, + {Role: roleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart(inlineLogMarker + "decision: accepted")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello world")}}, + } + app.rebuildTranscript() + app.focus = panelTranscript + app.applyFocus() + + model, _ := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + app = model.(App) + expandedPlain := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + if !app.transcriptProcessExpanded { + t.Fatalf("expected enter on transcript to expand folded process") + } + if !strings.Contains(expandedPlain, "decision: accepted") { + t.Fatalf("expected expanded transcript to show process details, got %q", expandedPlain) + } + if strings.Contains(expandedPlain, "Process output hidden") { + t.Fatalf("expected folded placeholder to disappear after expand, got %q", expandedPlain) + } + + model, _ = app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + app = model.(App) + collapsedPlain := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + if app.transcriptProcessExpanded { + t.Fatalf("expected second enter to collapse process output") + } + if !strings.Contains(collapsedPlain, "Process output hidden") { + t.Fatalf("expected folded placeholder after collapsing, got %q", collapsedPlain) + } +} + +func TestRebuildTranscriptFoldsWhenFinalContainsStreamingDraft(t *testing.T) { + app, _ := newTestApp(t) + app.width = 120 + app.height = 32 + app.applyComponentLayout(true) + app.activeMessages = []providertypes.Message{ + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("你好!我是 NeoCode,一个本地编程助手。")}}, + {Role: roleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart(inlineLogMarker + "verify: verification completed")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("你好!我是 NeoCode,一个本地编程助手。\n请问有什么具体任务需要我协助吗?")}}, + } + + app.rebuildTranscript() + plain := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + if !strings.Contains(plain, "Process output hidden") { + t.Fatalf("expected folded placeholder for overlapping final answer, got %q", plain) + } + if strings.Contains(plain, "verification completed") { + t.Fatalf("expected process details to be hidden in collapsed mode, got %q", plain) + } +} + +func TestRebuildTranscriptFoldsAllNonFinalOutputsInTurn(t *testing.T) { + app, _ := newTestApp(t) + app.width = 120 + app.height = 32 + app.applyComponentLayout(true) + app.activeMessages = []providertypes.Message{ + {Role: roleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("帮我分析代码")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("先看一下目录结构")}}, + {Role: roleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart(inlineLogMarker + "tool: list files")}}, + {Role: roleTool, Parts: []providertypes.ContentPart{providertypes.NewTextPart("main.go\nREADME.md")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("这是最终结论")}}, + } + + app.rebuildTranscript() + plain := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + if !strings.Contains(plain, "这是最终结论") { + t.Fatalf("expected final answer to remain visible, got %q", plain) + } + if strings.Contains(plain, "先看一下目录结构") || strings.Contains(plain, "tool: list files") || strings.Contains(plain, "main.go") { + t.Fatalf("expected non-final outputs in the turn to be folded, got %q", plain) + } + if !strings.Contains(plain, "Process output hidden") { + t.Fatalf("expected folded process placeholder, got %q", plain) + } +} + +func TestTranscriptMouseClickFoldPlaceholderExpandsProcess(t *testing.T) { + app, _ := newTestApp(t) + app.width = 120 + app.height = 32 + app.applyComponentLayout(true) + app.activeMessages = []providertypes.Message{ + {Role: roleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("分析一下")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("先检查项目结构")}}, + {Role: roleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart(inlineLogMarker + "tool: list files")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("最终回答")}}, + } + app.rebuildTranscript() + if !app.transcriptProcessFoldAvailable || app.transcriptProcessExpanded { + t.Fatalf("expected collapsed process fold before mouse click") + } + + lines := strings.Split(copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, ""), "\n") + targetRow := -1 + for i, line := range lines { + if strings.Contains(line, "Process output hidden") { + targetRow = i + break + } + } + if targetRow < 0 { + t.Fatalf("expected fold placeholder line in transcript content, got %q", app.transcriptContent) + } + + x, y, w, h := app.transcriptBounds() + if w <= 0 || h <= 0 { + t.Fatalf("expected transcript bounds to be drawable, got w=%d h=%d", w, h) + } + clickY := y + targetRow - app.transcript.YOffset + if clickY < y || clickY >= y+h { + t.Fatalf("expected fold placeholder row to be visible, row=%d y=%d h=%d", targetRow, y, h) + } + clickBodyRow := clickY - y + click := tea.MouseMsg{ + X: x + 1, + Y: clickY, + Button: tea.MouseButtonLeft, + Action: tea.MouseActionPress, + } + if !app.handleTranscriptMouse(click) { + t.Fatalf("expected fold-placeholder click to be handled") + } + if !app.transcriptProcessExpanded { + t.Fatalf("expected fold-placeholder click to expand process") + } + plain := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + if strings.Contains(plain, "Process output hidden") { + t.Fatalf("expected placeholder hidden after expand, got %q", plain) + } + if !strings.Contains(plain, "tool: list files") { + t.Fatalf("expected expanded transcript to reveal process details, got %q", plain) + } + if !strings.Contains(plain, "Process output expanded") { + t.Fatalf("expected expanded control line to be visible, got %q", plain) + } + if strings.Index(plain, "Process output expanded") < strings.Index(plain, "分析一下") { + t.Fatalf("expected expanded control to stay in-place near folded segment, got %q", plain) + } + + expandedLines := strings.Split(plain, "\n") + expandedRow := -1 + for i, line := range expandedLines { + if strings.Contains(line, "Process output expanded") { + expandedRow = i + break + } + } + if expandedRow < 0 { + t.Fatalf("expected expanded control row in transcript") + } + expandedBodyRow := expandedRow - app.transcript.YOffset + if expandedBodyRow != clickBodyRow { + t.Fatalf("expected expanded control anchor row to remain stable, got %d want %d", expandedBodyRow, clickBodyRow) + } + collapseClick := tea.MouseMsg{ + X: x + 1, + Y: y + expandedRow - app.transcript.YOffset, + Button: tea.MouseButtonLeft, + Action: tea.MouseActionPress, + } + if !app.handleTranscriptMouse(collapseClick) { + t.Fatalf("expected expanded-control click to be handled") + } + if app.transcriptProcessExpanded { + t.Fatalf("expected expanded-control click to collapse process") + } + collapsed := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + if !strings.Contains(collapsed, "Process output hidden") { + t.Fatalf("expected collapsed placeholder after clicking expanded control, got %q", collapsed) + } +} + +func TestTranscriptMouseClickKeepsClickedFoldSegmentAnchorWithMultipleSegments(t *testing.T) { + app, _ := newTestApp(t) + app.width = 120 + app.height = 32 + app.applyComponentLayout(true) + app.activeMessages = []providertypes.Message{ + {Role: roleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("任务A")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("A-过程")}}, + {Role: roleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart(inlineLogMarker + "A-log")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("A-最终")}}, + {Role: roleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("任务B")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("B-过程")}}, + {Role: roleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart(inlineLogMarker + "B-log")}}, + {Role: roleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("B-最终")}}, + } + app.rebuildTranscript() + plain := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + lines := strings.Split(plain, "\n") + placeholderRows := make([]int, 0, 2) + for i, line := range lines { + if strings.Contains(line, "Process output hidden") { + placeholderRows = append(placeholderRows, i) + } + } + if len(placeholderRows) < 2 { + t.Fatalf("expected at least two folded placeholders, got %q", plain) + } + targetRow := placeholderRows[1] + + x, y, w, h := app.transcriptBounds() + if w <= 0 || h <= 0 { + t.Fatalf("expected transcript bounds to be drawable, got w=%d h=%d", w, h) + } + clickY := y + targetRow - app.transcript.YOffset + if clickY < y || clickY >= y+h { + t.Fatalf("expected second fold placeholder row to be visible, row=%d y=%d h=%d", targetRow, y, h) + } + clickBodyRow := clickY - y + click := tea.MouseMsg{ + X: x + 1, + Y: clickY, + Button: tea.MouseButtonLeft, + Action: tea.MouseActionPress, + } + if !app.handleTranscriptMouse(click) { + t.Fatalf("expected second fold-placeholder click to be handled") + } + if !app.transcriptProcessExpanded { + t.Fatalf("expected click to expand process output") + } + + expanded := copyCodeANSIPattern.ReplaceAllString(app.transcriptContent, "") + expandedLines := strings.Split(expanded, "\n") + expandedRows := make([]int, 0, 1) + for i, line := range expandedLines { + if strings.Contains(line, "Process output expanded") { + expandedRows = append(expandedRows, i) + } + } + if len(expandedRows) != 1 { + t.Fatalf("expected exactly one expanded control line, got %q", expanded) + } + expandedBodyRow := expandedRows[0] - app.transcript.YOffset + diff := expandedBodyRow - clickBodyRow + if diff < -1 || diff > 1 { + t.Fatalf("expected clicked segment anchor row to remain stable, got %d want %d", expandedBodyRow, clickBodyRow) + } + if !strings.Contains(expanded, "B-过程") || strings.Contains(expanded, "A-过程") { + t.Fatalf("expected only clicked segment process details to expand, got %q", expanded) } } @@ -5970,6 +6262,7 @@ func TestTranscriptManualScrollPersistsWhileBusy(t *testing.T) { Parts: []providertypes.ContentPart{providertypes.NewTextPart(fmt.Sprintf("assistant-line-%03d", i))}, }) } + app.transcriptProcessExpanded = true app.rebuildTranscript() if app.transcriptMaxOffset() <= 6 { t.Fatalf("expected transcript to be scrollable, max offset=%d", app.transcriptMaxOffset()) @@ -6064,6 +6357,31 @@ func TestSessionLogViewerPersistenceAndCap(t *testing.T) { } } +func TestSessionLogFallbackPersistenceWithoutRuntimeLogInterface(t *testing.T) { + app, runtime := newTestApp(t) + app.runtime = noLogPersistenceRuntime{Runtime: runtime} + app.setActiveSessionID("session-fallback") + + app.appendActivity("verify", "Verification started", "completion_passed=true", false) + app.persistLogEntriesForActiveSession() + + app.logEntries = nil + app.loadLogEntriesForSession("session-fallback") + if len(app.logEntries) == 0 { + t.Fatalf("expected fallback store to reload persisted log entries") + } + foundInline := false + for _, entry := range app.logEntries { + if strings.Contains(entry.Inline, inlineLogMarker+"verify: Verification started | completion_passed=true") { + foundInline = true + break + } + } + if !foundInline { + t.Fatalf("expected persisted inline_message in fallback log store, got %#v", app.logEntries) + } +} + func TestSanitizeProviderAddJSONInputRunes(t *testing.T) { input := []rune{'a', '\u200b', '\n', '\t', '\r', 0x01, 'b'} got := sanitizeProviderAddJSONInputRunes(input) diff --git a/internal/tui/core/app/view.go b/internal/tui/core/app/view.go index 756ebd25..202164c5 100644 --- a/internal/tui/core/app/view.go +++ b/internal/tui/core/app/view.go @@ -357,7 +357,7 @@ func (a App) renderPicker(width int, height int) string { if a.state.ActivePicker == pickerProviderAdd { title = providerAddTitle subtitle = providerAddSubtitle - body = a.renderProviderAddForm() + body = a.renderProviderAddForm(max(32, width-6)) } if a.state.ActivePicker == pickerModelScope { title = modelScopeGuideTitle @@ -430,14 +430,14 @@ func (a App) renderModelScopeGuide() string { return sb.String() } -func (a App) renderProviderAddForm() string { +func (a App) renderProviderAddForm(bodyWidth int) string { if a.providerAddForm == nil { return "No form active" } if a.providerAddForm.Stage == providerAddFormStageManualModels { var sb strings.Builder - sb.WriteString("Manual Model JSON(id/name 必填)\n") - sb.WriteString("[Shift+Tab] 返回字段页 [Enter] 提交 [Esc] 取消\n\n") + sb.WriteString("Manual Model JSON (id/name required)\n") + sb.WriteString("[Shift+Tab] back to fields [Enter] confirm [Esc] cancel\n\n") content := strings.TrimSpace(a.providerAddForm.ManualModelsJSON) if content == "" { content = providerAddManualModelsJSONTemplate @@ -471,29 +471,41 @@ func (a App) renderProviderAddForm() string { for _, fieldID := range visible { switch fieldID { case providerAddFieldName: - fields = append(fields, renderField{label: "Name", value: a.providerAddForm.Name, required: true}) + fields = append(fields, renderField{ + label: "Name", + value: a.providerAddForm.Name, + required: true, + note: "Unique local provider name. Use letters/numbers/-/_. Example: team-gateway.", + }) case providerAddFieldDriver: - fields = append(fields, renderField{label: "Driver", value: a.providerAddForm.Driver, required: true}) + fields = append(fields, renderField{ + label: "Driver", + value: a.providerAddForm.Driver, + required: true, + note: "Protocol adapter. openaicompat for most gateways; gemini/anthropic for native APIs.", + }) case providerAddFieldModelSource: - note := "discover: 远端发现模型;manual: 手工 JSON 模型列表" fields = append(fields, renderField{ label: "Model Source", value: a.providerAddForm.ModelSource, required: true, - note: note, + note: "discover = fetch models from remote endpoint; manual = paste custom model JSON in next step.", }) case providerAddFieldChatAPIMode: - note := "仅 openaicompat 生效;chat_completions 或 responses" fields = append(fields, renderField{ label: "Chat API Mode", value: a.providerAddForm.ChatAPIMode, - note: note, + note: "openaicompat only. chat_completions uses /chat/completions; responses uses /responses style.", }) case providerAddFieldBaseURL: note := "" if strings.TrimSpace(a.providerAddForm.BaseURL) == "" && (driver == provider.DriverOpenAICompat || driver == provider.DriverGemini || driver == provider.DriverAnthropic) { - note = "留空会自动填充默认地址" + note = "Server base address. Empty = built-in default for this driver." + } else if baseURLRequired { + note = "Required for custom drivers. Example: https://api.example.com/v1" + } else { + note = "Override the default base URL for this driver." } fields = append(fields, renderField{ label: "Base URL", @@ -505,44 +517,71 @@ func (a App) renderProviderAddForm() string { note := "" trimmedPath := strings.TrimSpace(a.providerAddForm.ChatEndpointPath) if trimmedPath == "" { - note = "留空会按 Chat API Mode 自动回填默认端点" + note = "Chat endpoint path. Empty = auto default from driver/mode." } else if trimmedPath == "/" { - note = "\"/\" 使用直连 base_url" + note = "\"/\" means call base URL directly (no extra path)." } else { - note = "以 \"/\" 开头的端点路径" + note = "Must start with '/'. Example: /chat/completions" } fields = append(fields, renderField{label: "Chat Endpoint", value: a.providerAddForm.ChatEndpointPath, note: note}) case providerAddFieldDiscoveryEndpointPath: note := "" if strings.TrimSpace(a.providerAddForm.DiscoveryEndpointPath) == "" { - note = "OpenAI-compatible 默认 /models" + note = "Used by discover mode to fetch model list. Empty = /models." + } else { + note = "Path used for remote model discovery. Usually /models." } + fields = append(fields, renderField{label: "Discovery Endpoint", value: a.providerAddForm.DiscoveryEndpointPath, note: note}) + case providerAddFieldAPIKeyEnv: fields = append(fields, renderField{ - label: "Discovery Endpoint", - value: a.providerAddForm.DiscoveryEndpointPath, - note: note, + label: "API Key Env", + value: a.providerAddForm.APIKeyEnv, + required: true, + note: "Environment variable name to store key. Example: OPENAI_API_KEY. Must be a valid env var name.", }) - case providerAddFieldAPIKeyEnv: - fields = append(fields, renderField{label: "API Key Env", value: a.providerAddForm.APIKeyEnv, required: true}) case providerAddFieldAPIKey: - fields = append(fields, renderField{label: "API Key", value: maskedSecret(a.providerAddForm.APIKey), required: true}) + fields = append(fields, renderField{ + label: "API Key", + value: maskedSecret(a.providerAddForm.APIKey), + required: true, + note: "Secret token for provider auth. Input is masked and applied to current process env.", + }) + } + } + + labelWidth := 0 + for _, field := range fields { + displayLabel := field.label + if field.required { + displayLabel += " *" } + labelWidth = max(labelWidth, len(displayLabel)) + } + if labelWidth < 8 { + labelWidth = 8 } + noteWidth := max(20, bodyWidth-labelWidth-8) + currentHint := "" + currentHintLabel := "" for i, field := range fields { prefix := " " if i == a.providerAddForm.Step { prefix = "> " + if strings.TrimSpace(field.note) != "" { + currentHint = strings.TrimSpace(field.note) + currentHintLabel = field.label + } } - sb.WriteString(prefix + field.label + ": ") - sb.WriteString(field.value) + displayLabel := field.label if field.required { - sb.WriteString(" *") + displayLabel += " *" } - if field.note != "" { - sb.WriteString(" (" + field.note + ")") + value := strings.TrimSpace(field.value) + if value == "" { + value = "-" } - sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s%-*s : %s\n", prefix, labelWidth, displayLabel, value)) } if a.providerAddForm.Error != "" { @@ -553,7 +592,20 @@ func (a App) renderProviderAddForm() string { sb.WriteString("\n" + label + " " + a.providerAddForm.Error + "\n") } - sb.WriteString("\n[Tab] switch field [Up/Down or K/J] change option [Enter] confirm [Esc] cancel") + if currentHint != "" { + sb.WriteString(fmt.Sprintf("\nHint (%s):\n", currentHintLabel)) + wrapped := wrapPlain(currentHint, noteWidth) + for _, line := range strings.Split(wrapped, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + sb.WriteString(" " + line + "\n") + } + } + + sb.WriteString("\n* required\n") + sb.WriteString("[Tab/Shift+Tab] switch field [Up/Down or K/J] change option [Enter] confirm [Esc] cancel") return sb.String() } diff --git a/internal/tui/core/app/view_test.go b/internal/tui/core/app/view_test.go index 3d1f756e..0c10a77a 100644 --- a/internal/tui/core/app/view_test.go +++ b/internal/tui/core/app/view_test.go @@ -634,22 +634,23 @@ func TestRenderProviderAddFormMasksAPIKeyAndShowsHints(t *testing.T) { app.providerAddForm.ChatEndpointPath = "" app.providerAddForm.Error = "input invalid" app.providerAddForm.ErrorIsHard = true + app.providerAddForm.Step = 4 // Base URL - form := app.renderProviderAddForm() + form := app.renderProviderAddForm(72) if strings.Contains(form, "sk-secret-98765") { t.Fatalf("expected api key to be masked, got %q", form) } - if !strings.Contains(form, "API Key: ******") { + if !strings.Contains(form, "API Key") || !strings.Contains(form, "******") { t.Fatalf("expected masked api key, got %q", form) } - if !strings.Contains(form, "Model Source: discover") { + if !strings.Contains(form, "Model Source") || !strings.Contains(form, "discover") { t.Fatalf("expected model source field, got %q", form) } - if !strings.Contains(form, "Base URL: (") { + if !strings.Contains(form, "Hint (Base URL):") || !strings.Contains(form, "Server base address.") || !strings.Contains(form, "built-in default") { t.Fatalf("expected base url hint, got %q", form) } - if !strings.Contains(form, "Chat Endpoint: (") || !strings.Contains(form, "Chat API Mode") { - t.Fatalf("expected chat endpoint auto-fill hint, got %q", form) + if !strings.Contains(form, "Chat API Mode") { + t.Fatalf("expected chat api mode field, got %q", form) } if !strings.Contains(form, "[Error] input invalid") { t.Fatalf("expected hard error label, got %q", form) @@ -663,7 +664,7 @@ func TestRenderProviderAddFormPromptLabel(t *testing.T) { app.providerAddForm.Error = "continue input" app.providerAddForm.ErrorIsHard = false - form := app.renderProviderAddForm() + form := app.renderProviderAddForm(72) if !strings.Contains(form, "[Prompt] continue input") { t.Fatalf("expected prompt label, got %q", form) } @@ -675,7 +676,7 @@ func TestRenderProviderAddFormManualModelsStage(t *testing.T) { app.providerAddForm.Stage = providerAddFormStageManualModels app.providerAddForm.ManualModelsJSON = "" - form := app.renderProviderAddForm() + form := app.renderProviderAddForm(72) if !strings.Contains(form, "Manual Model JSON") { t.Fatalf("expected manual model json title, got %q", form) } @@ -969,12 +970,12 @@ func TestRenderMessageBlockWithCopyExtraBranches(t *testing.T) { func TestRenderProviderAddFormNoFormAndChatEndpointField(t *testing.T) { app, _ := newTestApp(t) - if got := app.renderProviderAddForm(); got != "No form active" { + if got := app.renderProviderAddForm(72); got != "No form active" { t.Fatalf("unexpected no-form output: %q", got) } app.startProviderAddForm() - form := app.renderProviderAddForm() + form := app.renderProviderAddForm(72) if !strings.Contains(form, "Chat Endpoint") { t.Fatalf("expected chat endpoint field in add form") } diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index cf403715..7aaae045 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -152,44 +152,30 @@ func decodeRuntimeEventFromGatewayNotification(notification gatewayRPCNotificati return event, nil } -// extractRuntimeEnvelope 从网关 payload 中提取事件包裹层。 +// extractRuntimeEnvelope 从网关 payload 中提取 runtime envelope。 +// 支持两种结构: +// 1) 直接 envelope: {"runtime_event_type": "...", ...} +// 2) gateway 包裹层: {"event_type": "...", "payload": {"runtime_event_type": "...", ...}} func extractRuntimeEnvelope(payload any) (map[string]any, bool) { - switch typed := payload.(type) { - case map[string]any: - if _, exists := streamReadMapValue(typed, "runtime_event_type"); exists { - return typed, true - } - if nested, exists := streamReadMapValue(typed, "payload"); exists { - if nestedMap, ok := nested.(map[string]any); ok { - if _, hasEventType := streamReadMapValue(nestedMap, "runtime_event_type"); hasEventType { - return nestedMap, true - } - } - } - case nil: - return nil, false - } - - raw, err := json.Marshal(payload) - if err != nil { - return nil, false - } - - var asMap map[string]any - if err := json.Unmarshal(raw, &asMap); err != nil { + typed, ok := payload.(map[string]any) + if !ok { return nil, false } - if _, exists := streamReadMapValue(asMap, "runtime_event_type"); exists { - return asMap, true - } - if nested, exists := streamReadMapValue(asMap, "payload"); exists { - if nestedMap, ok := nested.(map[string]any); ok { - if _, hasEventType := streamReadMapValue(nestedMap, "runtime_event_type"); hasEventType { - return nestedMap, true - } + if _, exists := streamReadMapValue(typed, "runtime_event_type"); !exists { + nested, nestedExists := streamReadMapValue(typed, "payload") + if !nestedExists || nested == nil { + return nil, false + } + nestedMap, nestedOK := nested.(map[string]any) + if !nestedOK { + return nil, false } + if _, nestedEnvelopeExists := streamReadMapValue(nestedMap, "runtime_event_type"); !nestedEnvelopeExists { + return nil, false + } + return nestedMap, true } - return nil, false + return typed, true } // restoreRuntimePayload 按事件类型将 payload 恢复为 TUI 可消费的强类型结构。 @@ -243,6 +229,18 @@ func restoreRuntimePayload(eventType EventType, payload any) (any, error) { return decodeRuntimePayload[RepoHooksLifecyclePayload](payload) case EventRepoHooksTrustStoreInvalid: return decodeRuntimePayload[RepoHooksTrustStoreInvalidPayload](payload) + case EventCheckpointCreated: + return decodeRuntimePayload[CheckpointCreatedPayload](payload) + case EventCheckpointWarning: + return decodeRuntimePayload[CheckpointWarningPayload](payload) + case EventCheckpointRestored: + return decodeRuntimePayload[CheckpointRestoredPayload](payload) + case EventCheckpointUndoRestore: + return decodeRuntimePayload[CheckpointUndoRestorePayload](payload) + case EventToolDiff: + return decodeRuntimePayload[ToolDiffPayload](payload) + case EventBashSideEffect: + return decodeRuntimePayload[BashSideEffectPayload](payload) case EventTodoUpdated, EventTodoConflict: return decodeRuntimePayload[TodoEventPayload](payload) case EventTodoSnapshotUpdated: diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go index 49caa6d7..97012e73 100644 --- a/internal/tui/services/gateway_stream_client_additional_test.go +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -12,14 +12,6 @@ import ( providertypes "neo-code/internal/provider/types" ) -type streamInvalidJSONMarshaler struct { - raw []byte -} - -func (m streamInvalidJSONMarshaler) MarshalJSON() ([]byte, error) { - return m.raw, nil -} - func TestDecodeRuntimeEventFromGatewayNotificationErrorBranches(t *testing.T) { t.Parallel() @@ -91,26 +83,44 @@ func TestDecodeRuntimeEventFromGatewayNotificationRejectsPayloadVersionMismatch( } } -func TestExtractRuntimeEnvelopeFallbackMarshalling(t *testing.T) { +func TestExtractRuntimeEnvelopeSupportsGatewayWrappedPayload(t *testing.T) { t.Parallel() type payloadEnvelope struct { Payload map[string]any `json:"payload"` } - envelope, ok := extractRuntimeEnvelope(payloadEnvelope{Payload: map[string]any{ + if _, ok := extractRuntimeEnvelope(payloadEnvelope{Payload: map[string]any{ "RuntimeEventType": string(EventError), "payload": "x", - }}) - if !ok { - t.Fatalf("expected envelope to be detected") - } - if got := streamReadMapString(envelope, "runtime_event_type"); got != string(EventError) { - t.Fatalf("runtime_event_type = %q", got) + }}); ok { + t.Fatalf("struct payload should not be treated as runtime envelope") } if _, ok := extractRuntimeEnvelope(nil); ok { t.Fatalf("nil payload should not decode") } + + if _, ok := extractRuntimeEnvelope(map[string]any{ + "payload_version": runtimeEventPayloadVersion, + "payload": "x", + }); ok { + t.Fatalf("map without runtime_event_type should not decode") + } + + envelope, ok := extractRuntimeEnvelope(map[string]any{ + "event_type": "run_progress", + "payload": map[string]any{ + "runtime_event_type": string(EventAgentChunk), + "payload_version": runtimeEventPayloadVersion, + "payload": "chunk", + }, + }) + if !ok { + t.Fatalf("expected wrapped runtime envelope to decode") + } + if streamReadMapString(envelope, "runtime_event_type") != string(EventAgentChunk) { + t.Fatalf("runtime_event_type = %q, want %q", streamReadMapString(envelope, "runtime_event_type"), EventAgentChunk) + } } func TestRestoreRuntimePayloadCoversSpecializedTypes(t *testing.T) { @@ -410,6 +420,12 @@ func TestRestoreRuntimePayloadAdditionalBranches(t *testing.T) { {eventType: EventInputNormalized, payload: map[string]any{"text_length": 3}}, {eventType: EventAssetSaved, payload: map[string]any{"asset_id": "asset-1"}}, {eventType: EventAssetSaveFailed, payload: map[string]any{"message": "x"}}, + {eventType: EventCheckpointCreated, payload: map[string]any{"checkpoint_id": "cp-1"}}, + {eventType: EventCheckpointWarning, payload: map[string]any{"error": "warn"}}, + {eventType: EventCheckpointRestored, payload: map[string]any{"checkpoint_id": "cp-1", "session_id": "s-1"}}, + {eventType: EventCheckpointUndoRestore, payload: map[string]any{"guard_checkpoint_id": "cp-guard", "session_id": "s-1"}}, + {eventType: EventToolDiff, payload: map[string]any{"tool_call_id": "call-1", "tool_name": "edit", "file_path": "a.txt"}}, + {eventType: EventBashSideEffect, payload: map[string]any{"tool_call_id": "call-2", "changes": []map[string]any{{"path": "a.txt", "kind": "modified"}}}}, {eventType: EventTodoUpdated, payload: map[string]any{"action": "replace"}}, {eventType: EventTodoConflict, payload: map[string]any{"action": "conflict"}}, {eventType: EventTodoSnapshotUpdated, payload: map[string]any{"action": "snapshot"}}, @@ -534,16 +550,10 @@ func TestGatewayStreamDecodeAndEnvelopeExtraBranches(t *testing.T) { t.Fatalf("expected restore payload decode error") } - if _, ok := extractRuntimeEnvelope(streamInvalidJSONMarshaler{raw: []byte("{")}); ok { - t.Fatalf("expected marshal error path to fail envelope extraction") - } - if _, ok := extractRuntimeEnvelope(streamInvalidJSONMarshaler{raw: []byte("[]")}); ok { - t.Fatalf("expected unmarshal-to-map error path to fail envelope extraction") - } if envelope, ok := extractRuntimeEnvelope(struct { RuntimeEventType string `json:"runtime_event_type"` - }{RuntimeEventType: string(EventError)}); !ok || streamReadMapString(envelope, "runtime_event_type") == "" { - t.Fatalf("expected runtime_event_type detection after marshal/unmarshal") + }{RuntimeEventType: string(EventError)}); ok || streamReadMapString(envelope, "runtime_event_type") != "" { + t.Fatalf("non-map payload should not decode as envelope") } if got := streamReadMapString(map[string]any{"v": 123}, "v"); got != "123" { diff --git a/internal/tui/services/gateway_stream_client_test.go b/internal/tui/services/gateway_stream_client_test.go index 11470cdd..72311c85 100644 --- a/internal/tui/services/gateway_stream_client_test.go +++ b/internal/tui/services/gateway_stream_client_test.go @@ -293,14 +293,14 @@ func TestDecodeRuntimeEventFromGatewayNotificationRestoresSubAgentPayloads(t *te } } -func TestDecodeRuntimeEventFromGatewayNotificationSupportsNestedEnvelope(t *testing.T) { +func TestDecodeRuntimeEventFromGatewayNotificationAcceptsGatewayWrappedEnvelope(t *testing.T) { notification := buildGatewayEventNotification(t, gateway.MessageFrame{ Type: gateway.FrameTypeEvent, Action: gateway.FrameActionRun, SessionID: "session-3", RunID: "run-3", Payload: map[string]any{ - "type": "run_progress", + "event_type": "run_progress", "payload": map[string]any{ "runtime_event_type": string(EventError), "payload_version": runtimeEventPayloadVersion, @@ -316,7 +316,8 @@ func TestDecodeRuntimeEventFromGatewayNotificationSupportsNestedEnvelope(t *test if event.Type != EventError { t.Fatalf("event.Type = %q, want %q", event.Type, EventError) } - if payload, ok := event.Payload.(string); !ok || payload != "boom" { + payload, ok := event.Payload.(string) + if !ok || payload != "boom" { t.Fatalf("event.Payload = %#v, want %q", event.Payload, "boom") } } diff --git a/internal/tui/services/remote_runtime_adapter.go b/internal/tui/services/remote_runtime_adapter.go index 1f479250..9aaaa2e0 100644 --- a/internal/tui/services/remote_runtime_adapter.go +++ b/internal/tui/services/remote_runtime_adapter.go @@ -21,6 +21,8 @@ import ( const ( unsupportedActionInGatewayMode = "unsupported_action_in_gateway_mode" defaultRemoteRuntimeTimeout = 8 * time.Second + startupProbeSessionPrefix = "session-startup-probe" + startupProbeRunPrefix = "run-startup-probe" ) var ( @@ -96,6 +98,10 @@ func NewRemoteRuntimeAdapter(options RemoteRuntimeAdapterOptions) (*RemoteRuntim _ = adapter.Close() return nil, err } + if err := adapter.startupHandshake(ctx); err != nil { + _ = adapter.Close() + return nil, err + } return adapter, nil } @@ -122,6 +128,7 @@ func newRemoteRuntimeAdapterWithClients( // Submit 将用户输入提交到网关:先 authenticate,再 bindStream,随后 loadSession,最后 run。 func (r *RemoteRuntimeAdapter) Submit(ctx context.Context, input PrepareInput) error { sessionID := strings.TrimSpace(input.SessionID) + requestNewSession := sessionID == "" if sessionID == "" { sessionID = agentsession.NewID("session") } @@ -136,11 +143,13 @@ func (r *RemoteRuntimeAdapter) Submit(ctx context.Context, input PrepareInput) e if err := r.bindStream(ctx, sessionID, runID); err != nil { return err } - if err := r.preloadSession(ctx, sessionID); err != nil { - return err + if !requestNewSession { + if err := r.preloadSession(ctx, sessionID); err != nil { + return err + } } - params := buildGatewayRunParams(sessionID, runID, input) + params := buildGatewayRunParams(sessionID, runID, requestNewSession, input) frame, err := r.callFrame(ctx, protocol.MethodGatewayRun, params, GatewayRPCCallOptions{ Timeout: r.timeout, Retries: 0, @@ -153,57 +162,12 @@ func (r *RemoteRuntimeAdapter) Submit(ctx context.Context, input PrepareInput) e if ackRunID == "" { ackRunID = runID } - r.setActiveRun(ackRunID, sessionID) - return nil -} - -// PrepareUserInput 在 gateway 模式下提供最小可用输入归一化结果,保持接口兼容。 -func (r *RemoteRuntimeAdapter) PrepareUserInput(ctx context.Context, input PrepareInput) (UserInput, error) { - if err := ctx.Err(); err != nil { - return UserInput{}, err - } - - sessionID := strings.TrimSpace(input.SessionID) - if sessionID == "" { - sessionID = agentsession.NewID("session") - } - runID := strings.TrimSpace(input.RunID) - if runID == "" { - runID = fmt.Sprintf("run-%d", time.Now().UnixNano()) - } - - parts := make([]providertypes.ContentPart, 0, 1+len(input.Images)) - if strings.TrimSpace(input.Text) != "" { - parts = append(parts, providertypes.NewTextPart(input.Text)) - } - for _, image := range input.Images { - path := strings.TrimSpace(image.Path) - if path == "" { - continue - } - parts = append(parts, providertypes.NewRemoteImagePart(path)) + ackSessionID := strings.TrimSpace(frame.SessionID) + if ackSessionID == "" { + ackSessionID = sessionID } - - return UserInput{ - SessionID: sessionID, - RunID: runID, - Parts: parts, - Workdir: strings.TrimSpace(input.Workdir), - Mode: strings.TrimSpace(input.Mode), - }, nil -} - -// Run 保持 runtime 接口兼容,在 gateway 模式下回落到 Submit 通道。 -func (r *RemoteRuntimeAdapter) Run(ctx context.Context, input UserInput) error { - prepareInput := PrepareInput{ - SessionID: strings.TrimSpace(input.SessionID), - RunID: strings.TrimSpace(input.RunID), - Workdir: strings.TrimSpace(input.Workdir), - Mode: strings.TrimSpace(input.Mode), - Text: renderInputTextFromParts(input.Parts), - Images: renderInputImagesFromParts(input.Parts), - } - return r.Submit(ctx, prepareInput) + r.setActiveRun(ackRunID, ackSessionID) + return nil } // Compact 转发 gateway.compact 请求并映射回 runtime CompactResult。 @@ -468,6 +432,245 @@ func (r *RemoteRuntimeAdapter) ListAvailableSkills( return mapGatewayAvailableSkillStates(payload.Skills), nil } +// ListModels 转发 gateway.listModels,并返回会话模型列表与 selected_model_id。 +func (r *RemoteRuntimeAdapter) ListModels( + ctx context.Context, + sessionID string, +) ([]providertypes.ModelDescriptor, string, error) { + if err := r.authenticate(ctx); err != nil { + return nil, "", err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayListModels, protocol.ListModelsParams{ + SessionID: strings.TrimSpace(sessionID), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return nil, "", err + } + + payload := struct { + Models []gateway.ModelEntry `json:"models"` + SelectedModelID string `json:"selected_model_id"` + }{} + if err := decodeIntoValue(frame.Payload, &payload); err != nil { + return nil, "", err + } + + models := make([]providertypes.ModelDescriptor, 0, len(payload.Models)) + for _, item := range payload.Models { + modelID := strings.TrimSpace(item.ID) + if modelID == "" { + continue + } + modelName := strings.TrimSpace(item.Name) + if modelName == "" { + modelName = modelID + } + models = append(models, providertypes.ModelDescriptor{ + ID: modelID, + Name: modelName, + }) + } + return models, strings.TrimSpace(payload.SelectedModelID), nil +} + +// ListCheckpoints 转发 checkpoint.list。 +func (r *RemoteRuntimeAdapter) ListCheckpoints( + ctx context.Context, + input CheckpointListInput, +) ([]CheckpointEntry, error) { + if err := r.authenticate(ctx); err != nil { + return nil, err + } + params := protocol.ListCheckpointsParams{ + SessionID: strings.TrimSpace(input.SessionID), + Limit: input.Limit, + RestorableOnly: input.RestorableOnly, + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayListCheckpoints, params, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return nil, err + } + return decodeFramePayload[[]CheckpointEntry](frame.Payload) +} + +// RestoreCheckpoint 转发 checkpoint.restore。 +func (r *RemoteRuntimeAdapter) RestoreCheckpoint( + ctx context.Context, + input CheckpointRestoreInput, +) (CheckpointRestoreResult, error) { + if err := r.authenticate(ctx); err != nil { + return CheckpointRestoreResult{}, err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayRestoreCheckpoint, protocol.RestoreCheckpointParams{ + SessionID: strings.TrimSpace(input.SessionID), + CheckpointID: strings.TrimSpace(input.CheckpointID), + Force: input.Force, + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return CheckpointRestoreResult{}, err + } + return decodeFramePayload[CheckpointRestoreResult](frame.Payload) +} + +// UndoRestoreCheckpoint 转发 checkpoint.undoRestore。 +func (r *RemoteRuntimeAdapter) UndoRestoreCheckpoint( + ctx context.Context, + sessionID string, +) (CheckpointRestoreResult, error) { + if err := r.authenticate(ctx); err != nil { + return CheckpointRestoreResult{}, err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayUndoRestore, protocol.UndoRestoreParams{ + SessionID: strings.TrimSpace(sessionID), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return CheckpointRestoreResult{}, err + } + return decodeFramePayload[CheckpointRestoreResult](frame.Payload) +} + +// CheckpointDiff 转发 checkpoint.diff。 +func (r *RemoteRuntimeAdapter) CheckpointDiff( + ctx context.Context, + sessionID string, + checkpointID string, +) (CheckpointDiffResult, error) { + if err := r.authenticate(ctx); err != nil { + return CheckpointDiffResult{}, err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayCheckpointDiff, protocol.CheckpointDiffParams{ + SessionID: strings.TrimSpace(sessionID), + CheckpointID: strings.TrimSpace(checkpointID), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return CheckpointDiffResult{}, err + } + return decodeFramePayload[CheckpointDiffResult](frame.Payload) +} + +// ListWorkspaces 转发 gateway.listWorkspaces。 +func (r *RemoteRuntimeAdapter) ListWorkspaces(ctx context.Context) ([]WorkspaceRecord, error) { + if err := r.authenticate(ctx); err != nil { + return nil, err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayListWorkspaces, protocol.ListWorkspacesParams{}, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return nil, err + } + payload := struct { + Workspaces []WorkspaceRecord `json:"workspaces"` + }{} + if err := decodeIntoValue(frame.Payload, &payload); err != nil { + return nil, err + } + return payload.Workspaces, nil +} + +// CreateWorkspace 转发 gateway.createWorkspace。 +func (r *RemoteRuntimeAdapter) CreateWorkspace( + ctx context.Context, + input WorkspaceCreateInput, +) (WorkspaceRecord, error) { + if err := r.authenticate(ctx); err != nil { + return WorkspaceRecord{}, err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayCreateWorkspace, protocol.CreateWorkspaceParams{ + Path: strings.TrimSpace(input.Path), + Name: strings.TrimSpace(input.Name), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return WorkspaceRecord{}, err + } + payload := struct { + Workspace WorkspaceRecord `json:"workspace"` + }{} + if err := decodeIntoValue(frame.Payload, &payload); err != nil { + return WorkspaceRecord{}, err + } + return payload.Workspace, nil +} + +// SwitchWorkspace 转发 gateway.switchWorkspace。 +func (r *RemoteRuntimeAdapter) SwitchWorkspace(ctx context.Context, workspaceHash string) error { + if err := r.authenticate(ctx); err != nil { + return err + } + _, err := r.callFrame(ctx, protocol.MethodGatewaySwitchWorkspace, protocol.SwitchWorkspaceParams{ + WorkspaceHash: strings.TrimSpace(workspaceHash), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + return err +} + +// RenameWorkspace 转发 gateway.renameWorkspace。 +func (r *RemoteRuntimeAdapter) RenameWorkspace( + ctx context.Context, + input WorkspaceRenameInput, +) (WorkspaceRecord, error) { + if err := r.authenticate(ctx); err != nil { + return WorkspaceRecord{}, err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayRenameWorkspace, protocol.RenameWorkspaceParams{ + WorkspaceHash: strings.TrimSpace(input.WorkspaceHash), + Name: strings.TrimSpace(input.Name), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return WorkspaceRecord{}, err + } + payload := struct { + Hash string `json:"hash"` + Name string `json:"name"` + }{} + if err := decodeIntoValue(frame.Payload, &payload); err != nil { + return WorkspaceRecord{}, err + } + return WorkspaceRecord{ + Hash: strings.TrimSpace(payload.Hash), + Name: strings.TrimSpace(payload.Name), + }, nil +} + +// DeleteWorkspace 转发 gateway.deleteWorkspace。 +func (r *RemoteRuntimeAdapter) DeleteWorkspace(ctx context.Context, input WorkspaceDeleteInput) error { + if err := r.authenticate(ctx); err != nil { + return err + } + _, err := r.callFrame(ctx, protocol.MethodGatewayDeleteWorkspace, protocol.DeleteWorkspaceParams{ + WorkspaceHash: strings.TrimSpace(input.WorkspaceHash), + RemoveData: input.RemoveData, + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + return err +} + // Close 关闭远程适配器并结束事件桥接。 func (r *RemoteRuntimeAdapter) Close() error { var closeErr error @@ -491,6 +694,12 @@ func (r *RemoteRuntimeAdapter) authenticate(ctx context.Context) error { return r.rpcClient.Authenticate(ctx) } +func (r *RemoteRuntimeAdapter) startupHandshake(ctx context.Context) error { + sessionID := fmt.Sprintf("%s-%d", startupProbeSessionPrefix, time.Now().UnixNano()) + runID := fmt.Sprintf("%s-%d", startupProbeRunPrefix, time.Now().UnixNano()) + return r.bindStream(ctx, sessionID, runID) +} + func (r *RemoteRuntimeAdapter) bindStream(ctx context.Context, sessionID string, runID string) error { _, err := r.callFrame(ctx, protocol.MethodGatewayBindStream, protocol.BindStreamParams{ SessionID: strings.TrimSpace(sessionID), @@ -619,7 +828,7 @@ func normalizeSystemToolArguments(arguments []byte) json.RawMessage { return json.RawMessage(cloned) } -func buildGatewayRunParams(sessionID string, runID string, input PrepareInput) protocol.RunParams { +func buildGatewayRunParams(sessionID string, runID string, requestNewSession bool, input PrepareInput) protocol.RunParams { parts := make([]protocol.RunInputPart, 0, len(input.Images)) for _, image := range input.Images { path := strings.TrimSpace(image.Path) @@ -637,6 +846,7 @@ func buildGatewayRunParams(sessionID string, runID string, input PrepareInput) p return protocol.RunParams{ SessionID: strings.TrimSpace(sessionID), + NewSession: requestNewSession, RunID: strings.TrimSpace(runID), InputText: strings.TrimSpace(input.Text), InputParts: parts, @@ -645,43 +855,6 @@ func buildGatewayRunParams(sessionID string, runID string, input PrepareInput) p } } -func renderInputTextFromParts(parts []providertypes.ContentPart) string { - textParts := make([]string, 0, len(parts)) - for _, part := range parts { - if part.Kind != providertypes.ContentPartText { - continue - } - text := strings.TrimSpace(part.Text) - if text == "" { - continue - } - textParts = append(textParts, text) - } - return strings.Join(textParts, "\n") -} - -func renderInputImagesFromParts(parts []providertypes.ContentPart) []UserImageInput { - images := make([]UserImageInput, 0, len(parts)) - for _, part := range parts { - if part.Kind != providertypes.ContentPartImage || part.Image == nil { - continue - } - path := strings.TrimSpace(part.Image.URL) - if path == "" { - continue - } - mimeType := "" - if part.Image.Asset != nil { - mimeType = strings.TrimSpace(part.Image.Asset.MimeType) - } - images = append(images, UserImageInput{ - Path: path, - MimeType: mimeType, - }) - } - return images -} - func mapGatewaySessionToRuntimeSession(source gateway.Session) agentsession.Session { messages := make([]providertypes.Message, 0, len(source.Messages)) for _, item := range source.Messages { diff --git a/internal/tui/services/remote_runtime_adapter_additional_test.go b/internal/tui/services/remote_runtime_adapter_additional_test.go index c8cdcd1b..a9e8aaea 100644 --- a/internal/tui/services/remote_runtime_adapter_additional_test.go +++ b/internal/tui/services/remote_runtime_adapter_additional_test.go @@ -9,7 +9,6 @@ import ( "neo-code/internal/gateway" "neo-code/internal/gateway/protocol" - providertypes "neo-code/internal/provider/types" "neo-code/internal/skills" ) @@ -33,6 +32,11 @@ func TestNewRemoteRuntimeAdapterBranches(t *testing.T) { if options.ListenAddress == "dial-failed" { client.authErr = errors.New("dial failed") } + if options.ListenAddress == "bind-failed" { + client.callErrs = map[string]error{ + protocol.MethodGatewayBindStream: errors.New("bind failed"), + } + } return client, nil } newGatewayStreamClientFactory = func(source <-chan gatewayRPCNotification) *GatewayStreamClient { @@ -45,6 +49,9 @@ func TestNewRemoteRuntimeAdapterBranches(t *testing.T) { if _, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ListenAddress: "dial-failed", RequestTimeout: -1}); err == nil { t.Fatalf("expected authenticate fail-fast error") } + if _, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ListenAddress: "bind-failed", RequestTimeout: -1}); err == nil { + t.Fatalf("expected bindStream fail-fast error") + } adapter, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ ListenAddress: "ok", @@ -63,14 +70,13 @@ func TestNewRemoteRuntimeAdapterBranches(t *testing.T) { _ = adapter.Close() } -func TestRemoteRuntimeAdapterPrepareUserInputAndRun(t *testing.T) { +func TestRemoteRuntimeAdapterSubmitGeneratesIDsAndNormalizesInput(t *testing.T) { t.Parallel() rpcClient := &stubRemoteRPCClient{ frames: map[string]gateway.MessageFrame{ - protocol.MethodGatewayLoadSession: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionLoadSession, SessionID: "s-1"}, - protocol.MethodGatewayBindStream: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionBindStream, SessionID: "s-1", RunID: "r-1"}, - protocol.MethodGatewayRun: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionRun, SessionID: "s-1", RunID: "r-1"}, + protocol.MethodGatewayBindStream: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionBindStream, SessionID: "s-1", RunID: "r-1"}, + protocol.MethodGatewayRun: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionRun, SessionID: "s-1", RunID: "r-1"}, }, notifications: make(chan gatewayRPCNotification), } @@ -78,13 +84,7 @@ func TestRemoteRuntimeAdapterPrepareUserInputAndRun(t *testing.T) { adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) t.Cleanup(func() { _ = adapter.Close() }) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - if _, err := adapter.PrepareUserInput(ctx, PrepareInput{}); err == nil { - t.Fatalf("expected context cancellation error") - } - - input, err := adapter.PrepareUserInput(context.Background(), PrepareInput{ + if err := adapter.Submit(context.Background(), PrepareInput{ SessionID: " ", RunID: "", Text: " hello ", @@ -93,23 +93,28 @@ func TestRemoteRuntimeAdapterPrepareUserInputAndRun(t *testing.T) { {Path: " /tmp/a.png ", MimeType: " image/png "}, }, Workdir: " /repo ", - }) - if err != nil { - t.Fatalf("PrepareUserInput() error = %v", err) + }); err != nil { + t.Fatalf("Submit() error = %v", err) + } + methods := rpcClient.snapshotMethods() + if len(methods) != 2 || methods[1] != protocol.MethodGatewayRun { + t.Fatalf("unexpected method chain: %#v", methods) } - if strings.TrimSpace(input.SessionID) == "" || strings.TrimSpace(input.RunID) == "" { - t.Fatalf("session/run id should be generated") + params, ok := rpcClient.snapshotParams()[protocol.MethodGatewayRun].(protocol.RunParams) + if !ok { + t.Fatalf("run params type = %T, want protocol.RunParams", rpcClient.snapshotParams()[protocol.MethodGatewayRun]) } - if len(input.Parts) != 2 { - t.Fatalf("parts len = %d, want 2", len(input.Parts)) + if !params.NewSession { + t.Fatalf("expected new_session=true when submitting draft input, params=%#v", params) } - - if err := adapter.Run(context.Background(), input); err != nil { - t.Fatalf("Run() error = %v", err) + if strings.TrimSpace(params.SessionID) == "" || strings.TrimSpace(params.RunID) == "" { + t.Fatalf("session/run id should be generated: %#v", params) } - methods := rpcClient.snapshotMethods() - if len(methods) != 3 || methods[2] != protocol.MethodGatewayRun { - t.Fatalf("unexpected method chain: %#v", methods) + if params.InputText != "hello" || params.Workdir != "/repo" { + t.Fatalf("unexpected normalized run payload: %#v", params) + } + if len(params.InputParts) != 1 || params.InputParts[0].Media == nil || params.InputParts[0].Media.URI != "/tmp/a.png" { + t.Fatalf("unexpected run input parts: %#v", params.InputParts) } } @@ -611,36 +616,209 @@ func TestRemoteRuntimeAdapterListAndLoadSessionErrorPaths(t *testing.T) { } } -func TestRemoteRuntimeAdapterRenderInputHelpers(t *testing.T) { +func TestRemoteRuntimeAdapterBuildGatewayRunParams(t *testing.T) { t.Parallel() - text := renderInputTextFromParts([]providertypes.ContentPart{ - providertypes.NewTextPart(" first "), - providertypes.NewRemoteImagePart("/tmp/a.png"), - providertypes.NewTextPart("second"), - providertypes.NewTextPart(" "), - }) - if text != "first\nsecond" { - t.Fatalf("renderInputTextFromParts() = %q", text) - } - - images := renderInputImagesFromParts([]providertypes.ContentPart{ - providertypes.NewTextPart("x"), - providertypes.NewRemoteImagePart(" "), - providertypes.ContentPart{ - Kind: providertypes.ContentPartImage, - Image: &providertypes.ImagePart{ - URL: " /tmp/b.png ", - Asset: &providertypes.AssetRef{MimeType: " image/png "}, + params := buildGatewayRunParams(" s ", " r ", true, PrepareInput{Text: " hi ", Workdir: " /w ", Mode: " plan ", Images: []UserImageInput{{Path: " /img.png ", MimeType: " image/png "}, {Path: " "}}}) + if params.SessionID != "s" || !params.NewSession || params.RunID != "r" || params.Workdir != "/w" || params.Mode != "plan" || params.InputText != "hi" || len(params.InputParts) != 1 { + t.Fatalf("buildGatewayRunParams() = %#v", params) + } +} + +func TestRemoteRuntimeAdapterListModels(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayListModels: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionListModels, + Payload: map[string]any{ + "models": []gateway.ModelEntry{ + {ID: " m-1 ", Name: " Model One "}, + {ID: " m-2 ", Name: ""}, + {ID: " ", Name: "ignored"}, + }, + "selected_model_id": " m-2 ", + }, + }, + }, + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients( + rpcClient, + &stubRemoteStreamClient{events: make(chan RuntimeEvent)}, + time.Second, + 1, + ) + t.Cleanup(func() { _ = adapter.Close() }) + + models, selectedID, err := adapter.ListModels(context.Background(), " session-1 ") + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + if len(models) != 2 { + t.Fatalf("expected two models after trimming/filtering, got %#v", models) + } + if models[0].ID != "m-1" || models[0].Name != "Model One" { + t.Fatalf("unexpected first model mapping: %#v", models[0]) + } + if models[1].ID != "m-2" || models[1].Name != "m-2" { + t.Fatalf("unexpected second model mapping: %#v", models[1]) + } + if selectedID != "m-2" { + t.Fatalf("selected model id = %q, want %q", selectedID, "m-2") + } + + params, ok := rpcClient.snapshotParams()[protocol.MethodGatewayListModels].(protocol.ListModelsParams) + if !ok { + t.Fatalf("listModels params type = %T, want protocol.ListModelsParams", rpcClient.snapshotParams()[protocol.MethodGatewayListModels]) + } + if params.SessionID != "session-1" { + t.Fatalf("session_id = %q, want %q", params.SessionID, "session-1") + } +} + +func TestRemoteRuntimeAdapterCheckpointAndWorkspaceMethods(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayListCheckpoints: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionListCheckpoints, + Payload: []CheckpointEntry{ + {CheckpointID: "cp-1", SessionID: "s-1", CreatedAtMS: 1700000000000, Restorable: true}, + }, + }, + protocol.MethodGatewayRestoreCheckpoint: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionRestoreCheckpoint, + Payload: CheckpointRestoreResult{ + CheckpointID: "cp-1", + SessionID: "s-1", + }, + }, + protocol.MethodGatewayUndoRestore: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionUndoRestore, + Payload: CheckpointRestoreResult{ + CheckpointID: "cp-guard", + SessionID: "s-1", + }, + }, + protocol.MethodGatewayCheckpointDiff: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionCheckpointDiff, + Payload: CheckpointDiffResult{ + CheckpointID: "cp-1", + PrevCheckpointID: "cp-0", + Files: CheckpointDiffFiles{ + Modified: []string{"a.txt"}, + }, + Patch: "diff --git a/a.txt b/a.txt", + }, + }, + protocol.MethodGatewayListWorkspaces: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWorkspaceList, + Payload: map[string]any{ + "workspaces": []WorkspaceRecord{ + {Hash: "ws-1", Path: "/repo", Name: "repo", CreatedAt: now, UpdatedAt: now}, + }, + }, + }, + protocol.MethodGatewayCreateWorkspace: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWorkspaceCreate, + Payload: map[string]any{ + "workspace": WorkspaceRecord{Hash: "ws-2", Path: "/repo2", Name: "repo2"}, + }, + }, + protocol.MethodGatewaySwitchWorkspace: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWorkspaceSwitch, + Payload: map[string]any{ + "workspace_hash": "ws-1", + }, + }, + protocol.MethodGatewayRenameWorkspace: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWorkspaceRename, + Payload: map[string]any{ + "hash": "ws-1", + "name": "renamed", + }, + }, + protocol.MethodGatewayDeleteWorkspace: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWorkspaceDelete, + Payload: map[string]any{ + "hash": "ws-1", + }, }, }, + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients( + rpcClient, + &stubRemoteStreamClient{events: make(chan RuntimeEvent)}, + time.Second, + 1, + ) + t.Cleanup(func() { _ = adapter.Close() }) + + checkpoints, err := adapter.ListCheckpoints(context.Background(), CheckpointListInput{ + SessionID: " s-1 ", + Limit: 20, + RestorableOnly: true, }) - if len(images) != 1 || images[0].Path != "/tmp/b.png" || images[0].MimeType != "image/png" { - t.Fatalf("renderInputImagesFromParts() = %#v", images) + if err != nil || len(checkpoints) != 1 || checkpoints[0].CheckpointID != "cp-1" { + t.Fatalf("ListCheckpoints() = (%#v, %v)", checkpoints, err) + } + restoreResult, err := adapter.RestoreCheckpoint(context.Background(), CheckpointRestoreInput{ + SessionID: " s-1 ", + CheckpointID: " cp-1 ", + Force: true, + }) + if err != nil || restoreResult.CheckpointID != "cp-1" { + t.Fatalf("RestoreCheckpoint() = (%#v, %v)", restoreResult, err) + } + undoResult, err := adapter.UndoRestoreCheckpoint(context.Background(), " s-1 ") + if err != nil || undoResult.CheckpointID != "cp-guard" { + t.Fatalf("UndoRestoreCheckpoint() = (%#v, %v)", undoResult, err) + } + diffResult, err := adapter.CheckpointDiff(context.Background(), " s-1 ", " cp-1 ") + if err != nil || diffResult.CheckpointID != "cp-1" || len(diffResult.Files.Modified) != 1 { + t.Fatalf("CheckpointDiff() = (%#v, %v)", diffResult, err) } - params := buildGatewayRunParams(" s ", " r ", PrepareInput{Text: " hi ", Workdir: " /w ", Mode: " plan ", Images: []UserImageInput{{Path: " /img.png ", MimeType: " image/png "}, {Path: " "}}}) - if params.SessionID != "s" || params.RunID != "r" || params.Workdir != "/w" || params.Mode != "plan" || params.InputText != "hi" || len(params.InputParts) != 1 { - t.Fatalf("buildGatewayRunParams() = %#v", params) + workspaces, err := adapter.ListWorkspaces(context.Background()) + if err != nil || len(workspaces) != 1 || workspaces[0].Hash != "ws-1" { + t.Fatalf("ListWorkspaces() = (%#v, %v)", workspaces, err) + } + createdWorkspace, err := adapter.CreateWorkspace(context.Background(), WorkspaceCreateInput{ + Path: " /repo2 ", + Name: " repo2 ", + }) + if err != nil || createdWorkspace.Hash != "ws-2" { + t.Fatalf("CreateWorkspace() = (%#v, %v)", createdWorkspace, err) + } + if err := adapter.SwitchWorkspace(context.Background(), " ws-1 "); err != nil { + t.Fatalf("SwitchWorkspace() error = %v", err) + } + renamedWorkspace, err := adapter.RenameWorkspace(context.Background(), WorkspaceRenameInput{ + WorkspaceHash: " ws-1 ", + Name: " renamed ", + }) + if err != nil || renamedWorkspace.Hash != "ws-1" || renamedWorkspace.Name != "renamed" { + t.Fatalf("RenameWorkspace() = (%#v, %v)", renamedWorkspace, err) + } + if err := adapter.DeleteWorkspace(context.Background(), WorkspaceDeleteInput{ + WorkspaceHash: " ws-1 ", + RemoveData: true, + }); err != nil { + t.Fatalf("DeleteWorkspace() error = %v", err) } } diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go index 4772495b..afa70160 100644 --- a/internal/tui/services/runtime_contract.go +++ b/internal/tui/services/runtime_contract.go @@ -13,8 +13,6 @@ import ( // Runtime 定义 TUI 与运行时交互所需的最小契约。 type Runtime interface { Submit(ctx context.Context, input PrepareInput) error - PrepareUserInput(ctx context.Context, input PrepareInput) (UserInput, error) - Run(ctx context.Context, input UserInput) error Compact(ctx context.Context, input CompactInput) (CompactResult, error) ExecuteSystemTool(ctx context.Context, input SystemToolInput) (tools.ToolResult, error) ResolvePermission(ctx context.Context, input PermissionResolutionInput) error @@ -169,12 +167,88 @@ type AvailableSkillState struct { Active bool } +// CheckpointEntry 描述 checkpoint 列表项。 +type CheckpointEntry struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + Reason string `json:"reason"` + Status string `json:"status"` + Restorable bool `json:"restorable"` + CreatedAtMS int64 `json:"created_at_ms"` +} + +// CheckpointListInput 描述 checkpoint.list 查询参数。 +type CheckpointListInput struct { + SessionID string + Limit int + RestorableOnly bool +} + +// CheckpointRestoreInput 描述 checkpoint.restore 入参。 +type CheckpointRestoreInput struct { + SessionID string + CheckpointID string + Force bool +} + +// CheckpointRestoreResult 描述 checkpoint.restore / checkpoint.undoRestore 结果。 +type CheckpointRestoreResult struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + HasConflict bool `json:"has_conflict,omitempty"` +} + +// CheckpointDiffFiles 描述 checkpoint diff 的文件分类。 +type CheckpointDiffFiles struct { + Added []string `json:"added,omitempty"` + Deleted []string `json:"deleted,omitempty"` + Modified []string `json:"modified,omitempty"` +} + +// CheckpointDiffResult 描述 checkpoint.diff 返回结构。 +type CheckpointDiffResult struct { + CheckpointID string `json:"checkpoint_id"` + PrevCheckpointID string `json:"prev_checkpoint_id,omitempty"` + CommitHash string `json:"commit_hash,omitempty"` + PrevCommitHash string `json:"prev_commit_hash,omitempty"` + Files CheckpointDiffFiles `json:"files"` + Patch string `json:"patch,omitempty"` +} + +// WorkspaceRecord 描述工作区登记信息。 +type WorkspaceRecord struct { + Hash string `json:"hash"` + Path string `json:"path"` + Name string `json:"name"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// WorkspaceCreateInput 描述 workspace.create 入参。 +type WorkspaceCreateInput struct { + Path string + Name string +} + +// WorkspaceRenameInput 描述 workspace.rename 入参。 +type WorkspaceRenameInput struct { + WorkspaceHash string + Name string +} + +// WorkspaceDeleteInput 描述 workspace.delete 入参。 +type WorkspaceDeleteInput struct { + WorkspaceHash string + RemoveData bool +} + // SessionLogEntry 描述日志查看器持久化条目。 type SessionLogEntry struct { Timestamp time.Time `json:"timestamp"` Level string `json:"level"` Source string `json:"source"` Message string `json:"message"` + Inline string `json:"inline_message,omitempty"` } // PhaseChangedPayload 描述阶段切换信息。 @@ -479,6 +553,68 @@ type RepoHooksLifecyclePayload struct { Reason string `json:"reason,omitempty"` } +// CheckpointCreatedPayload 描述 checkpoint 创建成功事件。 +type CheckpointCreatedPayload struct { + CheckpointID string `json:"checkpoint_id"` + CodeCheckpointRef string `json:"code_checkpoint_ref"` + SessionCheckpointRef string `json:"session_checkpoint_ref"` + CommitHash string `json:"commit_hash"` + Reason string `json:"reason"` +} + +// CheckpointWarningPayload 描述 checkpoint 创建中的非致命告警。 +type CheckpointWarningPayload struct { + Error string `json:"error"` + Phase string `json:"phase"` +} + +// CheckpointRestoredPayload 描述 checkpoint 恢复成功事件。 +type CheckpointRestoredPayload struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + GuardCheckpointID string `json:"guard_checkpoint_id"` +} + +// CheckpointUndoRestorePayload 描述 restore 撤销事件。 +type CheckpointUndoRestorePayload struct { + GuardCheckpointID string `json:"guard_checkpoint_id"` + SessionID string `json:"session_id"` +} + +// FileChange 描述一次文件变更。 +type FileChange struct { + Path string `json:"path"` + Kind string `json:"kind"` +} + +// FileDiffEntry 描述单个文件的 diff。 +type FileDiffEntry struct { + Path string `json:"path"` + Diff string `json:"diff,omitempty"` + WasNew bool `json:"was_new,omitempty"` + Kind string `json:"kind,omitempty"` +} + +// ToolDiffPayload 描述写工具变更。 +type ToolDiffPayload struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + FilePath string `json:"file_path"` + Diff string `json:"diff,omitempty"` + WasNew bool `json:"was_new,omitempty"` + Files []FileChange `json:"files,omitempty"` + Diffs []FileDiffEntry `json:"diffs,omitempty"` +} + +// BashSideEffectPayload 描述 bash 命令文件侧效应。 +type BashSideEffectPayload struct { + ToolCallID string `json:"tool_call_id"` + Command string `json:"command,omitempty"` + Changes []FileChange `json:"changes"` + PreemptivelyCapturedPaths []string `json:"preemptively_captured_paths,omitempty"` + UncoveredPaths []string `json:"uncovered_paths,omitempty"` +} + const ( EventUserMessage EventType = "user_message" EventAgentChunk EventType = "agent_chunk" @@ -522,6 +658,12 @@ const ( EventRepoHooksLoaded EventType = "repo_hooks_loaded" EventRepoHooksSkippedUntrusted EventType = "repo_hooks_skipped_untrusted" EventRepoHooksTrustStoreInvalid EventType = "repo_hooks_trust_store_invalid" + EventCheckpointCreated EventType = "checkpoint_created" + EventCheckpointWarning EventType = "checkpoint_warning" + EventCheckpointRestored EventType = "checkpoint_restored" + EventCheckpointUndoRestore EventType = "checkpoint_undo_restore" + EventToolDiff EventType = "tool_diff" + EventBashSideEffect EventType = "bash_side_effect" EventSubAgentStarted EventType = "subagent_started" EventSubAgentProgress EventType = "subagent_progress" EventSubAgentRetried EventType = "subagent_retried" diff --git a/internal/tui/services/runtime_service.go b/internal/tui/services/runtime_service.go index 7fae8e6d..baaa3c6d 100644 --- a/internal/tui/services/runtime_service.go +++ b/internal/tui/services/runtime_service.go @@ -11,11 +11,6 @@ import ( const permissionResolveTimeout = 10 * time.Second -// Runner 定义执行 run 所需的最小能力。 -type Runner interface { - Run(ctx context.Context, input UserInput) error -} - // Submitter 定义单入口提交所需能力。 type Submitter interface { Submit(ctx context.Context, input PrepareInput) error @@ -47,14 +42,6 @@ func ListenForRuntimeEventCmd(sub <-chan RuntimeEvent, eventMsg func(RuntimeEven } } -// RunAgentCmd 执行 run 并回传结果。 -func RunAgentCmd(runtime Runner, input UserInput, doneMsg func(error) tea.Msg) tea.Cmd { - return func() tea.Msg { - err := runtime.Run(context.Background(), input) - return doneMsg(err) - } -} - // RunSubmitCmd 执行 submit 并回传结果。 func RunSubmitCmd(runtime Submitter, input PrepareInput, doneMsg func(error) tea.Msg) tea.Cmd { return func() tea.Msg { diff --git a/internal/tui/services/services_test.go b/internal/tui/services/services_test.go index db831959..070de12f 100644 --- a/internal/tui/services/services_test.go +++ b/internal/tui/services/services_test.go @@ -15,16 +15,6 @@ import ( "neo-code/internal/tools" ) -type stubRunner struct { - lastInput UserInput - err error -} - -func (s *stubRunner) Run(ctx context.Context, input UserInput) error { - s.lastInput = input - return s.err -} - type stubSubmitter struct { lastInput PrepareInput err error @@ -113,18 +103,6 @@ func TestListenForRuntimeEventCmd(t *testing.T) { } } -func TestRunAgentCmd(t *testing.T) { - runner := &stubRunner{err: errors.New("boom")} - input := UserInput{SessionID: "s1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, Workdir: "D:/"} - msg := RunAgentCmd(runner, input, func(err error) tea.Msg { return err })() - if runner.lastInput.SessionID != "s1" || renderPartsForTest(runner.lastInput.Parts) != "hello" { - t.Fatalf("unexpected runner input: %+v", runner.lastInput) - } - if err, ok := msg.(error); !ok || err == nil || err.Error() != "boom" { - t.Fatalf("expected forwarded error message, got %T %#v", msg, msg) - } -} - func TestRunSubmitCmd(t *testing.T) { runner := &stubSubmitter{err: errors.New("run failed")} prepareInput := PrepareInput{