Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions internal/checkpoint/checkpoint_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package checkpoint

import (
"context"
"database/sql"
"encoding/json"
"path/filepath"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -415,3 +417,81 @@ func TestSQLiteCheckpointStoreUsesSessionDatabasePath(t *testing.T) {
t.Fatalf("dbPath = %q, want %q", fixture.checkpointStore.dbPath, expected)
}
}

func TestSQLiteCheckpointStoreSharedDBAndHelpers(t *testing.T) {
t.Parallel()

fixture := newCheckpointStoreFixture(t)
loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_shared_db", fixture.workspaceRoot)
db, err := fixture.checkpointStore.ensureDB(context.Background())
if err != nil {
t.Fatalf("ensureDB() error = %v", err)
}

shared := NewSQLiteCheckpointStoreWithDB(db)
if shared.ownsDB {
t.Fatal("shared checkpoint store should not own injected db")
}
if err := shared.Close(); err != nil {
t.Fatalf("Close(shared) error = %v", err)
}
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("db should remain open after shared Close(), got %v", err)
}
if _, err := shared.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{}); err != nil {
t.Fatalf("shared ListCheckpoints() error = %v", err)
}

if got := marshalPlanField(nil); got != "" {
t.Fatalf("marshalPlanField(nil) = %q, want empty", got)
}
var nilPlan *session.PlanArtifact
if got := marshalPlanField(nilPlan); got != "" {
t.Fatalf("marshalPlanField(nil pointer) = %q, want empty", got)
}
if got := marshalPlanField(map[string]any{"step": "verify"}); !strings.Contains(got, `"step":"verify"`) {
t.Fatalf("marshalPlanField(map) = %q", got)
}
if got := marshalPlanField(func() {}); got != "" {
t.Fatalf("marshalPlanField(unmarshalable) = %q, want empty", got)
}
if got := marshalHeadField(func() {}); got != "null" {
t.Fatalf("marshalHeadField(unmarshalable) = %q, want null", got)
}
}

func TestSQLiteCheckpointStoreErrorsAndEmptyResults(t *testing.T) {
t.Parallel()

fixture := newCheckpointStoreFixture(t)
loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_empty_resume", fixture.workspaceRoot)
if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), "missing", session.CheckpointStatusAvailable); err == nil {
t.Fatal("expected UpdateCheckpointStatus() to fail for missing checkpoint")
}

rc, err := fixture.checkpointStore.GetLatestResumeCheckpoint(context.Background(), loaded.ID)
if err != nil {
t.Fatalf("GetLatestResumeCheckpoint(missing) error = %v", err)
}
if rc != nil {
t.Fatalf("GetLatestResumeCheckpoint(missing) = %#v, want nil", rc)
}

ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := fixture.checkpointStore.ensureDB(context.Background()); err != nil {
t.Fatalf("ensureDB() error = %v", err)
}
if _, err := fixture.checkpointStore.CreateCheckpoint(ctx, CreateCheckpointInput{}); err == nil {
t.Fatal("expected CreateCheckpoint() to honor canceled context")
}
}

func TestNewSQLiteCheckpointStoreWithNilDBClose(t *testing.T) {
t.Parallel()

store := NewSQLiteCheckpointStoreWithDB((*sql.DB)(nil))
if err := store.Close(); err != nil {
t.Fatalf("Close(nil db) error = %v", err)
}
}
138 changes: 135 additions & 3 deletions internal/gateway/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type bootstrapRuntimeStub struct {
upsertMCPServerFn func(ctx context.Context, input UpsertMCPServerInput) error
setMCPEnabledFn func(ctx context.Context, input SetMCPServerEnabledInput) error
deleteMCPServerFn func(ctx context.Context, input DeleteMCPServerInput) error
listCheckpointsFn func(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error)
restoreCheckpointFn func(ctx context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error)
undoRestoreFn func(ctx context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error)
}

func (s *bootstrapRuntimeStub) Run(ctx context.Context, input RunInput) error {
Expand Down Expand Up @@ -249,15 +252,24 @@ func (s *bootstrapRuntimeStub) CreateSession(ctx context.Context, input CreateSe
return strings.TrimSpace(input.SessionID), nil
}

func (s *bootstrapRuntimeStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) {
func (s *bootstrapRuntimeStub) ListCheckpoints(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) {
if s != nil && s.listCheckpointsFn != nil {
return s.listCheckpointsFn(ctx, input)
}
return nil, nil
}

func (s *bootstrapRuntimeStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) {
func (s *bootstrapRuntimeStub) RestoreCheckpoint(ctx context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) {
if s != nil && s.restoreCheckpointFn != nil {
return s.restoreCheckpointFn(ctx, input)
}
return CheckpointRestoreResult{}, nil
}

func (s *bootstrapRuntimeStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) {
func (s *bootstrapRuntimeStub) UndoRestore(ctx context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) {
if s != nil && s.undoRestoreFn != nil {
return s.undoRestoreFn(ctx, input)
}
return CheckpointRestoreResult{}, nil
}

Expand Down Expand Up @@ -431,6 +443,126 @@ func TestDecodeSessionSkillAndSnapshotPayloadBranches(t *testing.T) {
}
}

func TestCheckpointFrameHandlers(t *testing.T) {
t.Run("list checkpoints success", func(t *testing.T) {
runtime := &bootstrapRuntimeStub{
listCheckpointsFn: func(_ context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) {
if input.SubjectID != "subject-1" || input.SessionID != "session-1" {
t.Fatalf("input = %#v", input)
}
return []CheckpointEntry{{CheckpointID: "cp-1", SessionID: "session-1"}}, nil
},
}
authState := NewConnectionAuthState()
authState.MarkAuthenticated("subject-1")
ctx := WithConnectionAuthState(context.Background(), authState)

response := handleListCheckpointsFrame(ctx, MessageFrame{
Type: FrameTypeRequest,
Action: FrameActionListCheckpoints,
RequestID: "req-checkpoint-list",
SessionID: " session-1 ",
}, runtime)

if response.Type != FrameTypeAck || response.Action != FrameActionListCheckpoints {
t.Fatalf("response = %#v", response)
}
entries, ok := response.Payload.([]CheckpointEntry)
if !ok || len(entries) != 1 || entries[0].CheckpointID != "cp-1" {
t.Fatalf("payload = %#v", response.Payload)
}
})

t.Run("restore checkpoint success", func(t *testing.T) {
runtime := &bootstrapRuntimeStub{
restoreCheckpointFn: func(_ context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) {
if input.SubjectID != "subject-1" || input.SessionID != "session-1" || input.CheckpointID != "cp-1" || !input.Force {
t.Fatalf("input = %#v", input)
}
return CheckpointRestoreResult{CheckpointID: input.CheckpointID, SessionID: input.SessionID}, nil
},
}
authState := NewConnectionAuthState()
authState.MarkAuthenticated("subject-1")
ctx := WithConnectionAuthState(context.Background(), authState)

response := handleRestoreCheckpointFrame(ctx, MessageFrame{
Type: FrameTypeRequest,
Action: FrameActionRestoreCheckpoint,
RequestID: "req-checkpoint-restore",
SessionID: " session-1 ",
Payload: map[string]any{
"checkpoint_id": " cp-1 ",
"force": true,
},
}, runtime)

if response.Type != FrameTypeAck || response.Action != FrameActionRestoreCheckpoint || response.SessionID != "session-1" {
t.Fatalf("response = %#v", response)
}
result, ok := response.Payload.(CheckpointRestoreResult)
if !ok || result.CheckpointID != "cp-1" {
t.Fatalf("payload = %#v", response.Payload)
}
})

t.Run("undo restore success", func(t *testing.T) {
runtime := &bootstrapRuntimeStub{
undoRestoreFn: func(_ context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) {
if input.SubjectID != "subject-1" || input.SessionID != "session-1" {
t.Fatalf("input = %#v", input)
}
return CheckpointRestoreResult{CheckpointID: "cp-guard", SessionID: input.SessionID}, nil
},
}
authState := NewConnectionAuthState()
authState.MarkAuthenticated("subject-1")
ctx := WithConnectionAuthState(context.Background(), authState)

response := handleUndoRestoreFrame(ctx, MessageFrame{
Type: FrameTypeRequest,
Action: FrameActionUndoRestore,
RequestID: "req-checkpoint-undo",
SessionID: " session-1 ",
}, runtime)

if response.Type != FrameTypeAck || response.Action != FrameActionUndoRestore || response.SessionID != "session-1" {
t.Fatalf("response = %#v", response)
}
result, ok := response.Payload.(CheckpointRestoreResult)
if !ok || result.CheckpointID != "cp-guard" {
t.Fatalf("payload = %#v", response.Payload)
}
})
}

func TestDecodeCheckpointRestorePayloadBranches(t *testing.T) {
t.Parallel()

params := decodeCheckpointRestorePayload(map[string]any{
"session_id": " session-1 ",
"checkpoint_id": " cp-1 ",
"force": true,
})
if params.SessionID != "session-1" || params.CheckpointID != "cp-1" || !params.Force {
t.Fatalf("decode map payload = %#v", params)
}

params = decodeCheckpointRestorePayload(CheckpointRestoreInput{
SessionID: "session-2",
CheckpointID: "cp-2",
Force: true,
})
if params.SessionID != "session-2" || params.CheckpointID != "cp-2" || !params.Force {
t.Fatalf("decode struct payload = %#v", params)
}

params = decodeCheckpointRestorePayload(invalidJSONMarshaler{})
if params != (CheckpointRestoreInput{}) {
t.Fatalf("marshal failure should return zero input, got %#v", params)
}
}

func TestDispatchRequestFrameWakeOpenURLReviewSuccess(t *testing.T) {
createInputs := make(chan CreateSessionInput, 1)
stub := &bootstrapRuntimeStub{
Expand Down
58 changes: 55 additions & 3 deletions internal/runtime/checkpoint_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,21 @@ import (
)

type checkpointStoreSpy struct {
lastResume agentsession.ResumeCheckpoint
lastResume agentsession.ResumeCheckpoint
listRecords []agentsession.CheckpointRecord
listSessionID string
listOpts checkpoint.ListCheckpointOpts
listErr error
}

func (s *checkpointStoreSpy) CreateCheckpoint(context.Context, checkpoint.CreateCheckpointInput) (agentsession.CheckpointRecord, error) {
return agentsession.CheckpointRecord{}, nil
}

func (s *checkpointStoreSpy) ListCheckpoints(context.Context, string, checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) {
return nil, nil
func (s *checkpointStoreSpy) ListCheckpoints(_ context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) {
s.listSessionID = sessionID
s.listOpts = opts
return s.listRecords, s.listErr
}

func (s *checkpointStoreSpy) GetCheckpoint(context.Context, string) (agentsession.CheckpointRecord, *agentsession.SessionCheckpoint, error) {
Expand Down Expand Up @@ -290,6 +296,52 @@ func TestCreateCompactCheckpointAndResumeCheckpoint(t *testing.T) {
}
}

func TestRuntimeCheckpointFacadeMethods(t *testing.T) {
t.Run("list checkpoints delegates to store", func(t *testing.T) {
spy := &checkpointStoreSpy{
listRecords: []agentsession.CheckpointRecord{{CheckpointID: "cp-1"}},
}
service := &Service{checkpointStore: spy}

records, err := service.ListCheckpoints(context.Background(), "session-1", checkpoint.ListCheckpointOpts{
Limit: 5,
RestorableOnly: true,
})
if err != nil {
t.Fatalf("ListCheckpoints() error = %v", err)
}
if spy.listSessionID != "session-1" || spy.listOpts.Limit != 5 || !spy.listOpts.RestorableOnly {
t.Fatalf("spy captured session=%q opts=%#v", spy.listSessionID, spy.listOpts)
}
if len(records) != 1 || records[0].CheckpointID != "cp-1" {
t.Fatalf("records = %#v", records)
}
})

t.Run("list checkpoints reports unavailable store", func(t *testing.T) {
service := &Service{}
if _, err := service.ListCheckpoints(context.Background(), "session-1", checkpoint.ListCheckpointOpts{}); err == nil {
t.Fatal("expected error when checkpoint store is unavailable")
}
})

t.Run("set checkpoint dependencies stores references", func(t *testing.T) {
service := &Service{}
store := &checkpointStoreSpy{}
repo := checkpoint.NewShadowRepo(t.TempDir(), t.TempDir())

service.SetCheckpointDependencies(store, repo)
if service.checkpointStore != store || service.shadowRepo != repo {
t.Fatalf("service checkpoint dependencies not set correctly")
}
})

t.Run("update runtime session after restore is no-op", func(t *testing.T) {
service := &Service{}
service.updateRuntimeSessionAfterRestore("session-1", agentsession.SessionHead{}, nil)
})
}

func TestRestoreCheckpointAndUndoRestore(t *testing.T) {
t.Parallel()

Expand Down
Loading