diff --git a/agent-schema.json b/agent-schema.json index 718912067..2cda0cdf7 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -14,7 +14,8 @@ "2", "3", "4", - "5" + "5", + "6" ], "examples": [ "0", @@ -22,7 +23,8 @@ "2", "3", "4", - "5" + "5", + "6" ] }, "providers": { diff --git a/pkg/config/config.go b/pkg/config/config.go index c2a86d2cd..40bae0808 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -70,7 +70,7 @@ func CheckRequiredEnvVars(ctx context.Context, cfg *latest.Config, modelsGateway } func parseCurrentVersion(data []byte, version string) (any, error) { - parsers := Parsers() + parsers, _ := versions() parser, found := parsers[version] if !found { return nil, fmt.Errorf("unsupported config version: %v (valid versions: %s)", version, strings.Join(slices.Sorted(maps.Keys(parsers)), ", ")) @@ -81,7 +81,8 @@ func parseCurrentVersion(data []byte, version string) (any, error) { func migrateToLatestConfig(c any, raw []byte) (latest.Config, error) { var err error - for _, upgrade := range Upgrades() { + _, upgraders := versions() + for _, upgrade := range upgraders { c, err = upgrade(c, raw) if err != nil { return latest.Config{}, err diff --git a/pkg/config/latest/parse.go b/pkg/config/latest/parse.go index 2cc8559b9..c72bbf756 100644 --- a/pkg/config/latest/parse.go +++ b/pkg/config/latest/parse.go @@ -1,9 +1,30 @@ package latest -import "github.com/goccy/go-yaml" +import ( + "github.com/goccy/go-yaml" -func Parse(data []byte) (Config, error) { + "github.com/docker/cagent/pkg/config/types" + previous "github.com/docker/cagent/pkg/config/v5" +) + +func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { + parsers[Version] = func(d []byte) (any, error) { return parse(d) } + *upgraders = append(*upgraders, upgradeIfNeeded) +} + +func parse(data []byte) (Config, error) { var cfg Config err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()) return cfg, err } + +func upgradeIfNeeded(c any, _ []byte) (any, error) { + old, ok := c.(previous.Config) + if !ok { + return c, nil + } + + var config Config + types.CloneThroughJSON(old, &config) + return config, nil +} diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index f601ef794..60f5d1e56 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -13,7 +13,7 @@ import ( "github.com/docker/cagent/pkg/config/types" ) -const Version = "5" +const Version = "6" // Config represents the entire configuration file type Config struct { diff --git a/pkg/config/latest/types_test.go b/pkg/config/latest/types_test.go index 9ae9a951b..ff75d4555 100644 --- a/pkg/config/latest/types_test.go +++ b/pkg/config/latest/types_test.go @@ -133,7 +133,7 @@ agents: instructions: "You are a helpful assistant." `) - _, err := Parse(input) + _, err := parse(input) require.Error(t, err) require.Contains(t, err.Error(), "instructions") } @@ -148,7 +148,7 @@ agents: instruction: "You are a helpful assistant." `) - cfg, err := Parse(input) + cfg, err := parse(input) require.NoError(t, err) require.Len(t, cfg.Agents, 1) require.Equal(t, "root", cfg.Agents[0].Name) diff --git a/pkg/config/latest/upgrade.go b/pkg/config/latest/upgrade.go deleted file mode 100644 index 987b6747e..000000000 --- a/pkg/config/latest/upgrade.go +++ /dev/null @@ -1,17 +0,0 @@ -package latest - -import ( - "github.com/docker/cagent/pkg/config/types" - previous "github.com/docker/cagent/pkg/config/v4" -) - -func UpgradeIfNeeded(c any, _ []byte) (any, error) { - old, ok := c.(previous.Config) - if !ok { - return c, nil - } - - var config Config - types.CloneThroughJSON(old, &config) - return config, nil -} diff --git a/pkg/config/v0/parse.go b/pkg/config/v0/parse.go index 44ac4e45b..461e195e6 100644 --- a/pkg/config/v0/parse.go +++ b/pkg/config/v0/parse.go @@ -2,8 +2,17 @@ package v0 import "github.com/goccy/go-yaml" -func Parse(data []byte) (Config, error) { +func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { + parsers[Version] = func(d []byte) (any, error) { return parse(d) } + *upgraders = append(*upgraders, upgradeIfNeeded) +} + +func parse(data []byte) (Config, error) { var cfg Config err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()) return cfg, err } + +func upgradeIfNeeded(old any, _ []byte) (any, error) { + return old, nil +} diff --git a/pkg/config/v0/upgrade.go b/pkg/config/v0/upgrade.go deleted file mode 100644 index 7b7237825..000000000 --- a/pkg/config/v0/upgrade.go +++ /dev/null @@ -1,5 +0,0 @@ -package v0 - -func UpgradeIfNeeded(old any, _ []byte) (any, error) { - return old, nil -} diff --git a/pkg/config/v1/parse.go b/pkg/config/v1/parse.go index 8a1374e88..cb3fd974c 100644 --- a/pkg/config/v1/parse.go +++ b/pkg/config/v1/parse.go @@ -1,9 +1,72 @@ package v1 -import "github.com/goccy/go-yaml" +import ( + "github.com/goccy/go-yaml" -func Parse(data []byte) (Config, error) { + "github.com/docker/cagent/pkg/config/types" + previous "github.com/docker/cagent/pkg/config/v0" +) + +func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { + parsers[Version] = func(d []byte) (any, error) { return parse(d) } + *upgraders = append(*upgraders, upgradeIfNeeded) +} + +func parse(data []byte) (Config, error) { var cfg Config err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()) return cfg, err } + +func upgradeIfNeeded(c any, _ []byte) (any, error) { + old, ok := c.(previous.Config) + if !ok { + return c, nil + } + + var config Config + types.CloneThroughJSON(old, &config) + + // model.Type --> model.Provider + for name := range old.Models { + oldModel := old.Models[name] + newModel := config.Models[name] + + newModel.Provider = oldModel.Type + config.Models[name] = newModel + } + + // todo:true --> toolsets: [{type: todo}] + // think:true --> toolsets: [{type: think}] + // memory:{path: PATH} --> toolsets: [{type: memory, path: PATH}] + for name := range old.Agents { + oldAgent := old.Agents[name] + newAgent := config.Agents[name] + + var toolsets []Toolset + + if oldAgent.Todo.Enabled { + toolsets = append(toolsets, Toolset{ + Type: "todo", + Shared: oldAgent.Todo.Shared, + }) + } + if oldAgent.Think { + toolsets = append(toolsets, Toolset{ + Type: "think", + }) + } + if oldAgent.MemoryConfig.Path != "" { + toolsets = append(toolsets, Toolset{ + Type: "memory", + Path: oldAgent.MemoryConfig.Path, + }) + } + + toolsets = append(toolsets, newAgent.Toolsets...) + newAgent.Toolsets = toolsets + config.Agents[name] = newAgent + } + + return config, nil +} diff --git a/pkg/config/v1/upgrade.go b/pkg/config/v1/upgrade.go deleted file mode 100644 index a91827fd0..000000000 --- a/pkg/config/v1/upgrade.go +++ /dev/null @@ -1,59 +0,0 @@ -package v1 - -import ( - "github.com/docker/cagent/pkg/config/types" - previous "github.com/docker/cagent/pkg/config/v0" -) - -func UpgradeIfNeeded(c any, _ []byte) (any, error) { - old, ok := c.(previous.Config) - if !ok { - return c, nil - } - - var config Config - types.CloneThroughJSON(old, &config) - - // model.Type --> model.Provider - for name := range old.Models { - oldModel := old.Models[name] - newModel := config.Models[name] - - newModel.Provider = oldModel.Type - config.Models[name] = newModel - } - - // todo:true --> toolsets: [{type: todo}] - // think:true --> toolsets: [{type: think}] - // memory:{path: PATH} --> toolsets: [{type: memory, path: PATH}] - for name := range old.Agents { - oldAgent := old.Agents[name] - newAgent := config.Agents[name] - - var toolsets []Toolset - - if oldAgent.Todo.Enabled { - toolsets = append(toolsets, Toolset{ - Type: "todo", - Shared: oldAgent.Todo.Shared, - }) - } - if oldAgent.Think { - toolsets = append(toolsets, Toolset{ - Type: "think", - }) - } - if oldAgent.MemoryConfig.Path != "" { - toolsets = append(toolsets, Toolset{ - Type: "memory", - Path: oldAgent.MemoryConfig.Path, - }) - } - - toolsets = append(toolsets, newAgent.Toolsets...) - newAgent.Toolsets = toolsets - config.Agents[name] = newAgent - } - - return config, nil -} diff --git a/pkg/config/v2/parse.go b/pkg/config/v2/parse.go index d7778fd1f..9b0194e79 100644 --- a/pkg/config/v2/parse.go +++ b/pkg/config/v2/parse.go @@ -1,9 +1,54 @@ package v2 -import "github.com/goccy/go-yaml" +import ( + "errors" -func Parse(data []byte) (Config, error) { + "github.com/goccy/go-yaml" + + "github.com/docker/cagent/pkg/config/types" + previous "github.com/docker/cagent/pkg/config/v1" +) + +func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { + parsers[Version] = func(d []byte) (any, error) { return parse(d) } + *upgraders = append(*upgraders, upgradeIfNeeded) +} + +func parse(data []byte) (Config, error) { var cfg Config err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()) return cfg, err } + +func upgradeIfNeeded(c any, _ []byte) (any, error) { + old, ok := c.(previous.Config) + if !ok { + return c, nil + } + + if len(old.Env) > 0 { + return Config{}, errors.New("top-level Env is not supported anymore") + } + + for i := range old.Models { + model := old.Models[i] + + if len(model.Env) > 0 { + return Config{}, errors.New("model Env is not supported anymore") + } + } + + for _, agent := range old.Agents { + for i := range agent.Toolsets { + toolSet := agent.Toolsets[i] + + if len(toolSet.Envfiles) > 0 { + return Config{}, errors.New("toolset Envfiles is not supported anymore") + } + } + } + + var config Config + types.CloneThroughJSON(old, &config) + return config, nil +} diff --git a/pkg/config/v2/upgrade.go b/pkg/config/v2/upgrade.go deleted file mode 100644 index 3a5263dd2..000000000 --- a/pkg/config/v2/upgrade.go +++ /dev/null @@ -1,41 +0,0 @@ -package v2 - -import ( - "errors" - - "github.com/docker/cagent/pkg/config/types" - previous "github.com/docker/cagent/pkg/config/v1" -) - -func UpgradeIfNeeded(c any, _ []byte) (any, error) { - old, ok := c.(previous.Config) - if !ok { - return c, nil - } - - if len(old.Env) > 0 { - return Config{}, errors.New("top-level Env is not supported anymore") - } - - for i := range old.Models { - model := old.Models[i] - - if len(model.Env) > 0 { - return Config{}, errors.New("model Env is not supported anymore") - } - } - - for _, agent := range old.Agents { - for i := range agent.Toolsets { - toolSet := agent.Toolsets[i] - - if len(toolSet.Envfiles) > 0 { - return Config{}, errors.New("toolset Envfiles is not supported anymore") - } - } - } - - var config Config - types.CloneThroughJSON(old, &config) - return config, nil -} diff --git a/pkg/config/v3/parse.go b/pkg/config/v3/parse.go index 7561ed41c..998502156 100644 --- a/pkg/config/v3/parse.go +++ b/pkg/config/v3/parse.go @@ -1,9 +1,30 @@ package v3 -import "github.com/goccy/go-yaml" +import ( + "github.com/goccy/go-yaml" -func Parse(data []byte) (Config, error) { + "github.com/docker/cagent/pkg/config/types" + previous "github.com/docker/cagent/pkg/config/v2" +) + +func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { + parsers[Version] = func(d []byte) (any, error) { return parse(d) } + *upgraders = append(*upgraders, upgradeIfNeeded) +} + +func parse(data []byte) (Config, error) { var cfg Config err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()) return cfg, err } + +func upgradeIfNeeded(c any, _ []byte) (any, error) { + old, ok := c.(previous.Config) + if !ok { + return c, nil + } + + var config Config + types.CloneThroughJSON(old, &config) + return config, nil +} diff --git a/pkg/config/v3/upgrade.go b/pkg/config/v3/upgrade.go deleted file mode 100644 index add2cf622..000000000 --- a/pkg/config/v3/upgrade.go +++ /dev/null @@ -1,17 +0,0 @@ -package v3 - -import ( - "github.com/docker/cagent/pkg/config/types" - previous "github.com/docker/cagent/pkg/config/v2" -) - -func UpgradeIfNeeded(c any, _ []byte) (any, error) { - old, ok := c.(previous.Config) - if !ok { - return c, nil - } - - var config Config - types.CloneThroughJSON(old, &config) - return config, nil -} diff --git a/pkg/config/v4/parse.go b/pkg/config/v4/parse.go index b7ca1373c..38bc04fcf 100644 --- a/pkg/config/v4/parse.go +++ b/pkg/config/v4/parse.go @@ -1,9 +1,55 @@ package v4 -import "github.com/goccy/go-yaml" +import ( + "github.com/goccy/go-yaml" -func Parse(data []byte) (Config, error) { + "github.com/docker/cagent/pkg/config/types" + previous "github.com/docker/cagent/pkg/config/v3" +) + +func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { + parsers[Version] = func(d []byte) (any, error) { return parse(d) } + *upgraders = append(*upgraders, upgradeIfNeeded) +} + +func parse(data []byte) (Config, error) { var cfg Config err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()) return cfg, err } + +func upgradeIfNeeded(c any, raw []byte) (any, error) { + old, ok := c.(previous.Config) + if !ok { + return c, nil + } + + // Put the agents on the side + previousAgents := old.Agents + old.Agents = nil + + var config Config + types.CloneThroughJSON(old, &config) + + // For agents, we have to read in what they order they appear in the raw config + type Original struct { + Agents yaml.MapSlice `yaml:"agents"` + } + + var original Original + if err := yaml.Unmarshal(raw, &original); err != nil { + return nil, err + } + + for _, agent := range original.Agents { + name := agent.Key.(string) + + var agentConfig AgentConfig + types.CloneThroughJSON(previousAgents[name], &agentConfig) + agentConfig.Name = name + + config.Agents = append(config.Agents, agentConfig) + } + + return config, nil +} diff --git a/pkg/config/v4/types_test.go b/pkg/config/v4/types_test.go index 8a3de445d..992134209 100644 --- a/pkg/config/v4/types_test.go +++ b/pkg/config/v4/types_test.go @@ -133,7 +133,7 @@ agents: instructions: "You are a helpful assistant." `) - _, err := Parse(input) + _, err := parse(input) require.Error(t, err) require.Contains(t, err.Error(), "instructions") } @@ -148,7 +148,7 @@ agents: instruction: "You are a helpful assistant." `) - cfg, err := Parse(input) + cfg, err := parse(input) require.NoError(t, err) require.Len(t, cfg.Agents, 1) require.Equal(t, "root", cfg.Agents[0].Name) diff --git a/pkg/config/v4/upgrade.go b/pkg/config/v4/upgrade.go deleted file mode 100644 index bafae6751..000000000 --- a/pkg/config/v4/upgrade.go +++ /dev/null @@ -1,44 +0,0 @@ -package v4 - -import ( - "github.com/goccy/go-yaml" - - "github.com/docker/cagent/pkg/config/types" - previous "github.com/docker/cagent/pkg/config/v3" -) - -func UpgradeIfNeeded(c any, raw []byte) (any, error) { - old, ok := c.(previous.Config) - if !ok { - return c, nil - } - - // Put the agents on the side - previousAgents := old.Agents - old.Agents = nil - - var config Config - types.CloneThroughJSON(old, &config) - - // For agents, we have to read in what they order they appear in the raw config - type Original struct { - Agents yaml.MapSlice `yaml:"agents"` - } - - var original Original - if err := yaml.Unmarshal(raw, &original); err != nil { - return nil, err - } - - for _, agent := range original.Agents { - name := agent.Key.(string) - - var agentConfig AgentConfig - types.CloneThroughJSON(previousAgents[name], &agentConfig) - agentConfig.Name = name - - config.Agents = append(config.Agents, agentConfig) - } - - return config, nil -} diff --git a/pkg/config/v5/model_config_clone_test.go b/pkg/config/v5/model_config_clone_test.go new file mode 100644 index 000000000..7ba0cbe82 --- /dev/null +++ b/pkg/config/v5/model_config_clone_test.go @@ -0,0 +1,79 @@ +package v5 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestModelConfig_Clone_DeepCopiesPointerFields(t *testing.T) { + t.Parallel() + + temp := 0.7 + maxTokens := int64(4096) + topP := 0.9 + parallel := true + trackUsage := true + + original := &ModelConfig{ + Provider: "openai", + Model: "gpt-4o", + Temperature: &temp, + MaxTokens: &maxTokens, + TopP: &topP, + ParallelToolCalls: ¶llel, + TrackUsage: &trackUsage, + ThinkingBudget: &ThinkingBudget{Effort: "high"}, + ProviderOpts: map[string]any{"key": "value"}, + Routing: []RoutingRule{ + {Model: "fast", Examples: []string{"quick question"}}, + }, + } + + clone := original.Clone() + + // Mutate every pointer/collection field in the original. + *original.Temperature = 0.1 + *original.MaxTokens = 1 + *original.TopP = 0.1 + *original.ParallelToolCalls = false + *original.TrackUsage = false + original.ThinkingBudget.Effort = "low" + original.ProviderOpts["key"] = "mutated" + original.Routing[0].Examples[0] = "mutated" + + // Clone must be unaffected. + assert.InDelta(t, 0.7, *clone.Temperature, 0.001) + assert.Equal(t, int64(4096), *clone.MaxTokens) + assert.InDelta(t, 0.9, *clone.TopP, 0.001) + assert.True(t, *clone.ParallelToolCalls) + assert.True(t, *clone.TrackUsage) + assert.Equal(t, "high", clone.ThinkingBudget.Effort) + assert.Equal(t, "value", clone.ProviderOpts["key"]) + assert.Equal(t, "quick question", clone.Routing[0].Examples[0]) +} + +func TestModelConfig_Clone_Nil(t *testing.T) { + t.Parallel() + + var m *ModelConfig + assert.Nil(t, m.Clone()) +} + +func TestModelConfig_Clone_MinimalFields(t *testing.T) { + t.Parallel() + + original := &ModelConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + } + + clone := original.Clone() + + assert.Equal(t, "anthropic", clone.Provider) + assert.Equal(t, "claude-sonnet-4-5", clone.Model) + assert.Nil(t, clone.Temperature) + assert.Nil(t, clone.MaxTokens) + assert.Nil(t, clone.ProviderOpts) + assert.Nil(t, clone.Routing) +} diff --git a/pkg/config/v5/parse.go b/pkg/config/v5/parse.go new file mode 100644 index 000000000..adddb33a2 --- /dev/null +++ b/pkg/config/v5/parse.go @@ -0,0 +1,30 @@ +package v5 + +import ( + "github.com/goccy/go-yaml" + + "github.com/docker/cagent/pkg/config/types" + previous "github.com/docker/cagent/pkg/config/v4" +) + +func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { + parsers[Version] = func(d []byte) (any, error) { return parse(d) } + *upgraders = append(*upgraders, upgradeIfNeeded) +} + +func parse(data []byte) (Config, error) { + var cfg Config + err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()) + return cfg, err +} + +func upgradeIfNeeded(c any, _ []byte) (any, error) { + old, ok := c.(previous.Config) + if !ok { + return c, nil + } + + var config Config + types.CloneThroughJSON(old, &config) + return config, nil +} diff --git a/pkg/config/v5/schema_test.go b/pkg/config/v5/schema_test.go new file mode 100644 index 000000000..001f8cb8c --- /dev/null +++ b/pkg/config/v5/schema_test.go @@ -0,0 +1,225 @@ +package v5 + +import ( + "encoding/json" + "maps" + "os" + "reflect" + "sort" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// schemaFile is the path to the JSON schema file relative to the repo root. +const schemaFile = "../../../agent-schema.json" + +// jsonSchema mirrors the subset of JSON Schema we need for comparison. +type jsonSchema struct { + Properties map[string]jsonSchema `json:"properties,omitempty"` + Definitions map[string]jsonSchema `json:"definitions,omitempty"` + Ref string `json:"$ref,omitempty"` + Items *jsonSchema `json:"items,omitempty"` + AdditionalProperties any `json:"additionalProperties,omitempty"` +} + +// resolveRef follows a $ref like "#/definitions/Foo" and returns the +// referenced schema. When no $ref is present it returns the receiver unchanged. +func (s jsonSchema) resolveRef(root jsonSchema) jsonSchema { + if s.Ref == "" { + return s + } + const prefix = "#/definitions/" + if !strings.HasPrefix(s.Ref, prefix) { + return s + } + name := strings.TrimPrefix(s.Ref, prefix) + if def, ok := root.Definitions[name]; ok { + return def + } + return s +} + +// structJSONFields returns the set of JSON property names declared on a Go +// struct type via `json:",…"` tags. Fields tagged with `json:"-"` are +// excluded. It recurses into anonymous (embedded) struct fields so that +// promoted fields are included. +func structJSONFields(t reflect.Type) map[string]bool { + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + fields := make(map[string]bool) + for f := range t.Fields() { + // Recurse into anonymous (embedded) structs. + if f.Anonymous { + maps.Copy(fields, structJSONFields(f.Type)) + continue + } + + tag := f.Tag.Get("json") + if tag == "" || tag == "-" { + continue + } + name, _, _ := strings.Cut(tag, ",") + if name != "" && name != "-" { + fields[name] = true + } + } + return fields +} + +// schemaProperties returns the set of property names from a JSON schema +// definition. It does NOT follow $ref on individual properties – it only +// looks at the top-level "properties" map. +func schemaProperties(def jsonSchema) map[string]bool { + props := make(map[string]bool, len(def.Properties)) + for k := range def.Properties { + props[k] = true + } + return props +} + +func sortedKeys(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +// TestSchemaMatchesGoTypes verifies that every JSON-tagged field in the Go +// config structs has a corresponding property in agent-schema.json (and +// vice-versa). This prevents the schema from silently drifting out of sync +// with the Go types. +func TestSchemaMatchesGoTypes(t *testing.T) { + t.Parallel() + + data, err := os.ReadFile(schemaFile) + require.NoError(t, err, "failed to read schema file – run this test from the repo root") + + var root jsonSchema + require.NoError(t, json.Unmarshal(data, &root)) + + // mapping maps a JSON Schema definition name (or pseudo-name for inline + // schemas) to the corresponding Go type. For top-level definitions that + // live in the "definitions" section of the schema we use their exact + // name. For schemas inlined inside a parent property we use + // "Parent.property" as the key. + type entry struct { + goType reflect.Type + schemaDef jsonSchema + schemaName string // human-readable name for error messages + } + + entries := []entry{ + // Top-level Config + {reflect.TypeFor[Config](), root, "Config (top-level)"}, + } + + // Definitions that map 1:1 to a Go struct. + definitionMap := map[string]reflect.Type{ + "AgentConfig": reflect.TypeFor[AgentConfig](), + "FallbackConfig": reflect.TypeFor[FallbackConfig](), + "ModelConfig": reflect.TypeFor[ModelConfig](), + "Metadata": reflect.TypeFor[Metadata](), + "ProviderConfig": reflect.TypeFor[ProviderConfig](), + "Toolset": reflect.TypeFor[Toolset](), + "Remote": reflect.TypeFor[Remote](), + "SandboxConfig": reflect.TypeFor[SandboxConfig](), + "ScriptShellToolConfig": reflect.TypeFor[ScriptShellToolConfig](), + "PostEditConfig": reflect.TypeFor[PostEditConfig](), + "PermissionsConfig": reflect.TypeFor[PermissionsConfig](), + "HooksConfig": reflect.TypeFor[HooksConfig](), + "HookMatcherConfig": reflect.TypeFor[HookMatcherConfig](), + "HookDefinition": reflect.TypeFor[HookDefinition](), + "RoutingRule": reflect.TypeFor[RoutingRule](), + "ApiConfig": reflect.TypeFor[APIToolConfig](), + } + + for name, goType := range definitionMap { + def, ok := root.Definitions[name] + require.True(t, ok, "schema definition %q not found", name) + entries = append(entries, entry{goType, def, name}) + } + + // Inline schemas that don't have their own top-level definition but are + // nested inside a parent property. + type inlineEntry struct { + goType reflect.Type + // path navigates from a schema definition to the inline object, + // e.g. []string{"RAGConfig", "results"} → definitions.RAGConfig.properties.results + path []string + name string + } + + inlines := []inlineEntry{ + {reflect.TypeFor[StructuredOutput](), []string{"AgentConfig", "structured_output"}, "StructuredOutput (AgentConfig.structured_output)"}, + {reflect.TypeFor[RAGConfig](), []string{"RAGConfig"}, "RAGConfig"}, + {reflect.TypeFor[RAGToolConfig](), []string{"RAGConfig", "tool"}, "RAGToolConfig (RAGConfig.tool)"}, + {reflect.TypeFor[RAGResultsConfig](), []string{"RAGConfig", "results"}, "RAGResultsConfig (RAGConfig.results)"}, + {reflect.TypeFor[RAGFusionConfig](), []string{"RAGConfig", "results", "fusion"}, "RAGFusionConfig (RAGConfig.results.fusion)"}, + {reflect.TypeFor[RAGRerankingConfig](), []string{"RAGConfig", "results", "reranking"}, "RAGRerankingConfig (RAGConfig.results.reranking)"}, + {reflect.TypeFor[RAGChunkingConfig](), []string{"RAGConfig", "strategies", "*", "chunking"}, "RAGChunkingConfig (RAGConfig.strategies[].chunking)"}, + } + + for _, il := range inlines { + def := navigateSchema(t, root, il.path) + entries = append(entries, entry{il.goType, def, il.name}) + } + + // Now compare each entry. + for _, e := range entries { + goFields := structJSONFields(e.goType) + schemaProps := schemaProperties(e.schemaDef) + + missingInSchema := diff(goFields, schemaProps) + missingInGo := diff(schemaProps, goFields) + + assert.Empty(t, sortedKeys(missingInSchema), + "%s: Go struct has JSON fields not present in the schema", e.schemaName) + assert.Empty(t, sortedKeys(missingInGo), + "%s: schema has properties not present in the Go struct", e.schemaName) + } +} + +// navigateSchema walks from a top-level definition through nested properties. +// path[0] is the definition name; subsequent elements are property names. +// The special element "*" dereferences an array's "items" schema. +func navigateSchema(t *testing.T, root jsonSchema, path []string) jsonSchema { + t.Helper() + require.NotEmpty(t, path) + + cur, ok := root.Definitions[path[0]] + require.True(t, ok, "definition %q not found", path[0]) + + // Resolve top-level $ref if present. + cur = cur.resolveRef(root) + + for _, segment := range path[1:] { + if segment == "*" { + require.NotNil(t, cur.Items, "expected items schema at %v", path) + cur = *cur.Items + cur = cur.resolveRef(root) + continue + } + prop, ok := cur.Properties[segment] + require.True(t, ok, "property %q not found at %v", segment, path) + prop = prop.resolveRef(root) + cur = prop + } + return cur +} + +// diff returns keys present in a but not in b. +func diff(a, b map[string]bool) map[string]bool { + d := make(map[string]bool) + for k := range a { + if !b[k] { + d[k] = true + } + } + return d +} diff --git a/pkg/config/v5/skills_config_test.go b/pkg/config/v5/skills_config_test.go new file mode 100644 index 000000000..732255569 --- /dev/null +++ b/pkg/config/v5/skills_config_test.go @@ -0,0 +1,270 @@ +package v5 + +import ( + "encoding/json" + "testing" + + "github.com/goccy/go-yaml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSkillsConfig_UnmarshalYAML(t *testing.T) { + tests := []struct { + name string + input string + expected SkillsConfig + }{ + { + name: "boolean true", + input: "true", + expected: SkillsConfig{Sources: []string{"local"}}, + }, + { + name: "boolean false", + input: "false", + expected: SkillsConfig{Sources: nil}, + }, + { + name: "list with local only", + input: "[local]", + expected: SkillsConfig{Sources: []string{"local"}}, + }, + { + name: "list with remote URL", + input: "[\"http://example.com\"]", + expected: SkillsConfig{Sources: []string{"http://example.com"}}, + }, + { + name: "list with local and remote", + input: "[local, \"https://skills.example.com\"]", + expected: SkillsConfig{Sources: []string{ + "local", + "https://skills.example.com", + }}, + }, + { + name: "multiline list", + input: `- local +- https://example.com +- http://internal.corp`, + expected: SkillsConfig{Sources: []string{ + "local", + "https://example.com", + "http://internal.corp", + }}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cfg SkillsConfig + err := yaml.Unmarshal([]byte(tt.input), &cfg) + require.NoError(t, err) + assert.Equal(t, tt.expected, cfg) + }) + } +} + +func TestSkillsConfig_MarshalYAML(t *testing.T) { + tests := []struct { + name string + input SkillsConfig + expected string + }{ + { + name: "disabled", + input: SkillsConfig{}, + expected: "false\n", + }, + { + name: "local only marshals as true", + input: SkillsConfig{Sources: []string{"local"}}, + expected: "true\n", + }, + { + name: "list with remote", + input: SkillsConfig{Sources: []string{"local", "https://example.com"}}, + expected: "- local\n- https://example.com\n", + }, + { + name: "remote only", + input: SkillsConfig{Sources: []string{"https://example.com"}}, + expected: "- https://example.com\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := yaml.Marshal(tt.input) + require.NoError(t, err) + assert.Equal(t, tt.expected, string(out)) + }) + } +} + +func TestSkillsConfig_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + expected SkillsConfig + }{ + { + name: "boolean true", + input: "true", + expected: SkillsConfig{Sources: []string{"local"}}, + }, + { + name: "boolean false", + input: "false", + expected: SkillsConfig{Sources: nil}, + }, + { + name: "list with local", + input: `["local"]`, + expected: SkillsConfig{Sources: []string{"local"}}, + }, + { + name: "list with remote URLs", + input: `["local", "https://skills.example.com"]`, + expected: SkillsConfig{Sources: []string{"local", "https://skills.example.com"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cfg SkillsConfig + err := json.Unmarshal([]byte(tt.input), &cfg) + require.NoError(t, err) + assert.Equal(t, tt.expected, cfg) + }) + } +} + +func TestSkillsConfig_MarshalJSON(t *testing.T) { + tests := []struct { + name string + input SkillsConfig + expected string + }{ + { + name: "disabled", + input: SkillsConfig{}, + expected: "false", + }, + { + name: "local only as true", + input: SkillsConfig{Sources: []string{"local"}}, + expected: "true", + }, + { + name: "list with remote", + input: SkillsConfig{Sources: []string{"local", "https://example.com"}}, + expected: `["local","https://example.com"]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := json.Marshal(tt.input) + require.NoError(t, err) + assert.Equal(t, tt.expected, string(out)) + }) + } +} + +func TestSkillsConfig_Enabled(t *testing.T) { + assert.False(t, SkillsConfig{}.Enabled()) + assert.False(t, SkillsConfig{Sources: nil}.Enabled()) + assert.False(t, SkillsConfig{Sources: []string{}}.Enabled()) + assert.True(t, SkillsConfig{Sources: []string{"local"}}.Enabled()) + assert.True(t, SkillsConfig{Sources: []string{"https://example.com"}}.Enabled()) +} + +func TestSkillsConfig_HasLocal(t *testing.T) { + assert.False(t, SkillsConfig{}.HasLocal()) + assert.False(t, SkillsConfig{Sources: []string{"https://example.com"}}.HasLocal()) + assert.True(t, SkillsConfig{Sources: []string{"local"}}.HasLocal()) + assert.True(t, SkillsConfig{Sources: []string{"local", "https://example.com"}}.HasLocal()) +} + +func TestSkillsConfig_RemoteURLs(t *testing.T) { + assert.Empty(t, SkillsConfig{}.RemoteURLs()) + assert.Empty(t, SkillsConfig{Sources: []string{"local"}}.RemoteURLs()) + assert.Equal(t, + []string{"https://example.com", "http://internal.corp"}, + SkillsConfig{Sources: []string{"local", "https://example.com", "http://internal.corp"}}.RemoteURLs(), + ) +} + +func TestSkillsConfig_JSONRoundTrip(t *testing.T) { + // This tests the upgrade path from v4 (bool) to v5 (SkillsConfig) via CloneThroughJSON + t.Run("bool true round trips through JSON", func(t *testing.T) { + jsonData := []byte("true") + var cfg SkillsConfig + require.NoError(t, json.Unmarshal(jsonData, &cfg)) + assert.True(t, cfg.Enabled()) + assert.True(t, cfg.HasLocal()) + assert.Equal(t, []string{"local"}, cfg.Sources) + + out, err := json.Marshal(cfg) + require.NoError(t, err) + assert.Equal(t, "true", string(out)) + }) + + t.Run("bool false round trips through JSON", func(t *testing.T) { + jsonData := []byte("false") + var cfg SkillsConfig + require.NoError(t, json.Unmarshal(jsonData, &cfg)) + assert.False(t, cfg.Enabled()) + assert.Nil(t, cfg.Sources) + + out, err := json.Marshal(cfg) + require.NoError(t, err) + assert.Equal(t, "false", string(out)) + }) + + t.Run("list round trips through JSON", func(t *testing.T) { + jsonData := []byte(`["local","https://example.com"]`) + var cfg SkillsConfig + require.NoError(t, json.Unmarshal(jsonData, &cfg)) + assert.True(t, cfg.Enabled()) + assert.Equal(t, []string{"local", "https://example.com"}, cfg.Sources) + + out, err := json.Marshal(cfg) + require.NoError(t, err) + assert.Equal(t, `["local","https://example.com"]`, string(out)) + }) +} + +func TestSkillsConfig_InAgentConfig(t *testing.T) { + yamlInput := ` +model: openai/gpt-4 +skills: + - local + - https://skills.example.com +toolsets: + - type: filesystem +` + var agent AgentConfig + err := yaml.Unmarshal([]byte(yamlInput), &agent) + require.NoError(t, err) + assert.True(t, agent.Skills.Enabled()) + assert.True(t, agent.Skills.HasLocal()) + assert.Equal(t, []string{"https://skills.example.com"}, agent.Skills.RemoteURLs()) +} + +func TestSkillsConfig_InAgentConfigBool(t *testing.T) { + yamlInput := ` +model: openai/gpt-4 +skills: true +toolsets: + - type: filesystem +` + var agent AgentConfig + err := yaml.Unmarshal([]byte(yamlInput), &agent) + require.NoError(t, err) + assert.True(t, agent.Skills.Enabled()) + assert.True(t, agent.Skills.HasLocal()) + assert.Empty(t, agent.Skills.RemoteURLs()) +} diff --git a/pkg/config/v5/types.go b/pkg/config/v5/types.go new file mode 100644 index 000000000..d8b5a505c --- /dev/null +++ b/pkg/config/v5/types.go @@ -0,0 +1,1310 @@ +package v5 + +import ( + "cmp" + "encoding/json" + "fmt" + "maps" + "strings" + "time" + + "github.com/goccy/go-yaml" + + "github.com/docker/cagent/pkg/config/types" +) + +const Version = "5" + +// Config represents the entire configuration file +type Config struct { + Version string `json:"version,omitempty"` + Agents Agents `json:"agents,omitempty"` + Providers map[string]ProviderConfig `json:"providers,omitempty"` + Models map[string]ModelConfig `json:"models,omitempty"` + RAG map[string]RAGConfig `json:"rag,omitempty"` + Metadata Metadata `json:"metadata"` + Permissions *PermissionsConfig `json:"permissions,omitempty"` +} + +type Agents []AgentConfig + +func (c *Agents) UnmarshalYAML(unmarshal func(any) error) error { + var items yaml.MapSlice + if err := unmarshal(&items); err != nil { + return err + } + + agents := make([]AgentConfig, 0, len(items)) + for _, item := range items { + name, ok := item.Key.(string) + if !ok { + return fmt.Errorf("agent name must be a string") + } + + valueBytes, err := yaml.Marshal(item.Value) + if err != nil { + return fmt.Errorf("failed to marshal agent config for %s: %w", name, err) + } + + var agent AgentConfig + if err := yaml.UnmarshalWithOptions(valueBytes, &agent, yaml.DisallowUnknownField()); err != nil { + return fmt.Errorf("failed to unmarshal agent config for %s: %w", name, err) + } + + agent.Name = name + agents = append(agents, agent) + } + + *c = agents + return nil +} + +func (c Agents) MarshalYAML() ([]byte, error) { + mapSlice := make(yaml.MapSlice, 0, len(c)) + + for _, agent := range c { + mapSlice = append(mapSlice, yaml.MapItem{ + Key: agent.Name, + Value: agent, + }) + } + + return yaml.Marshal(mapSlice) +} + +func (c *Agents) First() AgentConfig { + if len(*c) > 0 { + return (*c)[0] + } + panic("no agents configured") +} + +func (c *Agents) Lookup(name string) (AgentConfig, bool) { + for _, agent := range *c { + if agent.Name == name { + return agent, true + } + } + return AgentConfig{}, false +} + +func (c *Agents) Update(name string, update func(a *AgentConfig)) bool { + for i := range *c { + if (*c)[i].Name == name { + update(&(*c)[i]) + return true + } + } + return false +} + +// ProviderConfig represents a reusable provider definition. +// It allows users to define custom providers with default base URLs and token keys. +// Models can reference these providers by name, inheriting the defaults. +type ProviderConfig struct { + // APIType specifies which API schema to use. Supported values: + // - "openai_chatcompletions" (default): Use the OpenAI Chat Completions API + // - "openai_responses": Use the OpenAI Responses API + APIType string `json:"api_type,omitempty"` + // BaseURL is the base URL for the provider's API endpoint + BaseURL string `json:"base_url"` + // TokenKey is the environment variable name containing the API token + TokenKey string `json:"token_key,omitempty"` +} + +// FallbackConfig represents fallback model configuration for an agent. +// Controls which models to try when the primary fails and how retries/cooldowns work. +// Most users only need to specify Models — the defaults handle common scenarios automatically. +type FallbackConfig struct { + // Models is a list of fallback models to try in order if the primary fails. + // Each entry can be a model name from the models section or an inline provider/model format. + Models []string `json:"models,omitempty"` + // Retries is the number of retries per model with exponential backoff. + // Default is 2 (giving 3 total attempts per model). Use -1 to disable retries entirely. + // Retries only apply to retryable errors (5xx, timeouts); non-retryable errors (429, 4xx) + // skip immediately to the next model. + Retries int `json:"retries,omitempty"` + // Cooldown is the duration to stick with a successful fallback model before + // retrying the primary. Only applies after a non-retryable error (e.g., 429). + // Default is 1 minute. Use Go duration format (e.g., "1m", "30s", "2m30s"). + Cooldown Duration `json:"cooldown"` +} + +// Duration is a wrapper around time.Duration that supports YAML/JSON unmarshaling +// from string format (e.g., "1m", "30s", "2h30m"). +type Duration struct { + time.Duration +} + +// UnmarshalYAML implements custom unmarshaling for Duration from string format +func (d *Duration) UnmarshalYAML(unmarshal func(any) error) error { + if d == nil { + return fmt.Errorf("cannot unmarshal into nil Duration") + } + + var s string + if err := unmarshal(&s); err != nil { + // Try as integer (seconds) + var secs int + if err2 := unmarshal(&secs); err2 == nil { + d.Duration = time.Duration(secs) * time.Second + return nil + } + return err + } + if s == "" { + d.Duration = 0 + return nil + } + dur, err := time.ParseDuration(s) + if err != nil { + return fmt.Errorf("invalid duration format %q: %w", s, err) + } + d.Duration = dur + return nil +} + +// MarshalYAML implements custom marshaling for Duration to string format +func (d Duration) MarshalYAML() ([]byte, error) { + if d.Duration == 0 { + return yaml.Marshal("") + } + return yaml.Marshal(d.String()) +} + +// UnmarshalJSON implements custom unmarshaling for Duration from string format +func (d *Duration) UnmarshalJSON(data []byte) error { + if d == nil { + return fmt.Errorf("cannot unmarshal into nil Duration") + } + + var s string + if err := json.Unmarshal(data, &s); err != nil { + // Try as integer (seconds) + var secs int + if err2 := json.Unmarshal(data, &secs); err2 == nil { + d.Duration = time.Duration(secs) * time.Second + return nil + } + return err + } + if s == "" { + d.Duration = 0 + return nil + } + dur, err := time.ParseDuration(s) + if err != nil { + return fmt.Errorf("invalid duration format %q: %w", s, err) + } + d.Duration = dur + return nil +} + +// MarshalJSON implements custom marshaling for Duration to string format +func (d Duration) MarshalJSON() ([]byte, error) { + if d.Duration == 0 { + return json.Marshal("") + } + return json.Marshal(d.String()) +} + +// AgentConfig represents a single agent configuration +type AgentConfig struct { + Name string + Model string `json:"model,omitempty"` + Fallback *FallbackConfig `json:"fallback,omitempty"` + Description string `json:"description,omitempty"` + WelcomeMessage string `json:"welcome_message,omitempty"` + Toolsets []Toolset `json:"toolsets,omitempty"` + Instruction string `json:"instruction,omitempty"` + SubAgents []string `json:"sub_agents,omitempty"` + Handoffs []string `json:"handoffs,omitempty"` + RAG []string `json:"rag,omitempty"` + AddDate bool `json:"add_date,omitempty"` + AddEnvironmentInfo bool `json:"add_environment_info,omitempty"` + CodeModeTools bool `json:"code_mode_tools,omitempty"` + AddDescriptionParameter bool `json:"add_description_parameter,omitempty"` + MaxIterations int `json:"max_iterations,omitempty"` + NumHistoryItems int `json:"num_history_items,omitempty"` + AddPromptFiles []string `json:"add_prompt_files,omitempty" yaml:"add_prompt_files,omitempty"` + Commands types.Commands `json:"commands,omitempty"` + StructuredOutput *StructuredOutput `json:"structured_output,omitempty"` + Skills SkillsConfig `json:"skills,omitempty"` + Hooks *HooksConfig `json:"hooks,omitempty"` +} + +const SkillSourceLocal = "local" + +// SkillsConfig controls skill discovery sources for an agent. +// Supports three YAML formats: +// - Boolean: `skills: true` (equivalent to ["local"]) or `skills: false` (disabled) +// - List: `skills: ["local", "http://example.com"]` +// +// The special source "local" loads skills from the filesystem (standard locations). +// HTTP/HTTPS URLs load skills from remote servers per the well-known skills discovery spec. +type SkillsConfig struct { //nolint:recvcheck // MarshalYAML/MarshalJSON must use value receiver, UnmarshalYAML/UnmarshalJSON must use pointer + Sources []string +} + +func (s SkillsConfig) Enabled() bool { + return len(s.Sources) > 0 +} + +func (s SkillsConfig) HasLocal() bool { + for _, src := range s.Sources { + if src == SkillSourceLocal { + return true + } + } + return false +} + +func (s SkillsConfig) RemoteURLs() []string { + var urls []string + for _, src := range s.Sources { + if strings.HasPrefix(src, "http://") || strings.HasPrefix(src, "https://") { + urls = append(urls, src) + } + } + return urls +} + +func (s *SkillsConfig) UnmarshalYAML(unmarshal func(any) error) error { + var b bool + if err := unmarshal(&b); err == nil { + if b { + s.Sources = []string{SkillSourceLocal} + } else { + s.Sources = nil + } + return nil + } + + var sources []string + if err := unmarshal(&sources); err != nil { + return fmt.Errorf("skills must be a boolean or a list of sources") + } + s.Sources = sources + return nil +} + +func (s SkillsConfig) MarshalYAML() ([]byte, error) { + if len(s.Sources) == 0 { + return yaml.Marshal(false) + } + if len(s.Sources) == 1 && s.Sources[0] == SkillSourceLocal { + return yaml.Marshal(true) + } + return yaml.Marshal(s.Sources) +} + +func (s *SkillsConfig) UnmarshalJSON(data []byte) error { + var b bool + if err := json.Unmarshal(data, &b); err == nil { + if b { + s.Sources = []string{SkillSourceLocal} + } else { + s.Sources = nil + } + return nil + } + + var sources []string + if err := json.Unmarshal(data, &sources); err != nil { + return fmt.Errorf("skills must be a boolean or a list of sources") + } + s.Sources = sources + return nil +} + +func (s SkillsConfig) MarshalJSON() ([]byte, error) { + if len(s.Sources) == 0 { + return json.Marshal(false) + } + if len(s.Sources) == 1 && s.Sources[0] == SkillSourceLocal { + return json.Marshal(true) + } + return json.Marshal(s.Sources) +} + +// GetFallbackModels returns the fallback models from the config. +func (a *AgentConfig) GetFallbackModels() []string { + if a.Fallback != nil { + return a.Fallback.Models + } + return nil +} + +// GetFallbackRetries returns the fallback retries from the config. +func (a *AgentConfig) GetFallbackRetries() int { + if a.Fallback != nil { + return a.Fallback.Retries + } + return 0 +} + +// GetFallbackCooldown returns the fallback cooldown duration from the config. +// Returns the configured cooldown, or 0 if not set (caller should apply default). +func (a *AgentConfig) GetFallbackCooldown() time.Duration { + if a.Fallback != nil { + return a.Fallback.Cooldown.Duration + } + return 0 +} + +// ModelConfig represents the configuration for a model +type ModelConfig struct { + // Name is the manifest model name (map key), populated at runtime. + // Not serialized — set by teamloader/model_switcher when resolving models. + Name string `json:"-"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + // DisplayModel holds the original model name from the YAML config, before alias resolution. + // When set, provider.ID() returns Provider + "/" + DisplayModel instead of the resolved name. + // This ensures the UI shows the user-configured name (e.g., "claude-haiku-4-5") + // while the API uses the resolved name (e.g., "claude-haiku-4-5-20251001"). + DisplayModel string `json:"-"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int64 `json:"max_tokens,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + BaseURL string `json:"base_url,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + TokenKey string `json:"token_key,omitempty"` + // ProviderOpts allows provider-specific options. + ProviderOpts map[string]any `json:"provider_opts,omitempty"` + TrackUsage *bool `json:"track_usage,omitempty"` + // ThinkingBudget controls reasoning effort/budget: + // - For OpenAI: accepts string levels "minimal", "low", "medium", "high" + // - For Anthropic: accepts integer token budget (1024-32000) + // - For other providers: may be ignored + ThinkingBudget *ThinkingBudget `json:"thinking_budget,omitempty"` + // Routing defines rules for routing requests to different models. + // When routing is configured, this model becomes a rule-based router: + // - The provider/model fields define the fallback model + // - Each routing rule maps to a different model based on examples + Routing []RoutingRule `json:"routing,omitempty"` +} + +// Clone returns a deep copy of the ModelConfig. +func (m *ModelConfig) Clone() *ModelConfig { + if m == nil { + return nil + } + var c ModelConfig + types.CloneThroughJSON(m, &c) + // Preserve fields excluded from JSON serialization + c.Name = m.Name + c.DisplayModel = m.DisplayModel + return &c +} + +// DisplayOrModel returns DisplayModel if set (i.e., alias resolution preserved the original name), +// otherwise falls back to Model. +func (m *ModelConfig) DisplayOrModel() string { + return cmp.Or(m.DisplayModel, m.Model) +} + +// FlexibleModelConfig wraps ModelConfig to support both shorthand and full syntax. +// It can be unmarshaled from either: +// - A shorthand string: "provider/model" (e.g., "anthropic/claude-sonnet-4-5") +// - A full model definition with all options +type FlexibleModelConfig struct { + ModelConfig +} + +// UnmarshalYAML implements custom unmarshaling for flexible model config +func (f *FlexibleModelConfig) UnmarshalYAML(unmarshal func(any) error) error { + // Try string shorthand first + var shorthand string + if err := unmarshal(&shorthand); err == nil && shorthand != "" { + provider, model, ok := strings.Cut(shorthand, "/") + if !ok || provider == "" || model == "" { + return fmt.Errorf("invalid model shorthand %q: expected format 'provider/model'", shorthand) + } + f.Provider = provider + f.Model = model + return nil + } + + // Try full model config + var cfg ModelConfig + if err := unmarshal(&cfg); err != nil { + return err + } + f.ModelConfig = cfg + return nil +} + +// MarshalYAML outputs shorthand format if only provider/model are set +func (f FlexibleModelConfig) MarshalYAML() ([]byte, error) { + if f.isShorthandOnly() { + return yaml.Marshal(f.Provider + "/" + f.Model) + } + return yaml.Marshal(f.ModelConfig) +} + +// isShorthandOnly returns true if only provider and model are set +func (f *FlexibleModelConfig) isShorthandOnly() bool { + return f.Temperature == nil && + f.MaxTokens == nil && + f.TopP == nil && + f.FrequencyPenalty == nil && + f.PresencePenalty == nil && + f.BaseURL == "" && + f.ParallelToolCalls == nil && + f.TokenKey == "" && + len(f.ProviderOpts) == 0 && + f.TrackUsage == nil && + f.ThinkingBudget == nil && + len(f.Routing) == 0 +} + +// RoutingRule defines a single routing rule for model selection. +// Each rule maps example phrases to a target model. +type RoutingRule struct { + // Model is a reference to another model in the models section or an inline model spec (e.g., "openai/gpt-4o") + Model string `json:"model"` + // Examples are phrases that should trigger routing to this model + Examples []string `json:"examples"` +} + +type Metadata struct { + Author string `json:"author,omitempty"` + License string `json:"license,omitempty"` + Description string `json:"description,omitempty"` + Readme string `json:"readme,omitempty"` + Version string `json:"version,omitempty"` +} + +// Commands represents a set of named prompts for quick-starting conversations. +// It supports two YAML formats: +// +// commands: +// +// df: "check disk space" +// ls: "list files" +// +// or +// +// commands: +// - df: "check disk space" +// - ls: "list files" +// Commands YAML unmarshalling is implemented in pkg/config/types/commands.go + +// ScriptShellToolConfig represents a custom shell tool configuration +type ScriptShellToolConfig struct { + Cmd string `json:"cmd"` + Description string `json:"description"` + + // Args is directly passed as "properties" in the JSON schema + Args map[string]any `json:"args,omitempty"` + + // Required is directly passed as "required" in the JSON schema + Required []string `json:"required"` + + Env map[string]string `json:"env,omitempty"` + WorkingDir string `json:"working_dir,omitempty"` +} + +type APIToolConfig struct { + Instruction string `json:"instruction,omitempty"` + Name string `json:"name,omitempty"` + Required []string `json:"required,omitempty"` + Args map[string]any `json:"args,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + Method string `json:"method,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + // OutputSchema optionally describes the API response as JSON Schema for MCP/Code Mode consumers; runtime still returns the raw string body. + OutputSchema map[string]any `json:"output_schema,omitempty"` +} + +// PostEditConfig represents a post-edit command configuration +type PostEditConfig struct { + Path string `json:"path"` + Cmd string `json:"cmd"` +} + +// Toolset represents a tool configuration +type Toolset struct { + Type string `json:"type,omitempty"` + Tools []string `json:"tools,omitempty"` + Instruction string `json:"instruction,omitempty"` + Toon string `json:"toon,omitempty"` + + Defer DeferConfig `json:"defer" yaml:"defer,omitempty"` + + // For the `mcp` tool + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + Ref string `json:"ref,omitempty"` + Remote Remote `json:"remote"` + Config any `json:"config,omitempty"` + + // For the `a2a` and `openapi` tools + Name string `json:"name,omitempty"` + URL string `json:"url,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + + // For `shell`, `script`, `mcp` or `lsp` tools + Env map[string]string `json:"env,omitempty"` + + // For the `shell` tool - sandbox mode + Sandbox *SandboxConfig `json:"sandbox,omitempty"` + + // For the `todo` tool + Shared bool `json:"shared,omitempty"` + + // For the `memory` and `tasks` tools + Path string `json:"path,omitempty"` + + // For the `script` tool + Shell map[string]ScriptShellToolConfig `json:"shell,omitempty"` + + // For the `filesystem` tool - post-edit commands + PostEdit []PostEditConfig `json:"post_edit,omitempty"` + + APIConfig APIToolConfig `json:"api_config"` + + // For the `filesystem` tool - VCS integration + IgnoreVCS *bool `json:"ignore_vcs,omitempty"` + + // For the `fetch` tool + Timeout int `json:"timeout,omitempty"` +} + +func (t *Toolset) UnmarshalYAML(unmarshal func(any) error) error { + type alias Toolset + var tmp alias + if err := unmarshal(&tmp); err != nil { + return err + } + *t = Toolset(tmp) + return t.validate() +} + +type Remote struct { + URL string `json:"url"` + TransportType string `json:"transport_type,omitempty"` + Headers map[string]string `json:"headers,omitempty"` +} + +// SandboxConfig represents the configuration for running shell commands in a Docker container. +// When enabled, all shell commands run inside a sandboxed Linux container with only +// specified paths bind-mounted. +type SandboxConfig struct { + // Image is the Docker image to use for the sandbox container. + // Defaults to "alpine:latest" if not specified. + Image string `json:"image,omitempty"` + + // Paths is a list of paths to bind-mount into the container. + // Each path can optionally have a ":ro" suffix for read-only access. + // Default is read-write (:rw) if no suffix is specified. + // Example: [".", "/tmp", "/config:ro"] + Paths []string `json:"paths"` +} + +// DeferConfig represents the deferred loading configuration for a toolset. +// It can be either a boolean (true to defer all tools) or a slice of strings +// (list of tool names to defer). +type DeferConfig struct { //nolint:recvcheck // MarshalYAML must use value receiver for YAML slice encoding, UnmarshalYAML must use pointer + // DeferAll is true when all tools should be deferred + DeferAll bool `json:"-"` + // Tools is the list of specific tool names to defer (empty if DeferAll is true) + Tools []string `json:"-"` +} + +func (d DeferConfig) IsEmpty() bool { + return !d.DeferAll && len(d.Tools) == 0 +} + +func (d *DeferConfig) UnmarshalYAML(unmarshal func(any) error) error { + var b bool + if err := unmarshal(&b); err == nil { + d.DeferAll = b + d.Tools = nil + return nil + } + + var tools []string + if err := unmarshal(&tools); err == nil { + d.DeferAll = false + d.Tools = tools + return nil + } + + return nil +} + +func (d DeferConfig) MarshalYAML() ([]byte, error) { + if d.DeferAll { + return yaml.Marshal(true) + } + if len(d.Tools) == 0 { + // Return false for empty config - this will be omitted by yaml encoder + return yaml.Marshal(false) + } + return yaml.Marshal(d.Tools) +} + +// ThinkingBudget represents reasoning budget configuration. +// It accepts either a string effort level or an integer token budget: +// - String: "minimal", "low", "medium", "high" (for OpenAI) +// - Integer: token count (for Anthropic, range 1024-32768) +type ThinkingBudget struct { + // Effort stores string-based reasoning effort levels + Effort string `json:"effort,omitempty"` + // Tokens stores integer-based token budgets + Tokens int `json:"tokens,omitempty"` +} + +func (t *ThinkingBudget) UnmarshalYAML(unmarshal func(any) error) error { + // Try integer tokens first + var n int + if err := unmarshal(&n); err == nil { + *t = ThinkingBudget{Tokens: n} + return nil + } + + // Try string level + var s string + if err := unmarshal(&s); err == nil { + *t = ThinkingBudget{Effort: s} + return nil + } + + return nil +} + +// MarshalYAML implements custom marshaling to output simple string or int format +func (t ThinkingBudget) MarshalYAML() ([]byte, error) { + // If Effort string is set (non-empty), marshal as string + if t.Effort != "" { + return yaml.Marshal(t.Effort) + } + + // Otherwise marshal as integer (includes 0, -1, and positive values) + return yaml.Marshal(t.Tokens) +} + +// MarshalJSON implements custom marshaling to output simple string or int format +// This ensures JSON and YAML have the same flattened format for consistency +func (t ThinkingBudget) MarshalJSON() ([]byte, error) { + // If Effort string is set (non-empty), marshal as string + if t.Effort != "" { + return fmt.Appendf(nil, "%q", t.Effort), nil + } + + // Otherwise marshal as integer (includes 0, -1, and positive values) + return fmt.Appendf(nil, "%d", t.Tokens), nil +} + +// UnmarshalJSON implements custom unmarshaling to accept simple string or int format +// This ensures JSON and YAML have the same flattened format for consistency +func (t *ThinkingBudget) UnmarshalJSON(data []byte) error { + // Try integer tokens first + var n int + if err := json.Unmarshal(data, &n); err == nil { + *t = ThinkingBudget{Tokens: n} + return nil + } + + // Try string level + var s string + if err := json.Unmarshal(data, &s); err == nil { + *t = ThinkingBudget{Effort: s} + return nil + } + + return nil +} + +// StructuredOutput defines a JSON schema for structured output +type StructuredOutput struct { + // Name is the name of the response format + Name string `json:"name"` + // Description is optional description of the response format + Description string `json:"description,omitempty"` + // Schema is a JSON schema object defining the structure + Schema map[string]any `json:"schema"` + // Strict enables strict schema adherence (OpenAI only) + Strict bool `json:"strict,omitempty"` +} + +// RAGToolConfig represents tool-specific configuration for a RAG source +type RAGToolConfig struct { + Name string `json:"name,omitempty"` // Custom name for the tool (defaults to RAG source name if empty) + Description string `json:"description,omitempty"` // Tool description (what the tool does) + Instruction string `json:"instruction,omitempty"` // Tool instruction (how to use the tool effectively) +} + +// RAGConfig represents a RAG (Retrieval-Augmented Generation) configuration +// Uses a unified strategies array for flexible, extensible configuration +type RAGConfig struct { + Tool RAGToolConfig `json:"tool"` // Tool configuration + Docs []string `json:"docs,omitempty"` // Shared documents across all strategies + RespectVCS *bool `json:"respect_vcs,omitempty"` // Whether to respect VCS ignore files like .gitignore (default: true) + Strategies []RAGStrategyConfig `json:"strategies,omitempty"` // Array of strategy configurations + Results RAGResultsConfig `json:"results"` +} + +// GetRespectVCS returns whether VCS ignore files should be respected, defaulting to true +func (c *RAGConfig) GetRespectVCS() bool { + if c.RespectVCS == nil { + return true + } + return *c.RespectVCS +} + +// RAGStrategyConfig represents a single retrieval strategy configuration +// Strategy-specific fields are stored in Params (validated by strategy implementation) +type RAGStrategyConfig struct { //nolint:recvcheck // Marshal methods must use value receiver for YAML/JSON slice encoding, Unmarshal must use pointer + Type string `json:"type"` // Strategy type: "chunked-embeddings", "bm25", etc. + Docs []string `json:"docs,omitempty"` // Strategy-specific documents (augments shared docs) + Database RAGDatabaseConfig `json:"database"` // Database configuration + Chunking RAGChunkingConfig `json:"chunking"` // Chunking configuration + Limit int `json:"limit,omitempty"` // Max results from this strategy (for fusion input) + + // Strategy-specific parameters (arbitrary key-value pairs) + // Examples: + // - chunked-embeddings: embedding_model, similarity_metric, threshold, vector_dimensions + // - bm25: k1, b, threshold + Params map[string]any // Flattened into parent JSON +} + +// UnmarshalYAML implements custom unmarshaling to capture all extra fields into Params +// This allows strategies to have flexible, strategy-specific configuration parameters +// without requiring changes to the core config schema +func (s *RAGStrategyConfig) UnmarshalYAML(unmarshal func(any) error) error { + // First unmarshal into a map to capture everything + var raw map[string]any + if err := unmarshal(&raw); err != nil { + return err + } + + // Extract known fields + if t, ok := raw["type"].(string); ok { + s.Type = t + delete(raw, "type") + } + + if docs, ok := raw["docs"].([]any); ok { + s.Docs = make([]string, len(docs)) + for i, d := range docs { + if str, ok := d.(string); ok { + s.Docs[i] = str + } + } + delete(raw, "docs") + } + + if dbRaw, ok := raw["database"]; ok { + // Unmarshal database config using helper + var db RAGDatabaseConfig + unmarshalDatabaseConfig(dbRaw, &db) + s.Database = db + delete(raw, "database") + } + + if chunkRaw, ok := raw["chunking"]; ok { + var chunk RAGChunkingConfig + unmarshalChunkingConfig(chunkRaw, &chunk) + s.Chunking = chunk + delete(raw, "chunking") + } + + if limit, ok := raw["limit"].(int); ok { + s.Limit = limit + delete(raw, "limit") + } + + // Everything else goes into Params for strategy-specific configuration + s.Params = raw + + return nil +} + +// MarshalYAML implements custom marshaling to flatten Params into parent level +func (s RAGStrategyConfig) MarshalYAML() ([]byte, error) { + result := s.buildFlattenedMap() + return yaml.Marshal(result) +} + +// MarshalJSON implements custom marshaling to flatten Params into parent level +// This ensures JSON and YAML have the same flattened format for consistency +func (s RAGStrategyConfig) MarshalJSON() ([]byte, error) { + result := s.buildFlattenedMap() + return json.Marshal(result) +} + +// UnmarshalJSON implements custom unmarshaling to capture all extra fields into Params +// This ensures JSON and YAML have the same flattened format for consistency +func (s *RAGStrategyConfig) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to capture everything + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Extract known fields + if t, ok := raw["type"].(string); ok { + s.Type = t + delete(raw, "type") + } + + if docs, ok := raw["docs"].([]any); ok { + s.Docs = make([]string, len(docs)) + for i, d := range docs { + if str, ok := d.(string); ok { + s.Docs[i] = str + } + } + delete(raw, "docs") + } + + if dbRaw, ok := raw["database"]; ok { + if dbStr, ok := dbRaw.(string); ok { + var db RAGDatabaseConfig + db.value = dbStr + s.Database = db + } + delete(raw, "database") + } + + if chunkRaw, ok := raw["chunking"]; ok { + // Re-marshal and unmarshal chunking config + chunkBytes, _ := json.Marshal(chunkRaw) + var chunk RAGChunkingConfig + if err := json.Unmarshal(chunkBytes, &chunk); err == nil { + s.Chunking = chunk + } + delete(raw, "chunking") + } + + if limit, ok := raw["limit"].(float64); ok { + s.Limit = int(limit) + delete(raw, "limit") + } + + // Everything else goes into Params for strategy-specific configuration + s.Params = raw + + return nil +} + +// buildFlattenedMap creates a flattened map representation for marshaling +// Used by both MarshalYAML and MarshalJSON to ensure consistent format +func (s RAGStrategyConfig) buildFlattenedMap() map[string]any { + result := make(map[string]any) + + if s.Type != "" { + result["type"] = s.Type + } + if len(s.Docs) > 0 { + result["docs"] = s.Docs + } + if !s.Database.IsEmpty() { + dbStr, _ := s.Database.AsString() + result["database"] = dbStr + } + // Only include chunking if any fields are set + if s.Chunking.Size > 0 || s.Chunking.Overlap > 0 || s.Chunking.RespectWordBoundaries { + result["chunking"] = s.Chunking + } + if s.Limit > 0 { + result["limit"] = s.Limit + } + + // Flatten Params into the same level + maps.Copy(result, s.Params) + + return result +} + +// unmarshalDatabaseConfig handles DatabaseConfig unmarshaling from raw YAML data. +// For RAG strategies, the database configuration is intentionally simple: +// a single string value under the `database` key that points to the SQLite +// database file on disk. TODO(krissetto): eventually support more db types +func unmarshalDatabaseConfig(src any, dst *RAGDatabaseConfig) { + s, ok := src.(string) + if !ok { + return + } + + dst.value = s +} + +// unmarshalChunkingConfig handles ChunkingConfig unmarshaling from raw YAML data +func unmarshalChunkingConfig(src any, dst *RAGChunkingConfig) { + m, ok := src.(map[string]any) + if !ok { + return + } + + // Handle size - try various numeric types that YAML might produce + if size, ok := m["size"]; ok { + dst.Size = coerceToInt(size) + } + + // Handle overlap - try various numeric types that YAML might produce + if overlap, ok := m["overlap"]; ok { + dst.Overlap = coerceToInt(overlap) + } + + // Handle respect_word_boundaries - YAML should give us a bool + if rwb, ok := m["respect_word_boundaries"]; ok { + if val, ok := rwb.(bool); ok { + dst.RespectWordBoundaries = val + } + } + + // Handle code_aware - YAML should give us a bool + if ca, ok := m["code_aware"]; ok { + if val, ok := ca.(bool); ok { + dst.CodeAware = val + } + } +} + +// coerceToInt converts various numeric types to int +func coerceToInt(v any) int { + switch val := v.(type) { + case int: + return val + case int64: + return int(val) + case uint64: + return int(val) + case float64: + return int(val) + default: + return 0 + } +} + +// RAGDatabaseConfig represents database configuration for RAG strategies. +// Currently it only supports a single string value which is interpreted as +// the path to a SQLite database file. +type RAGDatabaseConfig struct { + value any // nil (unset) or string path +} + +// UnmarshalYAML implements custom unmarshaling for DatabaseConfig +func (d *RAGDatabaseConfig) UnmarshalYAML(unmarshal func(any) error) error { + var str string + if err := unmarshal(&str); err == nil { + d.value = str + return nil + } + + return fmt.Errorf("database must be a string path to a sqlite database") +} + +// AsString returns the database config as a connection string +// For simple string configs, returns as-is +// For structured configs, builds connection string based on type +func (d *RAGDatabaseConfig) AsString() (string, error) { + if d.value == nil { + return "", nil + } + + if str, ok := d.value.(string); ok { + return str, nil + } + + return "", fmt.Errorf("invalid database configuration: expected string path") +} + +// IsEmpty returns true if no database is configured +func (d *RAGDatabaseConfig) IsEmpty() bool { + return d.value == nil +} + +// RAGChunkingConfig represents text chunking configuration +type RAGChunkingConfig struct { + Size int `json:"size,omitempty"` + Overlap int `json:"overlap,omitempty"` + RespectWordBoundaries bool `json:"respect_word_boundaries,omitempty"` + // CodeAware enables code-aware chunking for source files. When true, the + // chunking strategy uses tree-sitter for AST-based chunking, producing + // semantically aligned chunks (e.g., whole functions). Falls back to + // plain text chunking for unsupported languages. + CodeAware bool `json:"code_aware,omitempty"` +} + +// UnmarshalYAML implements custom unmarshaling to apply sensible defaults for chunking +func (c *RAGChunkingConfig) UnmarshalYAML(unmarshal func(any) error) error { + // Use a struct with pointer to distinguish "not set" from "explicitly set to false" + var raw struct { + Size int `yaml:"size"` + Overlap int `yaml:"overlap"` + RespectWordBoundaries *bool `yaml:"respect_word_boundaries"` + } + + if err := unmarshal(&raw); err != nil { + return err + } + + c.Size = raw.Size + c.Overlap = raw.Overlap + + // Apply default of true for RespectWordBoundaries if not explicitly set + if raw.RespectWordBoundaries != nil { + c.RespectWordBoundaries = *raw.RespectWordBoundaries + } else { + c.RespectWordBoundaries = true + } + + return nil +} + +// RAGResultsConfig represents result post-processing configuration (common across strategies) +type RAGResultsConfig struct { + Limit int `json:"limit,omitempty"` // Maximum number of results to return (top K) + Fusion *RAGFusionConfig `json:"fusion,omitempty"` // How to combine results from multiple strategies + Reranking *RAGRerankingConfig `json:"reranking,omitempty"` // Optional reranking configuration + Deduplicate bool `json:"deduplicate,omitempty"` // Remove duplicate documents across strategies + IncludeScore bool `json:"include_score,omitempty"` // Include relevance scores in results + ReturnFullContent bool `json:"return_full_content,omitempty"` // Return full document content instead of just matched chunks +} + +// RAGRerankingConfig represents reranking configuration +type RAGRerankingConfig struct { + Model string `json:"model"` // Model reference for reranking (e.g., "hf.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF") + TopK int `json:"top_k,omitempty"` // Optional: only rerank top K results (0 = rerank all) + Threshold float64 `json:"threshold,omitempty"` // Optional: minimum score threshold after reranking (default: 0.5) + Criteria string `json:"criteria,omitempty"` // Optional: domain-specific relevance criteria to guide scoring +} + +// UnmarshalYAML implements custom unmarshaling to apply sensible defaults for reranking +func (r *RAGRerankingConfig) UnmarshalYAML(unmarshal func(any) error) error { + // Use a struct with pointer to distinguish "not set" from "explicitly set to 0" + var raw struct { + Model string `yaml:"model"` + TopK int `yaml:"top_k"` + Threshold *float64 `yaml:"threshold"` + Criteria string `yaml:"criteria"` + } + + if err := unmarshal(&raw); err != nil { + return err + } + + r.Model = raw.Model + r.TopK = raw.TopK + r.Criteria = raw.Criteria + + // Apply default threshold of 0.5 if not explicitly set + // This filters documents with negative logits (sigmoid < 0.5 = not relevant) + if raw.Threshold != nil { + r.Threshold = *raw.Threshold + } else { + r.Threshold = 0.5 + } + + return nil +} + +// defaultRAGResultsConfig returns the default results configuration +func defaultRAGResultsConfig() RAGResultsConfig { + return RAGResultsConfig{ + Limit: 15, + Deduplicate: true, + IncludeScore: false, + ReturnFullContent: false, + } +} + +// UnmarshalYAML implements custom unmarshaling so we can apply sensible defaults +func (r *RAGResultsConfig) UnmarshalYAML(unmarshal func(any) error) error { + var raw struct { + Limit int `json:"limit,omitempty"` + Fusion *RAGFusionConfig `json:"fusion,omitempty"` + Reranking *RAGRerankingConfig `json:"reranking,omitempty"` + Deduplicate *bool `json:"deduplicate,omitempty"` + IncludeScore *bool `json:"include_score,omitempty"` + ReturnFullContent *bool `json:"return_full_content,omitempty"` + } + + if err := unmarshal(&raw); err != nil { + return err + } + + // Start from defaults and then overwrite with any provided values. + def := defaultRAGResultsConfig() + *r = def + + if raw.Limit != 0 { + r.Limit = raw.Limit + } + r.Fusion = raw.Fusion + r.Reranking = raw.Reranking + + if raw.Deduplicate != nil { + r.Deduplicate = *raw.Deduplicate + } + if raw.IncludeScore != nil { + r.IncludeScore = *raw.IncludeScore + } + if raw.ReturnFullContent != nil { + r.ReturnFullContent = *raw.ReturnFullContent + } + + return nil +} + +// UnmarshalYAML for RAGConfig ensures that the Results field is always +// initialized with defaults, even when the `results` block is omitted. +func (c *RAGConfig) UnmarshalYAML(unmarshal func(any) error) error { + type alias RAGConfig + tmp := alias{ + Results: defaultRAGResultsConfig(), + } + if err := unmarshal(&tmp); err != nil { + return err + } + *c = RAGConfig(tmp) + return nil +} + +// RAGFusionConfig represents configuration for combining multi-strategy results +type RAGFusionConfig struct { + Strategy string `json:"strategy,omitempty"` // Fusion strategy: "rrf" (Reciprocal Rank Fusion), "weighted", "max" + K int `json:"k,omitempty"` // RRF parameter k (default: 60) + Weights map[string]float64 `json:"weights,omitempty"` // Strategy weights for weighted fusion +} + +// PermissionsConfig represents tool permission configuration. +// Allow/Ask/Deny model. This controls tool call approval behavior: +// - Allow: Tools matching these patterns are auto-approved (like --yolo for specific tools) +// - Ask: Tools matching these patterns always require user approval, even if the tool is read-only +// - Deny: Tools matching these patterns are always rejected, even with --yolo +// +// Patterns support glob-style matching (e.g., "shell", "read_*", "mcp:github:*") +// The evaluation order is: Deny (checked first), then Allow, then Ask (explicit), then default +// (read-only tools auto-approved, others ask) +type PermissionsConfig struct { + // Allow lists tool name patterns that are auto-approved without user confirmation + Allow []string `json:"allow,omitempty"` + // Ask lists tool name patterns that always require user confirmation, + // even for tools that are normally auto-approved (e.g. read-only tools) + Ask []string `json:"ask,omitempty"` + // Deny lists tool name patterns that are always rejected + Deny []string `json:"deny,omitempty"` +} + +// HooksConfig represents the hooks configuration for an agent. +// Hooks allow running shell commands at various points in the agent lifecycle. +type HooksConfig struct { + // PreToolUse hooks run before tool execution + PreToolUse []HookMatcherConfig `json:"pre_tool_use,omitempty" yaml:"pre_tool_use,omitempty"` + + // PostToolUse hooks run after tool execution + PostToolUse []HookMatcherConfig `json:"post_tool_use,omitempty" yaml:"post_tool_use,omitempty"` + + // SessionStart hooks run when a session begins + SessionStart []HookDefinition `json:"session_start,omitempty" yaml:"session_start,omitempty"` + + // SessionEnd hooks run when a session ends + SessionEnd []HookDefinition `json:"session_end,omitempty" yaml:"session_end,omitempty"` +} + +// IsEmpty returns true if no hooks are configured +func (h *HooksConfig) IsEmpty() bool { + if h == nil { + return true + } + return len(h.PreToolUse) == 0 && + len(h.PostToolUse) == 0 && + len(h.SessionStart) == 0 && + len(h.SessionEnd) == 0 +} + +// HookMatcherConfig represents a hook matcher with its hooks. +// Used for tool-related hooks (PreToolUse, PostToolUse). +type HookMatcherConfig struct { + // Matcher is a regex pattern to match tool names (e.g., "shell|edit_file") + // Use "*" to match all tools. Case-sensitive. + Matcher string `json:"matcher,omitempty" yaml:"matcher,omitempty"` + + // Hooks are the hooks to execute when the matcher matches + Hooks []HookDefinition `json:"hooks" yaml:"hooks"` +} + +// HookDefinition represents a single hook configuration +type HookDefinition struct { + // Type specifies the hook type (currently only "command" is supported) + Type string `json:"type" yaml:"type"` + + // Command is the shell command to execute + Command string `json:"command,omitempty" yaml:"command,omitempty"` + + // Timeout is the execution timeout in seconds (default: 60) + Timeout int `json:"timeout,omitempty" yaml:"timeout,omitempty"` +} + +// validate validates the HooksConfig +func (h *HooksConfig) validate() error { + // Validate PreToolUse matchers + for i, m := range h.PreToolUse { + if err := m.validate("pre_tool_use", i); err != nil { + return err + } + } + + // Validate PostToolUse matchers + for i, m := range h.PostToolUse { + if err := m.validate("post_tool_use", i); err != nil { + return err + } + } + + // Validate SessionStart hooks + for i, hook := range h.SessionStart { + if err := hook.validate("session_start", i); err != nil { + return err + } + } + + // Validate SessionEnd hooks + for i, hook := range h.SessionEnd { + if err := hook.validate("session_end", i); err != nil { + return err + } + } + + return nil +} + +// validate validates a HookMatcherConfig +func (m *HookMatcherConfig) validate(eventType string, index int) error { + if len(m.Hooks) == 0 { + return fmt.Errorf("hooks.%s[%d]: at least one hook is required", eventType, index) + } + + for i, hook := range m.Hooks { + if err := hook.validate(fmt.Sprintf("%s[%d].hooks", eventType, index), i); err != nil { + return err + } + } + + return nil +} + +// validate validates a HookDefinition +func (h *HookDefinition) validate(prefix string, index int) error { + if h.Type == "" { + return fmt.Errorf("hooks.%s[%d]: type is required", prefix, index) + } + + if h.Type != "command" { + return fmt.Errorf("hooks.%s[%d]: unsupported hook type '%s' (only 'command' is supported)", prefix, index, h.Type) + } + + if h.Command == "" { + return fmt.Errorf("hooks.%s[%d]: command is required for command hooks", prefix, index) + } + + return nil +} diff --git a/pkg/config/v5/types_test.go b/pkg/config/v5/types_test.go new file mode 100644 index 000000000..2cdc2413a --- /dev/null +++ b/pkg/config/v5/types_test.go @@ -0,0 +1,253 @@ +package v5 + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/config/types" +) + +func TestCommandsUnmarshal_Map(t *testing.T) { + var c types.Commands + input := []byte(` +df: "check disk" +ls: "list files" +`) + err := yaml.Unmarshal(input, &c) + require.NoError(t, err) + require.Equal(t, "check disk", c["df"].Instruction) + require.Equal(t, "list files", c["ls"].Instruction) +} + +func TestCommandsUnmarshal_List(t *testing.T) { + var c types.Commands + input := []byte(` +- df: "check disk" +- ls: "list files" +`) + err := yaml.Unmarshal(input, &c) + require.NoError(t, err) + require.Equal(t, "check disk", c["df"].Instruction) + require.Equal(t, "list files", c["ls"].Instruction) +} + +func TestThinkingBudget_MarshalUnmarshal_String(t *testing.T) { + t.Parallel() + + // Test string effort level + input := []byte(`thinking_budget: minimal`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Equal(t, "minimal", config.ThinkingBudget.Effort) + require.Equal(t, 0, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: minimal\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_Integer(t *testing.T) { + t.Parallel() + + // Test integer token budget + input := []byte(`thinking_budget: 8192`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, 8192, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: 8192\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_NegativeInteger(t *testing.T) { + t.Parallel() + + // Test negative integer token budget (e.g., -1 for Gemini dynamic thinking) + input := []byte(`thinking_budget: -1`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, -1, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: -1\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_Zero(t *testing.T) { + t.Parallel() + + // Test zero token budget (e.g., 0 for Gemini no thinking) + input := []byte(`thinking_budget: 0`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, 0, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: 0\n", string(output)) +} + +func TestAgents_UnmarshalYAML_RejectsUnknownFields(t *testing.T) { + t.Parallel() + + // "instructions" (plural) is not a valid field; the correct field is "instruction" (singular). + // Agents.UnmarshalYAML must reject it so that typos don't silently drop config. + input := []byte(`version: "5" +agents: + root: + model: openai/gpt-4o + instructions: "You are a helpful assistant." +`) + + _, err := parse(input) + require.Error(t, err) + require.Contains(t, err.Error(), "instructions") +} + +func TestAgents_UnmarshalYAML_AcceptsValidConfig(t *testing.T) { + t.Parallel() + + input := []byte(`version: "5" +agents: + root: + model: openai/gpt-4o + instruction: "You are a helpful assistant." +`) + + cfg, err := parse(input) + require.NoError(t, err) + require.Len(t, cfg.Agents, 1) + require.Equal(t, "root", cfg.Agents[0].Name) + require.Equal(t, "You are a helpful assistant.", cfg.Agents[0].Instruction) +} + +func TestRAGStrategyConfig_MarshalUnmarshal_FlattenedParams(t *testing.T) { + t.Parallel() + + // Test that params are flattened during unmarshal and remain flattened after marshal + input := []byte(`type: chunked-embeddings +model: embeddinggemma +database: ./rag/test.db +threshold: 0.5 +vector_dimensions: 768 +`) + + var strategy RAGStrategyConfig + + // Unmarshal + err := yaml.Unmarshal(input, &strategy) + require.NoError(t, err) + require.Equal(t, "chunked-embeddings", strategy.Type) + require.Equal(t, "./rag/test.db", mustGetDBString(t, strategy.Database)) + require.NotNil(t, strategy.Params) + require.Equal(t, "embeddinggemma", strategy.Params["model"]) + require.InEpsilon(t, 0.5, strategy.Params["threshold"], 0.001) + // YAML may unmarshal numbers as different numeric types (int, uint64, float64) + require.InEpsilon(t, float64(768), toFloat64(strategy.Params["vector_dimensions"]), 0.001) + + // Marshal back + output, err := yaml.Marshal(strategy) + require.NoError(t, err) + + // Verify it's still flattened (no "params:" key) + outputStr := string(output) + require.Contains(t, outputStr, "type: chunked-embeddings") + require.Contains(t, outputStr, "model: embeddinggemma") + require.Contains(t, outputStr, "threshold: 0.5") + require.Contains(t, outputStr, "vector_dimensions: 768") + require.NotContains(t, outputStr, "params:") + + // Unmarshal again to verify round-trip + var strategy2 RAGStrategyConfig + err = yaml.Unmarshal(output, &strategy2) + require.NoError(t, err) + require.Equal(t, strategy.Type, strategy2.Type) + require.Equal(t, strategy.Params["model"], strategy2.Params["model"]) + require.Equal(t, strategy.Params["threshold"], strategy2.Params["threshold"]) + // YAML may unmarshal numbers as different numeric types (int, uint64, float64) + // Just verify the numeric value is correct + require.InEpsilon(t, float64(768), toFloat64(strategy2.Params["vector_dimensions"]), 0.001) +} + +func TestRAGStrategyConfig_MarshalUnmarshal_WithDatabase(t *testing.T) { + t.Parallel() + + input := []byte(`type: chunked-embeddings +database: ./test.db +model: test-model +`) + + var strategy RAGStrategyConfig + err := yaml.Unmarshal(input, &strategy) + require.NoError(t, err) + + // Marshal back + output, err := yaml.Marshal(strategy) + require.NoError(t, err) + + // Should contain database as a simple string, not nested with sub-fields + outputStr := string(output) + require.Contains(t, outputStr, "database: ./test.db") + require.NotContains(t, outputStr, " value:") // Should not be nested with internal fields + require.Contains(t, outputStr, "model: test-model") + require.NotContains(t, outputStr, "params:") // Should be flattened +} + +func mustGetDBString(t *testing.T, db RAGDatabaseConfig) string { + t.Helper() + str, err := db.AsString() + require.NoError(t, err) + return str +} + +// toFloat64 converts various numeric types to float64 for comparison +func toFloat64(v any) float64 { + switch val := v.(type) { + case int: + return float64(val) + case int64: + return float64(val) + case uint64: + return float64(val) + case float64: + return val + case float32: + return float64(val) + default: + return 0 + } +} diff --git a/pkg/config/v5/validate.go b/pkg/config/v5/validate.go new file mode 100644 index 000000000..9ca79eac4 --- /dev/null +++ b/pkg/config/v5/validate.go @@ -0,0 +1,157 @@ +package v5 + +import ( + "errors" + "strings" +) + +func (t *Config) UnmarshalYAML(unmarshal func(any) error) error { + type alias Config + var tmp alias + if err := unmarshal(&tmp); err != nil { + return err + } + *t = Config(tmp) + return t.validate() +} + +func (t *Config) validate() error { + for i := range t.Agents { + agent := &t.Agents[i] + + // Validate fallback config + if err := agent.validateFallback(); err != nil { + return err + } + + for j := range agent.Toolsets { + if err := agent.Toolsets[j].validate(); err != nil { + return err + } + } + if agent.Hooks != nil { + if err := agent.Hooks.validate(); err != nil { + return err + } + } + } + + return nil +} + +// validateFallback validates the fallback configuration for an agent +func (a *AgentConfig) validateFallback() error { + if a.Fallback == nil { + return nil + } + + // -1 is allowed as a special value meaning "explicitly no retries" + if a.Fallback.Retries < -1 { + return errors.New("fallback.retries must be >= -1 (use -1 for no retries, 0 for default)") + } + if a.Fallback.Cooldown.Duration < 0 { + return errors.New("fallback.cooldown must be non-negative") + } + + return nil +} + +func (t *Toolset) validate() error { + // Attributes used on the wrong toolset type. + if len(t.Shell) > 0 && t.Type != "script" { + return errors.New("shell can only be used with type 'script'") + } + if t.Path != "" && t.Type != "memory" && t.Type != "tasks" { + return errors.New("path can only be used with type 'memory' or 'tasks'") + } + if len(t.PostEdit) > 0 && t.Type != "filesystem" { + return errors.New("post_edit can only be used with type 'filesystem'") + } + if t.IgnoreVCS != nil && t.Type != "filesystem" { + return errors.New("ignore_vcs can only be used with type 'filesystem'") + } + if len(t.Env) > 0 && (t.Type != "shell" && t.Type != "script" && t.Type != "mcp" && t.Type != "lsp") { + return errors.New("env can only be used with type 'shell', 'script', 'mcp' or 'lsp'") + } + if t.Sandbox != nil && t.Type != "shell" { + return errors.New("sandbox can only be used with type 'shell'") + } + if t.Shared && t.Type != "todo" { + return errors.New("shared can only be used with type 'todo'") + } + if t.Command != "" && t.Type != "mcp" && t.Type != "lsp" { + return errors.New("command can only be used with type 'mcp' or 'lsp'") + } + if len(t.Args) > 0 && t.Type != "mcp" && t.Type != "lsp" { + return errors.New("args can only be used with type 'mcp' or 'lsp'") + } + if t.Ref != "" && t.Type != "mcp" { + return errors.New("ref can only be used with type 'mcp'") + } + if (t.Remote.URL != "" || t.Remote.TransportType != "") && t.Type != "mcp" { + return errors.New("remote can only be used with type 'mcp'") + } + if (len(t.Remote.Headers) > 0) && (t.Type != "mcp" && t.Type != "a2a") { + return errors.New("remote headers can only be used with type 'mcp' or 'a2a'") + } + if len(t.Headers) > 0 && t.Type != "openapi" && t.Type != "a2a" { + return errors.New("headers can only be used with type 'openapi' or 'a2a'") + } + if t.Config != nil && t.Type != "mcp" { + return errors.New("config can only be used with type 'mcp'") + } + if t.URL != "" && t.Type != "a2a" && t.Type != "openapi" { + return errors.New("url can only be used with type 'a2a' or 'openapi'") + } + if t.Name != "" && (t.Type != "mcp" && t.Type != "a2a") { + return errors.New("name can only be used with type 'mcp' or 'a2a'") + } + + switch t.Type { + case "shell": + if t.Sandbox != nil && len(t.Sandbox.Paths) == 0 { + return errors.New("sandbox requires at least one path to be set") + } + case "memory": + if t.Path == "" { + return errors.New("memory toolset requires a path to be set") + } + case "tasks": + // path defaults to ./tasks.json if not set + case "mcp": + count := 0 + if t.Command != "" { + count++ + } + if t.Remote.URL != "" { + count++ + } + if t.Ref != "" { + count++ + } + if count == 0 { + return errors.New("either command, remote or ref must be set") + } + if count > 1 { + return errors.New("either command, remote or ref must be set, but only one of those") + } + + if t.Ref != "" && !strings.Contains(t.Ref, "docker:") { + return errors.New("only docker refs are supported for MCP tools, e.g., 'docker:context7'") + } + case "a2a": + if t.URL == "" { + return errors.New("a2a toolset requires a url to be set") + } + case "lsp": + if t.Command == "" { + return errors.New("lsp toolset requires a command to be set") + } + case "openapi": + if t.URL == "" { + return errors.New("openapi toolset requires a url to be set") + } + } + + return nil +} diff --git a/pkg/config/v5/validate_test.go b/pkg/config/v5/validate_test.go new file mode 100644 index 000000000..16f45a9cb --- /dev/null +++ b/pkg/config/v5/validate_test.go @@ -0,0 +1,191 @@ +package v5 + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/stretchr/testify/require" +) + +func TestToolset_Validate_LSP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "valid lsp with command", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp + command: gopls +`, + wantErr: "", + }, + { + name: "lsp missing command", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp +`, + wantErr: "lsp toolset requires a command to be set", + }, + { + name: "lsp with args", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp + command: gopls + args: + - -remote=auto +`, + wantErr: "", + }, + { + name: "lsp with env", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp + command: gopls + env: + GOFLAGS: "-mod=vendor" +`, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cfg Config + err := yaml.Unmarshal([]byte(tt.config), &cfg) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestToolset_Validate_Sandbox(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "valid shell with sandbox", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + sandbox: + image: alpine:latest + paths: + - . + - /tmp +`, + wantErr: "", + }, + { + name: "shell sandbox with readonly path", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + sandbox: + paths: + - ./:rw + - /config:ro +`, + wantErr: "", + }, + { + name: "shell sandbox without paths", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + sandbox: + image: alpine:latest +`, + wantErr: "sandbox requires at least one path to be set", + }, + { + name: "sandbox on non-shell toolset", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: filesystem + sandbox: + paths: + - . +`, + wantErr: "sandbox can only be used with type 'shell'", + }, + { + name: "shell without sandbox is valid", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell +`, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cfg Config + err := yaml.Unmarshal([]byte(tt.config), &cfg) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/config/versions.go b/pkg/config/versions.go index 90b9cfe50..d33a090d6 100644 --- a/pkg/config/versions.go +++ b/pkg/config/versions.go @@ -7,28 +7,20 @@ import ( v2 "github.com/docker/cagent/pkg/config/v2" v3 "github.com/docker/cagent/pkg/config/v3" v4 "github.com/docker/cagent/pkg/config/v4" + v5 "github.com/docker/cagent/pkg/config/v5" ) -func Parsers() map[string]func([]byte) (any, error) { - return map[string]func([]byte) (any, error){ - v0.Version: func(d []byte) (any, error) { return v0.Parse(d) }, - v1.Version: func(d []byte) (any, error) { return v1.Parse(d) }, - v2.Version: func(d []byte) (any, error) { return v2.Parse(d) }, - v3.Version: func(d []byte) (any, error) { return v3.Parse(d) }, - v4.Version: func(d []byte) (any, error) { return v4.Parse(d) }, +func versions() (map[string]func([]byte) (any, error), []func(any, []byte) (any, error)) { + parsers := map[string]func([]byte) (any, error){} + var upgraders []func(any, []byte) (any, error) - latest.Version: func(d []byte) (any, error) { return latest.Parse(d) }, - } -} - -func Upgrades() []func(any, []byte) (any, error) { - return []func(any, []byte) (any, error){ - v0.UpgradeIfNeeded, - v1.UpgradeIfNeeded, - v2.UpgradeIfNeeded, - v3.UpgradeIfNeeded, - v4.UpgradeIfNeeded, + v0.Register(parsers, &upgraders) + v1.Register(parsers, &upgraders) + v2.Register(parsers, &upgraders) + v3.Register(parsers, &upgraders) + v4.Register(parsers, &upgraders) + v5.Register(parsers, &upgraders) + latest.Register(parsers, &upgraders) - latest.UpgradeIfNeeded, - } + return parsers, upgraders }