Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions pkg/cli/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,39 @@ func (e RuntimeError) Unwrap() error {
return e.Err
}

// maxAutoExtensions is the maximum number of times --yolo mode will
// auto-continue when max iterations is reached, to prevent infinite loops.
const maxAutoExtensions = 5

// maxIterAction describes what the caller should do after a MaxIterationsReachedEvent.
type maxIterAction int

const (
maxIterContinue maxIterAction = iota // auto-approved, keep running
maxIterStop // safety cap reached, caller should stop
maxIterPrompt // not in yolo mode, caller should prompt the user
)

// handleMaxIterationsAutoApprove decides whether to auto-extend iterations in
// --yolo mode. Returns maxIterContinue (approved), maxIterStop (cap reached),
// or maxIterPrompt (not in auto-approve mode, caller should ask the user).
func handleMaxIterationsAutoApprove(autoApprove bool, autoExtensions *int, maxIter int) maxIterAction {
if !autoApprove {
return maxIterPrompt
}
*autoExtensions++
if *autoExtensions <= maxAutoExtensions {
slog.Info("Auto-extending iterations in yolo mode",
"extension", *autoExtensions,
"max_extensions", maxAutoExtensions,
"current_max", maxIter)
return maxIterContinue
}
slog.Warn("Max auto-extensions reached in yolo mode, stopping",
"total_extensions", *autoExtensions)
return maxIterStop
}

// Config holds configuration for running an agent in CLI mode
type Config struct {
AppName string
Expand Down Expand Up @@ -60,6 +93,8 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
var lastErr error

oneLoop := func(text string, rd io.Reader) error {
autoExtensions := 0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 HIGH SEVERITY: Safety cap can be bypassed in multi-turn mode

The autoExtensions counter is declared inside oneLoop(), which gets reset to 0 on every call. This defeats the purpose of the safety cap:

  • In multi-turn mode (line 280-281), oneLoop() is called once per message
  • Each call gets a fresh counter starting at 0
  • If a user provides 10 messages, each can use up to 5 auto-extensions = 50 total extensions instead of 5

Fix: Move autoExtensions := 0 outside of oneLoop() to the Run() function level, so the counter persists across all messages in the session:

func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess *session.Session, userMessages []string) error {
    // ... existing code ...
    var lastErr error
    autoExtensions := 0  // <-- Move here
    
    oneLoop := func(text string, rd io.Reader) error {
        // Remove: autoExtensions := 0
        // ... rest of function ...
    }
}

This ensures the safety cap applies to the entire session, not per-message.


userInput := strings.TrimSpace(text)
if userInput == "" {
return nil
Expand All @@ -74,6 +109,14 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
if !cfg.AutoApprove {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 HIGH SEVERITY: Missing Resume call causes runtime to hang

When cfg.AutoApprove is true in JSON mode, this code does not call rt.Resume() at all. This leaves the runtime waiting indefinitely for a confirmation that never comes.

The current logic:

  • If !cfg.AutoApprove: calls rt.Resume(ctx, runtime.ResumeReject(""))
  • If cfg.AutoApprove: does nothing

Fix: Add an else branch to approve when AutoApprove is enabled:

case *runtime.ToolCallConfirmationEvent:
    if !cfg.AutoApprove {
        rt.Resume(ctx, runtime.ResumeReject(""))
    } else {
        rt.Resume(ctx, runtime.ResumeApprove())  // <-- Add this
    }

This mirrors the MaxIterationsReachedEvent handling pattern on lines 112-119 and ensures consistency with non-JSON mode (lines 152-167).

rt.Resume(ctx, runtime.ResumeReject(""))
}
case *runtime.MaxIterationsReachedEvent:
switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) {
case maxIterContinue:
rt.Resume(ctx, runtime.ResumeApprove())
default: // maxIterStop or maxIterPrompt (no interactive prompt in JSON mode)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ MEDIUM SEVERITY: Silent rejection without user feedback

When not in yolo mode or when the safety cap is reached, this code rejects and returns with no visible indication to the user consuming the JSON output.

The problem:

  • User sees MaxIterationsReachedEvent in JSON output
  • Runtime calls rt.Resume(ctx, runtime.ResumeReject("")) silently
  • Process terminates with no event showing what action was taken
  • Resume actions are sent through an internal channel, not emitted as events

Impact: JSON API consumers cannot determine why execution stopped or what decision was made. The comment acknowledges "no interactive prompt in JSON mode" but doesn't address the lack of visibility.

Consider: Emitting a custom event before rejecting to inform JSON consumers of the decision, similar to how interactive mode shows clear feedback.

rt.Resume(ctx, runtime.ResumeReject(""))
return nil
}
case *runtime.ErrorEvent:
return fmt.Errorf("%s", e.Error)
}
Expand Down Expand Up @@ -153,16 +196,24 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
out.PrintError(lastErr)
}
case *runtime.MaxIterationsReachedEvent:
result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations)
switch result {
case ConfirmationApprove:
switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) {
case maxIterContinue:
rt.Resume(ctx, runtime.ResumeApprove())
case ConfirmationReject:
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
case ConfirmationAbort:
case maxIterStop:
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
case maxIterPrompt:
result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations)
switch result {
case ConfirmationApprove:
rt.Resume(ctx, runtime.ResumeApprove())
case ConfirmationReject:
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
case ConfirmationAbort:
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
}
}
case *runtime.ElicitationRequestEvent:
serverURL, ok := e.Meta["cagent/server_url"].(string)
Expand Down
205 changes: 205 additions & 0 deletions pkg/cli/runner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package cli

import (
"bytes"
"context"
"sync"
"testing"

"gotest.tools/v3/assert"

"github.com/docker/cagent/pkg/runtime"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/sessiontitle"
"github.com/docker/cagent/pkg/tools"
mcptools "github.com/docker/cagent/pkg/tools/mcp"
)

// mockRuntime implements runtime.Runtime for testing the CLI runner.
// It emits pre-configured events from RunStream and records Resume calls.
type mockRuntime struct {
events []runtime.Event

mu sync.Mutex
resumes []runtime.ResumeRequest
}

func (m *mockRuntime) CurrentAgentName() string { return "test" }
func (m *mockRuntime) CurrentAgentInfo(context.Context) runtime.CurrentAgentInfo {
return runtime.CurrentAgentInfo{Name: "test"}
}
func (m *mockRuntime) SetCurrentAgent(string) error { return nil }
func (m *mockRuntime) CurrentAgentTools(context.Context) ([]tools.Tool, error) { return nil, nil }
func (m *mockRuntime) EmitStartupInfo(context.Context, chan runtime.Event) {}
func (m *mockRuntime) ResetStartupInfo() {}
func (m *mockRuntime) Run(context.Context, *session.Session) ([]session.Message, error) {
return nil, nil
}

func (m *mockRuntime) ResumeElicitation(context.Context, tools.ElicitationAction, map[string]any) error {
return nil
}
func (m *mockRuntime) SessionStore() session.Store { return nil }
func (m *mockRuntime) Summarize(context.Context, *session.Session, string, chan runtime.Event) {}
func (m *mockRuntime) PermissionsInfo() *runtime.PermissionsInfo { return nil }
func (m *mockRuntime) CurrentAgentSkillsEnabled() bool { return false }
func (m *mockRuntime) CurrentMCPPrompts(context.Context) map[string]mcptools.PromptInfo {
return nil
}

func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]string) (string, error) {
return "", nil
}
func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil }
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
func (m *mockRuntime) Close() error { return nil }
func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {}

func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) {
m.mu.Lock()
defer m.mu.Unlock()
m.resumes = append(m.resumes, req)
}

func (m *mockRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event {
ch := make(chan runtime.Event, len(m.events))
for _, e := range m.events {
ch <- e
}
close(ch)
return ch
}

func (m *mockRuntime) getResumes() []runtime.ResumeRequest {
m.mu.Lock()
defer m.mu.Unlock()
result := make([]runtime.ResumeRequest, len(m.resumes))
copy(result, m.resumes)
return result
}

func maxIterEvent(maxIter int) *runtime.MaxIterationsReachedEvent {
return &runtime.MaxIterationsReachedEvent{
Type: "max_iterations_reached",
MaxIterations: maxIter,
}
}

func TestMaxIterationsAutoApproveInYoloMode(t *testing.T) {
t.Parallel()

rt := &mockRuntime{
events: []runtime.Event{maxIterEvent(60)},
}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), 1)
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove)
}

func TestMaxIterationsAutoApproveSafetyCap(t *testing.T) {
t.Parallel()

// Emit maxAutoExtensions+1 events to trigger the safety cap
events := make([]runtime.Event, maxAutoExtensions+1)
for i := range events {
events[i] = maxIterEvent(60 + i*10)
}

rt := &mockRuntime{events: events}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), maxAutoExtensions+1)

// First maxAutoExtensions should be approved
for i := range maxAutoExtensions {
assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove,
"extension %d should be approved", i+1)
}
// Last one should be rejected (safety cap)
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject,
"extension beyond cap should be rejected")
}

func TestMaxIterationsAutoApproveJSONMode(t *testing.T) {
t.Parallel()

rt := &mockRuntime{
events: []runtime.Event{maxIterEvent(60)},
}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: true, OutputJSON: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), 1)
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove)
}

func TestMaxIterationsRejectInJSONModeWithoutYolo(t *testing.T) {
t.Parallel()

rt := &mockRuntime{
events: []runtime.Event{maxIterEvent(60)},
}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: false, OutputJSON: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), 1)
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeReject)
}

func TestMaxIterationsSafetyCapJSONMode(t *testing.T) {
t.Parallel()

events := make([]runtime.Event, maxAutoExtensions+1)
for i := range events {
events[i] = maxIterEvent(60 + i*10)
}

rt := &mockRuntime{events: events}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: true, OutputJSON: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), maxAutoExtensions+1)

for i := range maxAutoExtensions {
assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove)
}
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject)
}