-
Notifications
You must be signed in to change notification settings - Fork 259
fix(cli): auto-continue max iterations in --yolo mode #1737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
80db452
7ed808f
ea08bcc
3634385
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,39 @@ func (e RuntimeError) Unwrap() error { | |
| return e.Err | ||
| } | ||
|
|
||
| // maxAutoExtensions is the maximum number of times --yolo mode will | ||
| // auto-continue when max iterations is reached, to prevent infinite loops. | ||
| const maxAutoExtensions = 5 | ||
|
|
||
| // maxIterAction describes what the caller should do after a MaxIterationsReachedEvent. | ||
| type maxIterAction int | ||
|
|
||
| const ( | ||
| maxIterContinue maxIterAction = iota // auto-approved, keep running | ||
| maxIterStop // safety cap reached, caller should stop | ||
| maxIterPrompt // not in yolo mode, caller should prompt the user | ||
| ) | ||
|
|
||
| // handleMaxIterationsAutoApprove decides whether to auto-extend iterations in | ||
| // --yolo mode. Returns maxIterContinue (approved), maxIterStop (cap reached), | ||
| // or maxIterPrompt (not in auto-approve mode, caller should ask the user). | ||
| func handleMaxIterationsAutoApprove(autoApprove bool, autoExtensions *int, maxIter int) maxIterAction { | ||
| if !autoApprove { | ||
| return maxIterPrompt | ||
| } | ||
| *autoExtensions++ | ||
| if *autoExtensions <= maxAutoExtensions { | ||
| slog.Info("Auto-extending iterations in yolo mode", | ||
| "extension", *autoExtensions, | ||
| "max_extensions", maxAutoExtensions, | ||
| "current_max", maxIter) | ||
| return maxIterContinue | ||
| } | ||
| slog.Warn("Max auto-extensions reached in yolo mode, stopping", | ||
| "total_extensions", *autoExtensions) | ||
| return maxIterStop | ||
| } | ||
|
|
||
| // Config holds configuration for running an agent in CLI mode | ||
| type Config struct { | ||
| AppName string | ||
|
|
@@ -60,6 +93,8 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess | |
| var lastErr error | ||
|
|
||
| oneLoop := func(text string, rd io.Reader) error { | ||
| autoExtensions := 0 | ||
|
|
||
| userInput := strings.TrimSpace(text) | ||
| if userInput == "" { | ||
| return nil | ||
|
|
@@ -74,6 +109,14 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess | |
| if !cfg.AutoApprove { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 HIGH SEVERITY: Missing Resume call causes runtime to hang When The current logic:
Fix: Add an case *runtime.ToolCallConfirmationEvent:
if !cfg.AutoApprove {
rt.Resume(ctx, runtime.ResumeReject(""))
} else {
rt.Resume(ctx, runtime.ResumeApprove()) // <-- Add this
}This mirrors the MaxIterationsReachedEvent handling pattern on lines 112-119 and ensures consistency with non-JSON mode (lines 152-167). |
||
| rt.Resume(ctx, runtime.ResumeReject("")) | ||
| } | ||
| case *runtime.MaxIterationsReachedEvent: | ||
| switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) { | ||
| case maxIterContinue: | ||
| rt.Resume(ctx, runtime.ResumeApprove()) | ||
| default: // maxIterStop or maxIterPrompt (no interactive prompt in JSON mode) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When not in yolo mode or when the safety cap is reached, this code rejects and returns with no visible indication to the user consuming the JSON output. The problem:
Impact: JSON API consumers cannot determine why execution stopped or what decision was made. The comment acknowledges "no interactive prompt in JSON mode" but doesn't address the lack of visibility. Consider: Emitting a custom event before rejecting to inform JSON consumers of the decision, similar to how interactive mode shows clear feedback. |
||
| rt.Resume(ctx, runtime.ResumeReject("")) | ||
| return nil | ||
| } | ||
| case *runtime.ErrorEvent: | ||
| return fmt.Errorf("%s", e.Error) | ||
| } | ||
|
|
@@ -153,16 +196,24 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess | |
| out.PrintError(lastErr) | ||
| } | ||
| case *runtime.MaxIterationsReachedEvent: | ||
| result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations) | ||
| switch result { | ||
| case ConfirmationApprove: | ||
| switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) { | ||
| case maxIterContinue: | ||
| rt.Resume(ctx, runtime.ResumeApprove()) | ||
| case ConfirmationReject: | ||
| rt.Resume(ctx, runtime.ResumeReject("")) | ||
| return nil | ||
| case ConfirmationAbort: | ||
| case maxIterStop: | ||
| rt.Resume(ctx, runtime.ResumeReject("")) | ||
| return nil | ||
| case maxIterPrompt: | ||
| result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations) | ||
| switch result { | ||
| case ConfirmationApprove: | ||
| rt.Resume(ctx, runtime.ResumeApprove()) | ||
| case ConfirmationReject: | ||
| rt.Resume(ctx, runtime.ResumeReject("")) | ||
| return nil | ||
| case ConfirmationAbort: | ||
| rt.Resume(ctx, runtime.ResumeReject("")) | ||
| return nil | ||
| } | ||
| } | ||
| case *runtime.ElicitationRequestEvent: | ||
| serverURL, ok := e.Meta["cagent/server_url"].(string) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| package cli | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "context" | ||
| "sync" | ||
| "testing" | ||
|
|
||
| "gotest.tools/v3/assert" | ||
|
|
||
| "github.com/docker/cagent/pkg/runtime" | ||
| "github.com/docker/cagent/pkg/session" | ||
| "github.com/docker/cagent/pkg/sessiontitle" | ||
| "github.com/docker/cagent/pkg/tools" | ||
| mcptools "github.com/docker/cagent/pkg/tools/mcp" | ||
| ) | ||
|
|
||
| // mockRuntime implements runtime.Runtime for testing the CLI runner. | ||
| // It emits pre-configured events from RunStream and records Resume calls. | ||
| type mockRuntime struct { | ||
| events []runtime.Event | ||
|
|
||
| mu sync.Mutex | ||
| resumes []runtime.ResumeRequest | ||
| } | ||
|
|
||
| func (m *mockRuntime) CurrentAgentName() string { return "test" } | ||
| func (m *mockRuntime) CurrentAgentInfo(context.Context) runtime.CurrentAgentInfo { | ||
| return runtime.CurrentAgentInfo{Name: "test"} | ||
| } | ||
| func (m *mockRuntime) SetCurrentAgent(string) error { return nil } | ||
| func (m *mockRuntime) CurrentAgentTools(context.Context) ([]tools.Tool, error) { return nil, nil } | ||
| func (m *mockRuntime) EmitStartupInfo(context.Context, chan runtime.Event) {} | ||
| func (m *mockRuntime) ResetStartupInfo() {} | ||
| func (m *mockRuntime) Run(context.Context, *session.Session) ([]session.Message, error) { | ||
| return nil, nil | ||
| } | ||
|
|
||
| func (m *mockRuntime) ResumeElicitation(context.Context, tools.ElicitationAction, map[string]any) error { | ||
| return nil | ||
| } | ||
| func (m *mockRuntime) SessionStore() session.Store { return nil } | ||
| func (m *mockRuntime) Summarize(context.Context, *session.Session, string, chan runtime.Event) {} | ||
| func (m *mockRuntime) PermissionsInfo() *runtime.PermissionsInfo { return nil } | ||
| func (m *mockRuntime) CurrentAgentSkillsEnabled() bool { return false } | ||
| func (m *mockRuntime) CurrentMCPPrompts(context.Context) map[string]mcptools.PromptInfo { | ||
| return nil | ||
| } | ||
|
|
||
| func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]string) (string, error) { | ||
| return "", nil | ||
| } | ||
| func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil } | ||
| func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil } | ||
| func (m *mockRuntime) Close() error { return nil } | ||
| func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {} | ||
|
|
||
| func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) { | ||
| m.mu.Lock() | ||
| defer m.mu.Unlock() | ||
| m.resumes = append(m.resumes, req) | ||
| } | ||
|
|
||
| func (m *mockRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event { | ||
| ch := make(chan runtime.Event, len(m.events)) | ||
| for _, e := range m.events { | ||
| ch <- e | ||
| } | ||
| close(ch) | ||
| return ch | ||
| } | ||
|
|
||
| func (m *mockRuntime) getResumes() []runtime.ResumeRequest { | ||
| m.mu.Lock() | ||
| defer m.mu.Unlock() | ||
| result := make([]runtime.ResumeRequest, len(m.resumes)) | ||
| copy(result, m.resumes) | ||
| return result | ||
| } | ||
|
|
||
| func maxIterEvent(maxIter int) *runtime.MaxIterationsReachedEvent { | ||
| return &runtime.MaxIterationsReachedEvent{ | ||
| Type: "max_iterations_reached", | ||
| MaxIterations: maxIter, | ||
| } | ||
| } | ||
|
|
||
| func TestMaxIterationsAutoApproveInYoloMode(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| rt := &mockRuntime{ | ||
| events: []runtime.Event{maxIterEvent(60)}, | ||
| } | ||
|
|
||
| var buf bytes.Buffer | ||
| out := NewPrinter(&buf) | ||
| sess := session.New() | ||
| cfg := Config{AutoApprove: true} | ||
|
|
||
| err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) | ||
| assert.NilError(t, err) | ||
|
|
||
| resumes := rt.getResumes() | ||
| assert.Equal(t, len(resumes), 1) | ||
| assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove) | ||
| } | ||
|
|
||
| func TestMaxIterationsAutoApproveSafetyCap(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| // Emit maxAutoExtensions+1 events to trigger the safety cap | ||
| events := make([]runtime.Event, maxAutoExtensions+1) | ||
| for i := range events { | ||
| events[i] = maxIterEvent(60 + i*10) | ||
| } | ||
|
|
||
| rt := &mockRuntime{events: events} | ||
|
|
||
| var buf bytes.Buffer | ||
| out := NewPrinter(&buf) | ||
| sess := session.New() | ||
| cfg := Config{AutoApprove: true} | ||
|
|
||
| err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) | ||
| assert.NilError(t, err) | ||
|
|
||
| resumes := rt.getResumes() | ||
| assert.Equal(t, len(resumes), maxAutoExtensions+1) | ||
|
|
||
| // First maxAutoExtensions should be approved | ||
| for i := range maxAutoExtensions { | ||
| assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove, | ||
| "extension %d should be approved", i+1) | ||
| } | ||
| // Last one should be rejected (safety cap) | ||
| assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject, | ||
| "extension beyond cap should be rejected") | ||
| } | ||
|
|
||
| func TestMaxIterationsAutoApproveJSONMode(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| rt := &mockRuntime{ | ||
| events: []runtime.Event{maxIterEvent(60)}, | ||
| } | ||
|
|
||
| var buf bytes.Buffer | ||
| out := NewPrinter(&buf) | ||
| sess := session.New() | ||
| cfg := Config{AutoApprove: true, OutputJSON: true} | ||
|
|
||
| err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) | ||
| assert.NilError(t, err) | ||
|
|
||
| resumes := rt.getResumes() | ||
| assert.Equal(t, len(resumes), 1) | ||
| assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove) | ||
| } | ||
|
|
||
| func TestMaxIterationsRejectInJSONModeWithoutYolo(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| rt := &mockRuntime{ | ||
| events: []runtime.Event{maxIterEvent(60)}, | ||
| } | ||
|
|
||
| var buf bytes.Buffer | ||
| out := NewPrinter(&buf) | ||
| sess := session.New() | ||
| cfg := Config{AutoApprove: false, OutputJSON: true} | ||
|
|
||
| err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) | ||
| assert.NilError(t, err) | ||
|
|
||
| resumes := rt.getResumes() | ||
| assert.Equal(t, len(resumes), 1) | ||
| assert.Equal(t, resumes[0].Type, runtime.ResumeTypeReject) | ||
| } | ||
|
|
||
| func TestMaxIterationsSafetyCapJSONMode(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| events := make([]runtime.Event, maxAutoExtensions+1) | ||
| for i := range events { | ||
| events[i] = maxIterEvent(60 + i*10) | ||
| } | ||
|
|
||
| rt := &mockRuntime{events: events} | ||
|
|
||
| var buf bytes.Buffer | ||
| out := NewPrinter(&buf) | ||
| sess := session.New() | ||
| cfg := Config{AutoApprove: true, OutputJSON: true} | ||
|
|
||
| err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) | ||
| assert.NilError(t, err) | ||
|
|
||
| resumes := rt.getResumes() | ||
| assert.Equal(t, len(resumes), maxAutoExtensions+1) | ||
|
|
||
| for i := range maxAutoExtensions { | ||
| assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove) | ||
| } | ||
| assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 HIGH SEVERITY: Safety cap can be bypassed in multi-turn mode
The
autoExtensionscounter is declared insideoneLoop(), which gets reset to 0 on every call. This defeats the purpose of the safety cap:oneLoop()is called once per messageFix: Move
autoExtensions := 0outside ofoneLoop()to theRun()function level, so the counter persists across all messages in the session:This ensures the safety cap applies to the entire session, not per-message.