diff --git a/docs/server/docs.go b/docs/server/docs.go index 9c565b583f..7a1d649488 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -2033,6 +2033,9 @@ const docTemplate = `{ "description": "Port for the HTTP proxy to listen on", "type": "integer" }, + "runtime_config": { + "$ref": "#/components/schemas/templates.RuntimeConfig" + }, "secrets": { "description": "Secret parameters to inject", "items": { @@ -2609,6 +2612,9 @@ const docTemplate = `{ "description": "Port for the HTTP proxy to listen on", "type": "integer" }, + "runtime_config": { + "$ref": "#/components/schemas/templates.RuntimeConfig" + }, "secrets": { "description": "Secret parameters to inject", "items": { diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 1bb019e592..c41cbf5031 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -2026,6 +2026,9 @@ "description": "Port for the HTTP proxy to listen on", "type": "integer" }, + "runtime_config": { + "$ref": "#/components/schemas/templates.RuntimeConfig" + }, "secrets": { "description": "Secret parameters to inject", "items": { @@ -2602,6 +2605,9 @@ "description": "Port for the HTTP proxy to listen on", "type": "integer" }, + "runtime_config": { + "$ref": "#/components/schemas/templates.RuntimeConfig" + }, "secrets": { "description": "Secret parameters to inject", "items": { diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index f43cda5079..b8829759ec 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -1833,6 +1833,8 @@ components: proxy_port: description: Port for the HTTP proxy to listen on type: integer + runtime_config: + $ref: '#/components/schemas/templates.RuntimeConfig' secrets: description: Secret parameters to inject items: @@ -2262,6 +2264,8 @@ components: proxy_port: description: Port for the HTTP proxy to listen on type: integer + runtime_config: + $ref: '#/components/schemas/templates.RuntimeConfig' secrets: description: Secret parameters to inject items: diff --git a/pkg/api/v1/workload_service.go b/pkg/api/v1/workload_service.go index 0c8b607144..1504456978 100644 --- a/pkg/api/v1/workload_service.go +++ b/pkg/api/v1/workload_service.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "strings" "time" groupval "github.com/stacklok/toolhive-core/validation/group" @@ -14,6 +15,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth/remote" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/container/templates" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" @@ -162,6 +164,7 @@ func (s *WorkloadService) BuildFullRunConfig( var imageURL string var imageMetadata *regtypes.ImageMetadata var serverMetadata regtypes.ServerMetadata + runtimeConfigOverride := runtimeConfigFromRequest(req) if req.URL != "" { // Configure remote authentication if OAuth config is provided @@ -180,8 +183,8 @@ func (s *WorkloadService) BuildFullRunConfig( req.Image, "", // We do not let the user specify a CA cert path here. retriever.VerifyImageWarn, - "", // TODO Add support for registry groups lookups for API - nil, // No runtime override from API (yet) + "", // TODO Add support for registry groups lookups for API + runtimeConfigOverride, ) if err != nil { // Check if the error is due to context timeout @@ -272,6 +275,11 @@ func (s *WorkloadService) BuildFullRunConfig( runner.WithTelemetryConfigFromFlags("", false, false, false, "", 0.0, nil, false, nil, false), } + // Runtime overrides only apply to protocol-scheme image builds. + if runtimeConfigOverride != nil && req.URL == "" { + options = append(options, runner.WithRuntimeConfig(runtimeConfigOverride)) + } + // Add header forward configuration if specified if req.HeaderForward != nil { if len(req.HeaderForward.AddPlaintextHeaders) > 0 { @@ -361,6 +369,29 @@ func createRequestToRemoteAuthConfig( return remoteAuthConfig } +func runtimeConfigFromRequest(req *createRequest) *templates.RuntimeConfig { + if req == nil || req.RuntimeConfig == nil { + return nil + } + + runtimeConfig := &templates.RuntimeConfig{} + if builderImage := strings.TrimSpace(req.RuntimeConfig.BuilderImage); builderImage != "" { + runtimeConfig.BuilderImage = builderImage + } + if len(req.RuntimeConfig.AdditionalPackages) > 0 { + for _, pkg := range req.RuntimeConfig.AdditionalPackages { + if trimmedPkg := strings.TrimSpace(pkg); trimmedPkg != "" { + runtimeConfig.AdditionalPackages = append(runtimeConfig.AdditionalPackages, trimmedPkg) + } + } + } + if runtimeConfig.BuilderImage == "" && len(runtimeConfig.AdditionalPackages) == 0 { + return nil + } + + return runtimeConfig +} + // GetWorkloadNamesFromRequest gets workload names from either the names field or group func (s *WorkloadService) GetWorkloadNamesFromRequest(ctx context.Context, req bulkOperationRequest) ([]string, error) { if len(req.Names) > 0 { diff --git a/pkg/api/v1/workload_service_test.go b/pkg/api/v1/workload_service_test.go index 20b6918048..c39bf09a9d 100644 --- a/pkg/api/v1/workload_service_test.go +++ b/pkg/api/v1/workload_service_test.go @@ -13,6 +13,7 @@ import ( "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/config" + "github.com/stacklok/toolhive/pkg/container/templates" groupsmocks "github.com/stacklok/toolhive/pkg/groups/mocks" workloadsmocks "github.com/stacklok/toolhive/pkg/workloads/mocks" ) @@ -150,3 +151,87 @@ func TestNewWorkloadService(t *testing.T) { service := NewWorkloadService(nil, nil, nil, false) require.NotNil(t, service) } + +func TestRuntimeConfigFromRequest(t *testing.T) { + t.Parallel() + + t.Run("nil request", func(t *testing.T) { + t.Parallel() + assert.Nil(t, runtimeConfigFromRequest(nil)) + }) + + t.Run("nil runtime config", func(t *testing.T) { + t.Parallel() + req := &createRequest{} + assert.Nil(t, runtimeConfigFromRequest(req)) + }) + + t.Run("empty runtime config returns nil", func(t *testing.T) { + t.Parallel() + + req := &createRequest{ + updateRequest: updateRequest{ + RuntimeConfig: &templates.RuntimeConfig{ + BuilderImage: " ", + AdditionalPackages: []string{"", " "}, + }, + }, + } + + assert.Nil(t, runtimeConfigFromRequest(req)) + }) + + t.Run("trims builder image", func(t *testing.T) { + t.Parallel() + + req := &createRequest{ + updateRequest: updateRequest{ + RuntimeConfig: &templates.RuntimeConfig{ + BuilderImage: " golang:1.24-alpine ", + }, + }, + } + + result := runtimeConfigFromRequest(req) + require.NotNil(t, result) + assert.Equal(t, "golang:1.24-alpine", result.BuilderImage) + }) + + t.Run("trims and filters additional packages", func(t *testing.T) { + t.Parallel() + + req := &createRequest{ + updateRequest: updateRequest{ + RuntimeConfig: &templates.RuntimeConfig{ + AdditionalPackages: []string{" git ", "", " ", "curl"}, + }, + }, + } + + result := runtimeConfigFromRequest(req) + require.NotNil(t, result) + assert.Equal(t, []string{"git", "curl"}, result.AdditionalPackages) + }) + + t.Run("copies runtime config", func(t *testing.T) { + t.Parallel() + + req := &createRequest{ + updateRequest: updateRequest{ + RuntimeConfig: &templates.RuntimeConfig{ + BuilderImage: "golang:1.24-alpine", + AdditionalPackages: []string{"git"}, + }, + }, + } + + result := runtimeConfigFromRequest(req) + require.NotNil(t, result) + assert.Equal(t, "golang:1.24-alpine", result.BuilderImage) + assert.Equal(t, []string{"git"}, result.AdditionalPackages) + + // Verify a copy is made for slice fields. + req.RuntimeConfig.AdditionalPackages[0] = "curl" + assert.Equal(t, []string{"git"}, result.AdditionalPackages) + }) +} diff --git a/pkg/api/v1/workload_types.go b/pkg/api/v1/workload_types.go index 9f69d57563..32cb4ff912 100644 --- a/pkg/api/v1/workload_types.go +++ b/pkg/api/v1/workload_types.go @@ -9,6 +9,7 @@ import ( httpval "github.com/stacklok/toolhive-core/validation/http" "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/container/templates" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/registry/registry" @@ -39,6 +40,8 @@ type workloadStatusResponse struct { type updateRequest struct { // Docker image to use Image string `json:"image"` + // RuntimeConfig allows overriding runtime build configuration for protocol schemes. + RuntimeConfig *templates.RuntimeConfig `json:"runtime_config,omitempty"` // Host to bind to Host string `json:"host"` // Command arguments to pass to the container @@ -295,6 +298,7 @@ func runConfigToCreateRequest(runConfig *runner.RunConfig) *createRequest { return &createRequest{ updateRequest: updateRequest{ Image: runConfig.Image, + RuntimeConfig: runConfig.RuntimeConfig, Host: runConfig.Host, CmdArguments: runConfig.CmdArgs, TargetPort: runConfig.TargetPort, diff --git a/pkg/api/v1/workloads_test.go b/pkg/api/v1/workloads_test.go index 579a8a9c5e..58bcf1e01c 100644 --- a/pkg/api/v1/workloads_test.go +++ b/pkg/api/v1/workloads_test.go @@ -104,11 +104,13 @@ func TestCreateWorkload(t *testing.T) { logger.Initialize() tests := []struct { - name string - requestBody string - setupMock func(*testing.T, *workloadsmocks.MockManager, *runtimemocks.MockRuntime, *groupsmocks.MockManager) - expectedStatus int - expectedBody string + name string + requestBody string + setupMock func(*testing.T, *workloadsmocks.MockManager, *runtimemocks.MockRuntime, *groupsmocks.MockManager) + expectedServerOrImage string + expectedRuntimeConfig *templates.RuntimeConfig + expectedStatus int + expectedBody string }{ { name: "invalid JSON", @@ -137,6 +139,28 @@ func TestCreateWorkload(t *testing.T) { expectedStatus: http.StatusBadRequest, expectedBody: "Invalid proxy_mode", }, + { + name: "with runtime config override", + requestBody: `{"name": "test-workload", "image": "go://github.com/example/server", "runtime_config": {"builder_image": "golang:1.24-alpine", "additional_packages": ["ca-certificates"]}}`, + setupMock: func(_ *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) { + wm.EXPECT().DoesWorkloadExist(gomock.Any(), "test-workload").Return(false, nil) + gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil) + wm.EXPECT().RunWorkloadDetached(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, runConfig *runner.RunConfig) error { + assert.NotNil(t, runConfig.RuntimeConfig) + assert.Equal(t, "golang:1.24-alpine", runConfig.RuntimeConfig.BuilderImage) + assert.Equal(t, []string{"ca-certificates"}, runConfig.RuntimeConfig.AdditionalPackages) + return nil + }) + }, + expectedRuntimeConfig: &templates.RuntimeConfig{ + BuilderImage: "golang:1.24-alpine", + AdditionalPackages: []string{"ca-certificates"}, + }, + expectedServerOrImage: "go://github.com/example/server", + expectedStatus: http.StatusCreated, + expectedBody: "test-workload", + }, { name: "with tool filters", requestBody: `{"name": "test-workload", "image": "test-image", "tools": ["filter1", "filter2"]}`, @@ -212,12 +236,17 @@ func TestCreateWorkload(t *testing.T) { mockGroupManager := groupsmocks.NewMockManager(ctrl) tt.setupMock(t, mockWorkloadManager, mockRuntime, mockGroupManager) + expectedServerOrImage := tt.expectedServerOrImage + if expectedServerOrImage == "" { + expectedServerOrImage = "test-image" + } mockRetriever := makeMockRetriever(t, - "test-image", + expectedServerOrImage, "test-image", ®types.ImageMetadata{Image: "test-image"}, nil, + tt.expectedRuntimeConfig, ) routes := &WorkloadRoutes{ @@ -411,6 +440,7 @@ func TestUpdateWorkload(t *testing.T) { "test-image", ®types.ImageMetadata{Image: "test-image"}, nil, + nil, ) routes := &WorkloadRoutes{ @@ -556,6 +586,7 @@ func TestUpdateWorkload_PortReuse(t *testing.T) { "test-image", ®types.ImageMetadata{Image: "test-image"}, nil, + nil, ) routes := &WorkloadRoutes{ @@ -595,12 +626,14 @@ func makeMockRetriever( returnedImage string, returnedServerMetadata regtypes.ServerMetadata, returnedError error, + expectedRuntimeConfig *templates.RuntimeConfig, ) retriever.Retriever { t.Helper() - return func(_ context.Context, serverOrImage string, _ string, verificationType string, _ string, _ *templates.RuntimeConfig) (string, regtypes.ServerMetadata, error) { + return func(_ context.Context, serverOrImage string, _ string, verificationType string, _ string, runtimeConfig *templates.RuntimeConfig) (string, regtypes.ServerMetadata, error) { assert.Equal(t, expectedServerOrImage, serverOrImage) assert.Equal(t, retriever.VerifyImageWarn, verificationType) + assert.Equal(t, expectedRuntimeConfig, runtimeConfig) return returnedImage, returnedServerMetadata, returnedError } } diff --git a/pkg/api/v1/workloads_types_test.go b/pkg/api/v1/workloads_types_test.go index 1b41bf3f5e..5e00790eac 100644 --- a/pkg/api/v1/workloads_types_test.go +++ b/pkg/api/v1/workloads_types_test.go @@ -12,6 +12,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/auth/remote" + "github.com/stacklok/toolhive/pkg/container/templates" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/runner" "github.com/stacklok/toolhive/pkg/secrets" @@ -400,6 +401,25 @@ func TestRunConfigToCreateRequest(t *testing.T) { assert.Empty(t, result.ToolsOverride["read"].Description) }) + t.Run("with runtime config", func(t *testing.T) { + t.Parallel() + + runConfig := &runner.RunConfig{ + Name: "test-workload", + RuntimeConfig: &templates.RuntimeConfig{ + BuilderImage: "node:20-alpine", + AdditionalPackages: []string{"git"}, + }, + } + + result := runConfigToCreateRequest(runConfig) + + require.NotNil(t, result) + require.NotNil(t, result.RuntimeConfig) + assert.Equal(t, "node:20-alpine", result.RuntimeConfig.BuilderImage) + assert.Equal(t, []string{"git"}, result.RuntimeConfig.AdditionalPackages) + }) + t.Run("nil runConfig", func(t *testing.T) { t.Parallel() diff --git a/pkg/runner/protocol.go b/pkg/runner/protocol.go index 4824cf86b7..f47bd8715e 100644 --- a/pkg/runner/protocol.go +++ b/pkg/runner/protocol.go @@ -157,20 +157,44 @@ func loadRuntimeConfig( transportType templates.TransportType, override *templates.RuntimeConfig, ) *templates.RuntimeConfig { - // If override is provided, use it - if override != nil { - return override + // Resolve base config from user configuration or defaults. + baseConfig := getBaseRuntimeConfig(transportType) + if override == nil { + return baseConfig } + // Merge overrides into base config so omitted fields inherit sane defaults. + merged := &templates.RuntimeConfig{ + BuilderImage: baseConfig.BuilderImage, + AdditionalPackages: append([]string{}, baseConfig.AdditionalPackages...), + } + + if strings.TrimSpace(override.BuilderImage) != "" { + merged.BuilderImage = strings.TrimSpace(override.BuilderImage) + } + if len(override.AdditionalPackages) > 0 { + merged.AdditionalPackages = append(merged.AdditionalPackages, override.AdditionalPackages...) + } + + return merged +} + +func getBaseRuntimeConfig(transportType templates.TransportType) *templates.RuntimeConfig { // Try loading from user config provider := config.NewProvider() if userConfig, err := provider.GetRuntimeConfig(string(transportType)); err == nil && userConfig != nil { - return userConfig + return &templates.RuntimeConfig{ + BuilderImage: userConfig.BuilderImage, + AdditionalPackages: append([]string{}, userConfig.AdditionalPackages...), + } } // Fall back to defaults defaultConfig := templates.GetDefaultRuntimeConfig(transportType) - return &defaultConfig + return &templates.RuntimeConfig{ + BuilderImage: defaultConfig.BuilderImage, + AdditionalPackages: append([]string{}, defaultConfig.AdditionalPackages...), + } } // addBuildEnvToTemplate loads build environment variables from config and adds them to template data. diff --git a/pkg/runner/protocol_test.go b/pkg/runner/protocol_test.go index f90fe50bb8..c39daec415 100644 --- a/pkg/runner/protocol_test.go +++ b/pkg/runner/protocol_test.go @@ -5,6 +5,7 @@ package runner import ( "context" + "reflect" "strings" "testing" @@ -442,3 +443,59 @@ func TestCreateTemplateData(t *testing.T) { }) } } + +func TestLoadRuntimeConfig_MergesMissingOverrideFields(t *testing.T) { + t.Parallel() + + base := loadRuntimeConfig(templates.TransportTypeGO, nil) + if base == nil { + t.Fatal("loadRuntimeConfig returned nil base config") + return + } + if base.BuilderImage == "" { + t.Fatal("base runtime config has empty builder image") + } + + override := &templates.RuntimeConfig{ + AdditionalPackages: []string{"curl"}, + } + got := loadRuntimeConfig(templates.TransportTypeGO, override) + if got == nil { + t.Fatal("loadRuntimeConfig returned nil merged config") + return + } + + // Missing builder image in override should inherit the base builder image. + if got.BuilderImage != base.BuilderImage { + t.Fatalf("BuilderImage = %q, want base %q", got.BuilderImage, base.BuilderImage) + } + + // Additional packages should be appended to base defaults. + expectedPackages := append([]string{}, base.AdditionalPackages...) + expectedPackages = append(expectedPackages, "curl") + if !reflect.DeepEqual(got.AdditionalPackages, expectedPackages) { + t.Fatalf("AdditionalPackages = %v, want %v", got.AdditionalPackages, expectedPackages) + } + + // Ensure merged config is detached from input slices. + override.AdditionalPackages[0] = "git" + if got.AdditionalPackages[len(got.AdditionalPackages)-1] != "curl" { + t.Fatalf("AdditionalPackages mutated via override input: got %v", got.AdditionalPackages) + } +} + +func TestLoadRuntimeConfig_UsesOverrideBuilderImage(t *testing.T) { + t.Parallel() + + customImage := "golang:1.24-alpine" + got := loadRuntimeConfig(templates.TransportTypeGO, &templates.RuntimeConfig{ + BuilderImage: customImage, + }) + if got == nil { + t.Fatal("loadRuntimeConfig returned nil config") + return + } + if got.BuilderImage != customImage { + t.Fatalf("BuilderImage = %q, want %q", got.BuilderImage, customImage) + } +}