diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go index 6dae9662..d8d4a950 100644 --- a/internal/checkpoint/checkpoint_manager.go +++ b/internal/checkpoint/checkpoint_manager.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "reflect" "sync" "time" @@ -423,7 +424,7 @@ func (s *SQLiteCheckpointStore) RestoreCheckpoint(ctx context.Context, input Res toUnixMillis(input.UpdatedAt), h.Provider, h.Model, h.Workdir, marshalHeadField(h.TaskState), marshalHeadField(h.Todos), marshalHeadField(h.ActivatedSkills), h.TokenInputTotal, h.TokenOutputTotal, boolToInt(h.HasUnknownUsage), h.AgentMode, - marshalHeadField(h.CurrentPlan), h.LastFullPlanRevision, + marshalPlanField(h.CurrentPlan), h.LastFullPlanRevision, boolToInt(h.PlanApprovalPendingFullAlign), boolToInt(h.PlanCompletionPendingFullReview), boolToInt(h.PlanContextDirty), boolToInt(h.PlanRestorePendingAlign), len(input.Messages), len(input.Messages), input.SessionID, @@ -456,12 +457,25 @@ func (s *SQLiteCheckpointStore) RestoreCheckpoint(ctx context.Context, input Res } func marshalHeadField(value any) string { + data, err := json.Marshal(value) + if err != nil { + return "null" + } + return string(data) +} + +// marshalPlanField 将可选计划字段编码为 session 兼容的持久化格式,nil 计划统一写为空串。 +func marshalPlanField(value any) string { if value == nil { - return "{}" + return "" + } + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Pointer && rv.IsNil() { + return "" } data, err := json.Marshal(value) if err != nil { - return "{}" + return "" } return string(data) } diff --git a/internal/checkpoint/checkpoint_manager_test.go b/internal/checkpoint/checkpoint_manager_test.go new file mode 100644 index 00000000..b8986d2f --- /dev/null +++ b/internal/checkpoint/checkpoint_manager_test.go @@ -0,0 +1,417 @@ +package checkpoint + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + providertypes "neo-code/internal/provider/types" + "neo-code/internal/session" +) + +type checkpointStoreFixture struct { + sessionStore *session.SQLiteStore + checkpointStore *SQLiteCheckpointStore + baseDir string + workspaceRoot string +} + +func newCheckpointStoreFixture(t *testing.T) checkpointStoreFixture { + t.Helper() + + baseDir := t.TempDir() + workspaceRoot := t.TempDir() + + sessionStore := session.NewSQLiteStore(baseDir, workspaceRoot) + t.Cleanup(func() { + _ = sessionStore.Close() + }) + + checkpointStore := NewSQLiteCheckpointStore(session.DatabasePath(baseDir, workspaceRoot)) + t.Cleanup(func() { + _ = checkpointStore.Close() + }) + + return checkpointStoreFixture{ + sessionStore: sessionStore, + checkpointStore: checkpointStore, + baseDir: baseDir, + workspaceRoot: workspaceRoot, + } +} + +func createCheckpointTestSession(t *testing.T, store *session.SQLiteStore, id string, workdir string) session.Session { + t.Helper() + + created, err := store.CreateSession(context.Background(), session.CreateSessionInput{ + ID: id, + Title: "checkpoint test", + Head: session.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + TaskState: session.TaskState{ + Goal: "before restore", + VerificationProfile: session.VerificationProfileTaskOnly, + }, + Todos: []session.TodoItem{ + {ID: "todo-1", Content: "before restore"}, + }, + }, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + messages := []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("before restore"), + }, + }, + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("tool planned"), + }, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "bash", Arguments: `{"cmd":"pwd"}`}, + }, + ToolMetadata: map[string]string{"source": "test"}, + }, + } + if err := store.AppendMessages(context.Background(), session.AppendMessagesInput{ + SessionID: created.ID, + Messages: messages, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) + } + + loaded, err := store.LoadSession(context.Background(), created.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) + } + return loaded +} + +func checkpointInputFromSession(t *testing.T, loaded session.Session, checkpointID string, reason session.CheckpointReason, createdAt time.Time) CreateCheckpointInput { + t.Helper() + + headJSON, err := json.Marshal(loaded.HeadSnapshot()) + if err != nil { + t.Fatalf("Marshal(head) error = %v", err) + } + messagesJSON, err := json.Marshal(loaded.Messages) + if err != nil { + t.Fatalf("Marshal(messages) error = %v", err) + } + + return CreateCheckpointInput{ + Record: session.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: session.WorkspacePathKey(loaded.Workdir), + SessionID: loaded.ID, + RunID: "run-" + checkpointID, + Workdir: loaded.Workdir, + CreatedAt: createdAt, + Reason: reason, + CodeCheckpointRef: RefForCheckpoint(loaded.ID, checkpointID), + Restorable: true, + Status: session.CheckpointStatusCreating, + }, + SessionCP: session.SessionCheckpoint{ + ID: "sc-" + checkpointID, + SessionID: loaded.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: createdAt, + }, + } +} + +func TestSQLiteCheckpointStoreCreateRestoreAndResume(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_restore", fixture.workspaceRoot) + checkpointCreatedAt := time.Now().Add(-time.Minute) + + input := checkpointInputFromSession(t, loaded, "cp-restore", session.CheckpointReasonPreWrite, checkpointCreatedAt) + saved, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), input) + if err != nil { + t.Fatalf("CreateCheckpoint() error = %v", err) + } + if saved.SessionCheckpointRef == "" || saved.Status != session.CheckpointStatusAvailable { + t.Fatalf("CreateCheckpoint() = %#v, want available checkpoint with session ref", saved) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{ + Limit: 10, + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].CheckpointID != saved.CheckpointID { + t.Fatalf("ListCheckpoints() = %#v, want only %q", records, saved.CheckpointID) + } + + record, sessionCP, err := fixture.checkpointStore.GetCheckpoint(context.Background(), saved.CheckpointID) + if err != nil { + t.Fatalf("GetCheckpoint() error = %v", err) + } + if record.CheckpointID != saved.CheckpointID || sessionCP == nil || sessionCP.ID != saved.SessionCheckpointRef { + t.Fatalf("GetCheckpoint() = (%#v, %#v), want saved record and session snapshot", record, sessionCP) + } + + if err := fixture.sessionStore.UpdateSessionState(context.Background(), session.UpdateSessionStateInput{ + SessionID: loaded.ID, + UpdatedAt: time.Now(), + Title: "mutated", + Head: session.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: loaded.Workdir, + TaskState: session.TaskState{ + Goal: "after restore", + VerificationProfile: session.VerificationProfileTaskOnly, + }, + Todos: []session.TodoItem{ + {ID: "todo-2", Content: "after restore"}, + }, + }, + }); err != nil { + t.Fatalf("UpdateSessionState() error = %v", err) + } + if err := fixture.sessionStore.AppendMessages(context.Background(), session.AppendMessagesInput{ + SessionID: loaded.ID, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("after restore"), + }, + }, + }, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: loaded.Workdir, + }); err != nil { + t.Fatalf("AppendMessages(after) error = %v", err) + } + + if err := fixture.checkpointStore.RestoreCheckpoint(context.Background(), RestoreCheckpointInput{ + SessionID: loaded.ID, + Head: loaded.HeadSnapshot(), + Messages: loaded.Messages, + UpdatedAt: time.Now(), + }); err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + + restored, err := fixture.sessionStore.LoadSession(context.Background(), loaded.ID) + if err != nil { + t.Fatalf("LoadSession(restored) error = %v", err) + } + if restored.TaskState.Goal != loaded.TaskState.Goal { + t.Fatalf("restored goal = %q, want %q", restored.TaskState.Goal, loaded.TaskState.Goal) + } + if len(restored.Messages) != len(loaded.Messages) { + t.Fatalf("restored message count = %d, want %d", len(restored.Messages), len(loaded.Messages)) + } + if restored.Messages[1].ToolMetadata["source"] != "test" { + t.Fatalf("restored tool metadata = %#v, want preserved metadata", restored.Messages[1].ToolMetadata) + } + + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), saved.CheckpointID, session.CheckpointStatusRestored); err != nil { + t.Fatalf("UpdateCheckpointStatus() error = %v", err) + } + filtered, err := fixture.checkpointStore.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{ + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints(filtered) error = %v", err) + } + if len(filtered) != 0 { + t.Fatalf("expected no restorable checkpoints after status change, got %#v", filtered) + } + + firstResume := session.ResumeCheckpoint{ + ID: "rc-1", + WorkspaceKey: session.WorkspacePathKey(loaded.Workdir), + RunID: "run-1", + SessionID: loaded.ID, + Turn: 1, + Phase: "plan", + CompletionState: "running", + TranscriptRevision: 3, + UpdatedAt: time.Now().Add(-time.Minute), + } + secondResume := firstResume + secondResume.ID = "rc-2" + secondResume.RunID = "run-2" + secondResume.Turn = 2 + secondResume.Phase = "execute" + secondResume.UpdatedAt = time.Now() + + if err := fixture.checkpointStore.SetResumeCheckpoint(context.Background(), firstResume); err != nil { + t.Fatalf("SetResumeCheckpoint(first) error = %v", err) + } + if err := fixture.checkpointStore.SetResumeCheckpoint(context.Background(), secondResume); err != nil { + t.Fatalf("SetResumeCheckpoint(second) error = %v", err) + } + gotResume, err := fixture.checkpointStore.GetLatestResumeCheckpoint(context.Background(), loaded.ID) + if err != nil { + t.Fatalf("GetLatestResumeCheckpoint() error = %v", err) + } + if gotResume == nil || gotResume.ID != secondResume.ID || gotResume.Turn != secondResume.Turn { + t.Fatalf("GetLatestResumeCheckpoint() = %#v, want %#v", gotResume, secondResume) + } +} + +func TestSQLiteCheckpointStorePruneAndRepair(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_prune", fixture.workspaceRoot) + + createdAt := time.Now().Add(-10 * time.Minute) + for i := 0; i < 4; i++ { + checkpointID := "cp-auto-" + string(rune('a'+i)) + input := checkpointInputFromSession(t, loaded, checkpointID, session.CheckpointReasonPreWrite, createdAt.Add(time.Duration(i)*time.Minute)) + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), input); err != nil { + t.Fatalf("CreateCheckpoint(%s) error = %v", checkpointID, err) + } + } + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession(t, loaded, "cp-manual", session.CheckpointReasonManual, time.Now())); err != nil { + t.Fatalf("CreateCheckpoint(manual) error = %v", err) + } + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession(t, loaded, "cp-guard", session.CheckpointReasonGuard, time.Now().Add(time.Minute))); err != nil { + t.Fatalf("CreateCheckpoint(guard) error = %v", err) + } + + pruned, err := fixture.checkpointStore.PruneExpiredCheckpoints(context.Background(), loaded.ID, 2) + if err != nil { + t.Fatalf("PruneExpiredCheckpoints() error = %v", err) + } + if pruned != 2 { + t.Fatalf("PruneExpiredCheckpoints() = %d, want 2", pruned) + } + + prunedRecord, prunedSessionCP, err := fixture.checkpointStore.GetCheckpoint(context.Background(), "cp-auto-a") + if err != nil { + t.Fatalf("GetCheckpoint(pruned) error = %v", err) + } + if prunedRecord.Status != session.CheckpointStatusPruned || prunedRecord.Restorable { + t.Fatalf("pruned record = %#v, want pruned and not restorable", prunedRecord) + } + if prunedSessionCP != nil { + t.Fatalf("expected pruned session snapshot to be deleted, got %#v", prunedSessionCP) + } + + manualRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), "cp-manual") + if err != nil { + t.Fatalf("GetCheckpoint(manual) error = %v", err) + } + if manualRecord.Status != session.CheckpointStatusAvailable || !manualRecord.Restorable { + t.Fatalf("manual record = %#v, want still available", manualRecord) + } + + db, err := fixture.checkpointStore.ensureDB(context.Background()) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + withSessionCPID := "cp-creating-with-session" + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO session_checkpoints (id, session_id, head_json, messages_json, created_at_ms) +VALUES (?, ?, ?, ?, ?) +`, "sc-creating", loaded.ID, `{}`, `[]`, time.Now().UnixMilli()); err != nil { + t.Fatalf("insert session_checkpoint error = %v", err) + } + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + withSessionCPID, + session.WorkspacePathKey(loaded.Workdir), + loaded.ID, + "run-repair", + loaded.Workdir, + time.Now().UnixMilli(), + string(session.CheckpointReasonPreWrite), + "", + "sc-creating", + "", + 0, + 1, + string(session.CheckpointStatusCreating), + ); err != nil { + t.Fatalf("insert creating checkpoint with session ref error = %v", err) + } + + orphanID := "cp-creating-orphan" + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + orphanID, + session.WorkspacePathKey(loaded.Workdir), + loaded.ID, + "run-repair", + loaded.Workdir, + time.Now().UnixMilli(), + string(session.CheckpointReasonPreWrite), + "", + "", + "", + 0, + 1, + string(session.CheckpointStatusCreating), + ); err != nil { + t.Fatalf("insert orphan checkpoint error = %v", err) + } + + repaired, err := fixture.checkpointStore.RepairCreatingCheckpoints(context.Background()) + if err != nil { + t.Fatalf("RepairCreatingCheckpoints() error = %v", err) + } + if repaired != 2 { + t.Fatalf("RepairCreatingCheckpoints() = %d, want 2", repaired) + } + + repairedRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), withSessionCPID) + if err != nil { + t.Fatalf("GetCheckpoint(repaired) error = %v", err) + } + if repairedRecord.Status != session.CheckpointStatusAvailable { + t.Fatalf("repaired record status = %q, want available", repairedRecord.Status) + } + + if _, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), orphanID); err == nil { + t.Fatalf("expected orphan creating checkpoint to be deleted") + } +} + +func TestSQLiteCheckpointStoreUsesSessionDatabasePath(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + expected := filepath.Clean(session.DatabasePath(fixture.baseDir, fixture.workspaceRoot)) + if filepath.Clean(fixture.checkpointStore.dbPath) != expected { + t.Fatalf("dbPath = %q, want %q", fixture.checkpointStore.dbPath, expected) + } +} diff --git a/internal/checkpoint/shadow_repo_test.go b/internal/checkpoint/shadow_repo_test.go new file mode 100644 index 00000000..109ddce5 --- /dev/null +++ b/internal/checkpoint/shadow_repo_test.go @@ -0,0 +1,147 @@ +package checkpoint + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestShadowRepoSnapshotRestoreAndConflictDetection(t *testing.T) { + t.Parallel() + + if available, _ := CheckGitAvailability(context.Background()); !available { + t.Skip("git is not available in test environment") + } + + projectDir := t.TempDir() + workdir := t.TempDir() + repo := NewShadowRepo(projectDir, workdir) + if err := repo.Init(context.Background()); err != nil { + t.Fatalf("Init() error = %v", err) + } + + targetFile := filepath.Join(workdir, "main.go") + if err := os.WriteFile(targetFile, []byte("package main\nconst version = 1\n"), 0o644); err != nil { + t.Fatalf("WriteFile(version1) error = %v", err) + } + + refOne := RefForCheckpoint("session-1", "cp-1") + hashOne, err := repo.Snapshot(context.Background(), refOne, "snapshot one") + if err != nil { + t.Fatalf("Snapshot(first) error = %v", err) + } + if strings.TrimSpace(hashOne) == "" { + t.Fatalf("Snapshot(first) returned empty hash") + } + + if repo.HasCodeChanges(context.Background()) { + t.Fatalf("expected clean worktree after first snapshot") + } + + if err := os.WriteFile(targetFile, []byte("package main\nconst version = 2\n"), 0o644); err != nil { + t.Fatalf("WriteFile(version2) error = %v", err) + } + if !repo.HasCodeChanges(context.Background()) { + t.Fatalf("expected HasCodeChanges() to detect modified file") + } + + refTwo := RefForCheckpoint("session-1", "cp-2") + if _, err := repo.Snapshot(context.Background(), refTwo, "snapshot two"); err != nil { + t.Fatalf("Snapshot(second) error = %v", err) + } + + resolved, err := repo.ResolveRef(context.Background(), refOne) + if err != nil { + t.Fatalf("ResolveRef() error = %v", err) + } + if resolved != hashOne { + t.Fatalf("ResolveRef() = %q, want %q", resolved, hashOne) + } + + conflict, err := repo.DetectConflicts(context.Background(), hashOne) + if err != nil { + t.Fatalf("DetectConflicts() error = %v", err) + } + if !conflict.HasConflict || len(conflict.ModifiedFiles) != 1 || conflict.ModifiedFiles[0] != "main.go" { + t.Fatalf("DetectConflicts() = %#v, want modified main.go", conflict) + } + + if err := repo.Restore(context.Background(), hashOne); err != nil { + t.Fatalf("Restore() error = %v", err) + } + content, err := os.ReadFile(targetFile) + if err != nil { + t.Fatalf("ReadFile(restored) error = %v", err) + } + if !strings.Contains(string(content), "version = 1") { + t.Fatalf("restored content = %q, want version 1", string(content)) + } + + if err := repo.HealthCheck(context.Background()); err != nil { + t.Fatalf("HealthCheck() error = %v", err) + } +} + +func TestShadowRepoInitRebuildsDamagedRepository(t *testing.T) { + t.Parallel() + + if available, _ := CheckGitAvailability(context.Background()); !available { + t.Skip("git is not available in test environment") + } + + projectDir := t.TempDir() + workdir := t.TempDir() + shadowDir := filepath.Join(projectDir, ".shadow") + if err := os.MkdirAll(shadowDir, 0o755); err != nil { + t.Fatalf("MkdirAll(shadowDir) error = %v", err) + } + if err := os.WriteFile(filepath.Join(shadowDir, "corrupted"), []byte("not a git dir"), 0o644); err != nil { + t.Fatalf("WriteFile(corrupted) error = %v", err) + } + + repo := NewShadowRepo(projectDir, workdir) + if err := repo.Init(context.Background()); err != nil { + t.Fatalf("Init() error = %v", err) + } + if !repo.IsAvailable() { + t.Fatalf("expected repo to be available after rebuild") + } + + backups, err := filepath.Glob(shadowDir + ".bak.*") + if err != nil { + t.Fatalf("Glob() error = %v", err) + } + if len(backups) == 0 { + t.Fatalf("expected damaged shadow repo backup to be created") + } + + if err := repo.Rebuild(context.Background()); err != nil { + t.Fatalf("Rebuild() error = %v", err) + } + backups, err = filepath.Glob(shadowDir + ".bak.*") + if err != nil { + t.Fatalf("Glob(after rebuild) error = %v", err) + } + if len(backups) < 2 { + t.Fatalf("expected rebuild to create another backup, got %v", backups) + } +} + +func TestShadowRepoHelpers(t *testing.T) { + t.Parallel() + + ref := RefForCheckpoint("session-a", "checkpoint-b") + if ref != "refs/neocode/sessions/session-a/checkpoints/checkpoint-b" { + t.Fatalf("RefForCheckpoint() = %q", ref) + } + + repo := NewShadowRepo(t.TempDir(), t.TempDir()) + if repo.HasCodeChanges(context.Background()) != true { + t.Fatalf("expected unavailable shadow repo to conservatively report code changes") + } + if err := repo.DeleteRef(context.Background(), "refs/unused"); err != nil { + t.Fatalf("DeleteRef() on unavailable repo error = %v", err) + } +} diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index b936544d..893279bc 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "neo-code/internal/checkpoint" "neo-code/internal/config" configstate "neo-code/internal/config/state" "neo-code/internal/gateway" @@ -63,6 +64,16 @@ type runtimeStub struct { getSnapshotSessionID string getSnapshotResult agentruntime.RuntimeSnapshot getSnapshotErr error + listCheckpointsID string + listCheckpointsOpts checkpoint.ListCheckpointOpts + listCheckpointsResult []agentsession.CheckpointRecord + listCheckpointsErr error + restoreCheckpointIn agentruntime.GatewayRestoreInput + restoreCheckpointOut agentruntime.RestoreResult + restoreCheckpointErr error + undoRestoreSessionID string + undoRestoreOut agentruntime.RestoreResult + undoRestoreErr error } const testBridgeSubjectID = bridgeLocalSubjectID @@ -150,6 +161,19 @@ func (s *runtimeStub) GetRuntimeSnapshot(_ context.Context, sessionID string) (a s.getSnapshotSessionID = sessionID return s.getSnapshotResult, s.getSnapshotErr } +func (s *runtimeStub) ListCheckpoints(_ context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + s.listCheckpointsID = sessionID + s.listCheckpointsOpts = opts + return s.listCheckpointsResult, s.listCheckpointsErr +} +func (s *runtimeStub) RestoreCheckpoint(_ context.Context, input agentruntime.GatewayRestoreInput) (agentruntime.RestoreResult, error) { + s.restoreCheckpointIn = input + return s.restoreCheckpointOut, s.restoreCheckpointErr +} +func (s *runtimeStub) UndoRestoreCheckpoint(_ context.Context, sessionID string) (agentruntime.RestoreResult, error) { + s.undoRestoreSessionID = sessionID + return s.undoRestoreOut, s.undoRestoreErr +} func (s *runtimeStub) DeleteSession(_ context.Context, _ string) error { return nil } @@ -207,6 +231,15 @@ func (r *runtimeWithoutCreator) ListAvailableSkills( ) ([]agentruntime.AvailableSkillState, error) { return r.base.ListAvailableSkills(ctx, sessionID) } +func (r *runtimeWithoutCreator) ListCheckpoints(ctx context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + return r.base.ListCheckpoints(ctx, sessionID, opts) +} +func (r *runtimeWithoutCreator) RestoreCheckpoint(ctx context.Context, input agentruntime.GatewayRestoreInput) (agentruntime.RestoreResult, error) { + return r.base.RestoreCheckpoint(ctx, input) +} +func (r *runtimeWithoutCreator) UndoRestoreCheckpoint(ctx context.Context, sessionID string) (agentruntime.RestoreResult, error) { + return r.base.UndoRestoreCheckpoint(ctx, sessionID) +} type bridgeSessionStoreStub struct { deleteFn func(ctx context.Context, id string) error @@ -226,6 +259,72 @@ func (s *bridgeSessionStoreStub) UpdateSessionState(ctx context.Context, input a return nil } +func TestGatewayRuntimePortBridgeCheckpointOperations(t *testing.T) { + stub := &runtimeStub{ + listCheckpointsResult: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-1", + SessionID: "session-1", + Reason: agentsession.CheckpointReasonCompact, + Status: agentsession.CheckpointStatusAvailable, + Restorable: true, + CreatedAt: time.UnixMilli(1234), + }, + }, + restoreCheckpointOut: agentruntime.RestoreResult{ + CheckpointID: "cp-1", + SessionID: "session-1", + }, + undoRestoreOut: agentruntime.RestoreResult{ + CheckpointID: "guard-1", + SessionID: "session-1", + }, + } + + bridge := &gatewayRuntimePortBridge{runtime: stub} + + entries, err := bridge.ListCheckpoints(context.Background(), gateway.ListCheckpointsInput{ + SessionID: " session-1 ", + Limit: 5, + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if stub.listCheckpointsID != "session-1" || stub.listCheckpointsOpts.Limit != 5 || !stub.listCheckpointsOpts.RestorableOnly { + t.Fatalf("ListCheckpoints() forwarded (%q, %#v)", stub.listCheckpointsID, stub.listCheckpointsOpts) + } + if len(entries) != 1 || entries[0].CheckpointID != "cp-1" || entries[0].Reason != string(agentsession.CheckpointReasonCompact) { + t.Fatalf("ListCheckpoints() = %#v", entries) + } + + restoreResult, err := bridge.RestoreCheckpoint(context.Background(), gateway.CheckpointRestoreInput{ + SessionID: " session-1 ", + CheckpointID: " cp-1 ", + Force: true, + }) + if err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + if stub.restoreCheckpointIn.SessionID != "session-1" || stub.restoreCheckpointIn.CheckpointID != "cp-1" || !stub.restoreCheckpointIn.Force { + t.Fatalf("RestoreCheckpoint() forwarded %#v", stub.restoreCheckpointIn) + } + if restoreResult.CheckpointID != "cp-1" || restoreResult.SessionID != "session-1" || restoreResult.HasConflict { + t.Fatalf("RestoreCheckpoint() = %#v", restoreResult) + } + + undoResult, err := bridge.UndoRestore(context.Background(), gateway.UndoRestoreInput{SessionID: " session-1 "}) + if err != nil { + t.Fatalf("UndoRestore() error = %v", err) + } + if stub.undoRestoreSessionID != "session-1" { + t.Fatalf("UndoRestore() forwarded session %q", stub.undoRestoreSessionID) + } + if undoResult.CheckpointID != "guard-1" || undoResult.SessionID != "session-1" { + t.Fatalf("UndoRestore() = %#v", undoResult) + } +} + var testSessionStore bridgeSessionStore = &bridgeSessionStoreStub{} func TestNewGatewayRuntimePortBridgeRuntimeUnavailable(t *testing.T) { @@ -1351,7 +1450,7 @@ func TestResolveListFilesRootPriorities(t *testing.T) { // priority 2: session workdir (store implements bridgeSessionLoader) loaderStore := &bridgeSessionStoreWithLoader{ bridgeSessionStoreStub: bridgeSessionStoreStub{}, - session: agentsession.Session{Workdir: subDir}, + session: agentsession.Session{Workdir: subDir}, } bridge2, _ := newGatewayRuntimePortBridge(context.Background(), &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, loaderStore) defer bridge2.Close() @@ -1958,9 +2057,9 @@ func TestModelDisplayName(t *testing.T) { func TestGatewayRuntimePortBridgeCancelRunAndSnapshots(t *testing.T) { stub := &runtimeStub{ - eventsCh: make(chan agentruntime.RuntimeEvent, 1), - cancelReturn: true, - listTodosErr: errors.New("todo failed"), + eventsCh: make(chan agentruntime.RuntimeEvent, 1), + cancelReturn: true, + listTodosErr: errors.New("todo failed"), getSnapshotErr: errors.New("snapshot failed"), } bridge, err := newGatewayRuntimePortBridge(context.Background(), stub, testSessionStore) diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go new file mode 100644 index 00000000..d9e81a57 --- /dev/null +++ b/internal/runtime/checkpoint_flow_test.go @@ -0,0 +1,404 @@ +package runtime + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "neo-code/internal/checkpoint" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +type checkpointStoreSpy struct { + lastResume agentsession.ResumeCheckpoint +} + +func (s *checkpointStoreSpy) CreateCheckpoint(context.Context, checkpoint.CreateCheckpointInput) (agentsession.CheckpointRecord, error) { + return agentsession.CheckpointRecord{}, nil +} + +func (s *checkpointStoreSpy) ListCheckpoints(context.Context, string, checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + return nil, nil +} + +func (s *checkpointStoreSpy) GetCheckpoint(context.Context, string) (agentsession.CheckpointRecord, *agentsession.SessionCheckpoint, error) { + return agentsession.CheckpointRecord{}, nil, nil +} + +func (s *checkpointStoreSpy) UpdateCheckpointStatus(context.Context, string, agentsession.CheckpointStatus) error { + return nil +} + +func (s *checkpointStoreSpy) GetLatestResumeCheckpoint(context.Context, string) (*agentsession.ResumeCheckpoint, error) { + return nil, nil +} + +func (s *checkpointStoreSpy) RestoreCheckpoint(context.Context, checkpoint.RestoreCheckpointInput) error { + return nil +} + +func (s *checkpointStoreSpy) SetResumeCheckpoint(_ context.Context, rc agentsession.ResumeCheckpoint) error { + s.lastResume = rc + return nil +} + +func (s *checkpointStoreSpy) PruneExpiredCheckpoints(context.Context, string, int) (int, error) { + return 0, nil +} + +func (s *checkpointStoreSpy) RepairCreatingCheckpoints(context.Context) (int, error) { + return 0, nil +} + +type runtimeCheckpointFixture struct { + service *Service + sessionStore *agentsession.SQLiteStore + checkpointStore *checkpoint.SQLiteCheckpointStore + shadowRepo *checkpoint.ShadowRepo + workdir string + projectDir string + session agentsession.Session +} + +func newRuntimeCheckpointFixture(t *testing.T, withShadow bool) runtimeCheckpointFixture { + t.Helper() + + baseDir := t.TempDir() + workdir := t.TempDir() + projectDir := t.TempDir() + + sessionStore := agentsession.NewSQLiteStore(baseDir, workdir) + t.Cleanup(func() { + _ = sessionStore.Close() + }) + + checkpointStore := checkpoint.NewSQLiteCheckpointStore(agentsession.DatabasePath(baseDir, workdir)) + t.Cleanup(func() { + _ = checkpointStore.Close() + }) + + created, err := sessionStore.CreateSession(context.Background(), agentsession.CreateSessionInput{ + ID: "runtime-checkpoint-session", + Title: "runtime checkpoint", + Head: agentsession.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + TaskState: agentsession.TaskState{ + Goal: "initial goal", + VerificationProfile: agentsession.VerificationProfileTaskOnly, + }, + }, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + if err := sessionStore.AppendMessages(context.Background(), agentsession.AppendMessagesInput{ + SessionID: created.ID, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("before restore"), + }, + }, + }, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) + } + loaded, err := sessionStore.LoadSession(context.Background(), created.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) + } + + var shadowRepo *checkpoint.ShadowRepo + if withShadow { + shadowRepo = checkpoint.NewShadowRepo(projectDir, workdir) + if err := shadowRepo.Init(context.Background()); err != nil { + t.Fatalf("Init shadow repo error = %v", err) + } + } + + return runtimeCheckpointFixture{ + service: &Service{ + sessionStore: sessionStore, + checkpointStore: checkpointStore, + shadowRepo: shadowRepo, + events: make(chan RuntimeEvent, 32), + }, + sessionStore: sessionStore, + checkpointStore: checkpointStore, + shadowRepo: shadowRepo, + workdir: workdir, + projectDir: projectDir, + session: loaded, + } +} + +func createStoredCheckpointFromSession( + t *testing.T, + cpStore *checkpoint.SQLiteCheckpointStore, + shadowRepo *checkpoint.ShadowRepo, + loaded agentsession.Session, + checkpointID string, +) agentsession.CheckpointRecord { + t.Helper() + + headJSON, err := json.Marshal(loaded.HeadSnapshot()) + if err != nil { + t.Fatalf("Marshal(head) error = %v", err) + } + messagesJSON, err := json.Marshal(loaded.Messages) + if err != nil { + t.Fatalf("Marshal(messages) error = %v", err) + } + + ref := checkpoint.RefForCheckpoint(loaded.ID, checkpointID) + if _, err := shadowRepo.Snapshot(context.Background(), ref, checkpointID); err != nil { + t.Fatalf("Snapshot(%s) error = %v", checkpointID, err) + } + + record, err := cpStore.CreateCheckpoint(context.Background(), checkpoint.CreateCheckpointInput{ + Record: agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(loaded.Workdir), + SessionID: loaded.ID, + RunID: "run-" + checkpointID, + Workdir: loaded.Workdir, + CreatedAt: time.Now().Add(-time.Minute), + Reason: agentsession.CheckpointReasonPreWrite, + CodeCheckpointRef: ref, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + }, + SessionCP: agentsession.SessionCheckpoint{ + ID: "sc-" + checkpointID, + SessionID: loaded.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: time.Now().Add(-time.Minute), + }, + }) + if err != nil { + t.Fatalf("CreateCheckpoint(%s) error = %v", checkpointID, err) + } + return record +} + +func TestCreatePerTurnCheckpointVariants(t *testing.T) { + t.Run("full checkpoint when code changed", func(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t, true) + target := fixture.workdir + "/main.go" + if err := os.WriteFile(target, []byte("package main\nconst value = 1\n"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if _, err := fixture.shadowRepo.Snapshot(context.Background(), "refs/heads/base", "baseline"); err != nil { + t.Fatalf("Snapshot(baseline) error = %v", err) + } + if err := os.WriteFile(target, []byte("package main\nconst value = 2\n"), 0o644); err != nil { + t.Fatalf("WriteFile(modified) error = %v", err) + } + + state := newRunState("run-full", fixture.session) + if err := fixture.service.createPerTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createPerTurnCheckpoint() error = %v", err) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonPreWrite || records[0].CodeCheckpointRef == "" { + t.Fatalf("records = %#v, want one full checkpoint", records) + } + }) + + t.Run("degraded checkpoint when repo unavailable", func(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t, false) + state := newRunState("run-degraded", fixture.session) + if err := fixture.service.createPerTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createPerTurnCheckpoint() error = %v", err) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonPreWriteDegraded || records[0].CodeCheckpointRef != "" { + t.Fatalf("records = %#v, want one degraded checkpoint", records) + } + }) + + t.Run("degraded checkpoint when no code changes", func(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t, true) + target := fixture.workdir + "/main.go" + if err := os.WriteFile(target, []byte("package main\nconst value = 1\n"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if _, err := fixture.shadowRepo.Snapshot(context.Background(), "refs/heads/base", "baseline"); err != nil { + t.Fatalf("Snapshot(baseline) error = %v", err) + } + + state := newRunState("run-noop", fixture.session) + if err := fixture.service.createPerTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createPerTurnCheckpoint() error = %v", err) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonPreWriteDegraded || records[0].CodeCheckpointRef != "" { + t.Fatalf("records = %#v, want session-only checkpoint for no-op turn", records) + } + }) +} + +func TestCreateCompactCheckpointAndResumeCheckpoint(t *testing.T) { + t.Parallel() + + fixture := newRuntimeCheckpointFixture(t, true) + if err := os.WriteFile(fixture.workdir+"/compact.txt", []byte("compact"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + fixture.service.createCompactCheckpoint(context.Background(), "run-compact", fixture.session) + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonCompact { + t.Fatalf("records = %#v, want compact checkpoint", records) + } + + state := newRunState("run-resume", fixture.session) + state.turn = 3 + spy := &checkpointStoreSpy{} + service := &Service{checkpointStore: spy} + service.updateResumeCheckpoint(context.Background(), &state, "verify", "running") + + if spy.lastResume.SessionID != fixture.session.ID || spy.lastResume.RunID != "run-resume" || spy.lastResume.Turn != 3 || spy.lastResume.Phase != "verify" { + t.Fatalf("SetResumeCheckpoint() captured %#v", spy.lastResume) + } +} + +func TestRestoreCheckpointAndUndoRestore(t *testing.T) { + t.Parallel() + + fixture := newRuntimeCheckpointFixture(t, true) + target := fixture.workdir + "/restore.txt" + if err := os.WriteFile(target, []byte("version one"), 0o644); err != nil { + t.Fatalf("WriteFile(version one) error = %v", err) + } + + originalSession, err := fixture.sessionStore.LoadSession(context.Background(), fixture.session.ID) + if err != nil { + t.Fatalf("LoadSession(original) error = %v", err) + } + record := createStoredCheckpointFromSession(t, fixture.checkpointStore, fixture.shadowRepo, originalSession, "cp-restore") + + if err := os.WriteFile(target, []byte("version two"), 0o644); err != nil { + t.Fatalf("WriteFile(version two) error = %v", err) + } + if err := fixture.sessionStore.UpdateSessionState(context.Background(), agentsession.UpdateSessionStateInput{ + SessionID: originalSession.ID, + UpdatedAt: time.Now(), + Title: "mutated", + Head: agentsession.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: fixture.workdir, + TaskState: agentsession.TaskState{ + Goal: "mutated goal", + VerificationProfile: agentsession.VerificationProfileTaskOnly, + }, + }, + }); err != nil { + t.Fatalf("UpdateSessionState() error = %v", err) + } + if err := fixture.sessionStore.AppendMessages(context.Background(), agentsession.AppendMessagesInput{ + SessionID: originalSession.ID, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("after snapshot"), + }, + }, + }, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: fixture.workdir, + }); err != nil { + t.Fatalf("AppendMessages(mutated) error = %v", err) + } + + conflictResult, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: originalSession.ID, + CheckpointID: record.CheckpointID, + }) + if err == nil || conflictResult.Conflict == nil || !conflictResult.Conflict.HasConflict { + t.Fatalf("RestoreCheckpoint(conflict) = (%#v, %v), want conflict error", conflictResult, err) + } + + restoreResult, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: originalSession.ID, + CheckpointID: record.CheckpointID, + Force: true, + }) + if err != nil { + t.Fatalf("RestoreCheckpoint(force) error = %v", err) + } + if restoreResult.CheckpointID != record.CheckpointID { + t.Fatalf("RestoreCheckpoint(force) = %#v", restoreResult) + } + + restoredContent, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile(restored) error = %v", err) + } + if string(restoredContent) != "version one" { + t.Fatalf("restored content = %q, want version one", string(restoredContent)) + } + + restoredSession, err := fixture.sessionStore.LoadSession(context.Background(), originalSession.ID) + if err != nil { + t.Fatalf("LoadSession(restored) error = %v", err) + } + if restoredSession.TaskState.Goal != originalSession.TaskState.Goal || len(restoredSession.Messages) != len(originalSession.Messages) { + t.Fatalf("restored session = %#v, want original goal/messages", restoredSession) + } + + undoResult, err := fixture.service.UndoRestoreCheckpoint(context.Background(), originalSession.ID) + if err != nil { + t.Fatalf("UndoRestoreCheckpoint() error = %v", err) + } + if undoResult.SessionID != originalSession.ID { + t.Fatalf("UndoRestoreCheckpoint() = %#v", undoResult) + } + + undoneContent, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile(undone) error = %v", err) + } + if string(undoneContent) != "version two" { + t.Fatalf("undone content = %q, want version two", string(undoneContent)) + } + + undoneSession, err := fixture.sessionStore.LoadSession(context.Background(), originalSession.ID) + if err != nil { + t.Fatalf("LoadSession(undone) error = %v", err) + } + if undoneSession.TaskState.Goal != "mutated goal" || len(undoneSession.Messages) != len(originalSession.Messages)+1 { + t.Fatalf("undone session = %#v, want mutated session restored", undoneSession) + } +}