diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index a64e6c7ccd..aa77196879 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -45,6 +45,7 @@ type RunFlags struct { ProxyPort int TargetPort int TargetHost string + Publish []string // Server configuration Name string @@ -154,6 +155,8 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) { "target-host", transport.LocalhostIPv4, "Host to forward traffic to (only applicable to SSE or Streamable HTTP transport)") + cmd.Flags().StringArrayVarP(&config.Publish, "publish", "p", []string{}, + "Publish a container's port(s) to the host (format: hostPort:containerPort)") cmd.Flags().StringVar( &config.PermissionProfile, "permission-profile", @@ -596,6 +599,7 @@ func buildRunnerConfig( LoadGlobal: runFlags.IgnoreGlobally, PrintOverlays: runFlags.PrintOverlays, }), + runner.WithPublish(runFlags.Publish), } // Load tools override configuration diff --git a/docs/cli/thv_run.md b/docs/cli/thv_run.md index 21d1cb0de8..3eac04527c 100644 --- a/docs/cli/thv_run.md +++ b/docs/cli/thv_run.md @@ -155,6 +155,7 @@ thv run [flags] SERVER_OR_IMAGE_OR_PROTOCOL [-- ARGS...] --print-resolved-overlays Debug: show resolved container paths for tmpfs overlays (default false) --proxy-mode string Proxy mode for stdio (streamable-http or sse (deprecated, will be removed)) (default "streamable-http") --proxy-port int Port for the HTTP proxy to listen on (host port) + -p, --publish stringArray Publish a container's port(s) to the host (format: hostPort:containerPort) --remote-auth Enable OAuth/OIDC authentication to remote MCP server (default false) --remote-auth-authorize-url string OAuth authorization endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) --remote-auth-bearer-token string Bearer token for remote server authentication (alternative to OAuth) diff --git a/docs/server/docs.go b/docs/server/docs.go index 045231194e..1c7fe7d7b6 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -882,6 +882,14 @@ const docTemplate = `{ "proxy_mode": { "$ref": "#/components/schemas/types.ProxyMode" }, + "publish": { + "description": "Publish lists ports to publish to the host in format \"hostPort:containerPort\"", + "items": { + "type": "string" + }, + "type": "array", + "uniqueItems": false + }, "remote_auth_config": { "$ref": "#/components/schemas/remote.Config" }, diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 12a7021d89..0f9fa974eb 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -875,6 +875,14 @@ "proxy_mode": { "$ref": "#/components/schemas/types.ProxyMode" }, + "publish": { + "description": "Publish lists ports to publish to the host in format \"hostPort:containerPort\"", + "items": { + "type": "string" + }, + "type": "array", + "uniqueItems": false + }, "remote_auth_config": { "$ref": "#/components/schemas/remote.Config" }, diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 4977b81609..839d620bca 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -837,6 +837,12 @@ components: type: integer proxy_mode: $ref: '#/components/schemas/types.ProxyMode' + publish: + description: Publish lists ports to publish to the host in format "hostPort:containerPort" + items: + type: string + type: array + uniqueItems: false remote_auth_config: $ref: '#/components/schemas/remote.Config' remote_url: diff --git a/pkg/container/docker/client.go b/pkg/container/docker/client.go index 8652e896a1..91f7c91a15 100644 --- a/pkg/container/docker/client.go +++ b/pkg/container/docker/client.go @@ -1619,7 +1619,7 @@ func generatePortBindings(labels map[string]string, portBindings map[string][]runtime.PortBinding) (map[string][]runtime.PortBinding, int, error) { var hostPort int // check if we need to map to a random port of not - if _, ok := labels["toolhive-auxiliary"]; ok && labels["toolhive-auxiliary"] == "true" { + if _, ok := labels[ToolhiveAuxiliaryWorkloadLabel]; ok && labels[ToolhiveAuxiliaryWorkloadLabel] == LabelValueTrue { // find first port var err error for _, bindings := range portBindings { @@ -1633,17 +1633,25 @@ func generatePortBindings(labels map[string]string, } } } else { - // bind to a random host port - hostPort = networking.FindAvailable() - if hostPort == 0 { - return nil, 0, fmt.Errorf("could not find an available port") - } - // first port binding needs to map to the host port + // For consistency, we only use FindAvailable for the primary port if it's not already set for key, bindings := range portBindings { if len(bindings) > 0 { - bindings[0].HostPort = fmt.Sprintf("%d", hostPort) - portBindings[key] = bindings + hostPortStr := bindings[0].HostPort + if hostPortStr == "" || hostPortStr == "0" { + hostPort = networking.FindAvailable() + if hostPort == 0 { + return nil, 0, fmt.Errorf("could not find an available port") + } + bindings[0].HostPort = fmt.Sprintf("%d", hostPort) + portBindings[key] = bindings + } else { + var err error + hostPort, err = strconv.Atoi(hostPortStr) + if err != nil { + return nil, 0, fmt.Errorf("failed to convert host port %s to int: %w", hostPortStr, err) + } + } break } } diff --git a/pkg/container/docker/client_helpers_test.go b/pkg/container/docker/client_helpers_test.go index 477f8fce4c..7a5de202a1 100644 --- a/pkg/container/docker/client_helpers_test.go +++ b/pkg/container/docker/client_helpers_test.go @@ -118,6 +118,43 @@ func TestGeneratePortBindings_NonAuxiliaryAssignsRandomPortAndMutatesFirstBindin assert.Equal(t, 1, countMatches, "expected exactly one first binding to be updated to hostPort=%s", expected) } +func TestGeneratePortBindings_NonAuxiliaryKeepsExplicitHostPort(t *testing.T) { + t.Parallel() + + labels := map[string]string{} // not auxiliary + in := map[string][]runtime.PortBinding{ + "8080/tcp": { + {HostIP: "", HostPort: "9090"}, + }, + } + out, hostPort, err := generatePortBindings(labels, in) + require.NoError(t, err) + require.Equal(t, 9090, hostPort) + + require.Contains(t, out, "8080/tcp") + require.Len(t, out["8080/tcp"], 1) + assert.Equal(t, "9090", out["8080/tcp"][0].HostPort) +} + +func TestGeneratePortBindings_NonAuxiliaryAssignsRandomPortForZero(t *testing.T) { + t.Parallel() + + labels := map[string]string{} // not auxiliary + in := map[string][]runtime.PortBinding{ + "8080/tcp": { + {HostIP: "", HostPort: "0"}, + }, + } + out, hostPort, err := generatePortBindings(labels, in) + require.NoError(t, err) + require.NotZero(t, hostPort) + + require.Contains(t, out, "8080/tcp") + require.Len(t, out["8080/tcp"], 1) + assert.NotEqual(t, "0", out["8080/tcp"][0].HostPort) + assert.Equal(t, fmt.Sprintf("%d", hostPort), out["8080/tcp"][0].HostPort) +} + func TestAddEgressEnvVars_SetsAll(t *testing.T) { t.Parallel() diff --git a/pkg/networking/port.go b/pkg/networking/port.go index fc79adb2ad..54092cc123 100644 --- a/pkg/networking/port.go +++ b/pkg/networking/port.go @@ -11,6 +11,8 @@ import ( "log/slog" "math/big" "net" + "strconv" + "strings" gopsutilnet "github.com/shirou/gopsutil/v4/net" ) @@ -180,6 +182,48 @@ func IsPreRegisteredClient(clientID string) bool { return clientID != "" } +// ParsePortSpec parses a port specification string in the format "hostPort:containerPort" or just "containerPort". +// Returns the host port string and container port integer. +// If only a container port is provided, a random available host port is selected. +func ParsePortSpec(portSpec string) (string, int, error) { + slog.Debug("Parsing port spec", "spec", portSpec) + // Check if it's in host:container format + if strings.Contains(portSpec, ":") { + parts := strings.Split(portSpec, ":") + if len(parts) != 2 { + return "", 0, fmt.Errorf("invalid port specification: %s (expected 'hostPort:containerPort')", portSpec) + } + + hostPortStr := parts[0] + containerPortStr := parts[1] + + // Verify host port is a valid integer (or empty string if we supported random host port with :, but here we expect explicit) + if _, err := strconv.Atoi(hostPortStr); err != nil { + return "", 0, fmt.Errorf("invalid host port in spec '%s': %w", portSpec, err) + } + + containerPort, err := strconv.Atoi(containerPortStr) + if err != nil { + return "", 0, fmt.Errorf("invalid container port in spec '%s': %w", portSpec, err) + } + + return hostPortStr, containerPort, nil + } + + // Try parsing as just container port + containerPort, err := strconv.Atoi(portSpec) + if err == nil { + // Find a random available host port + hostPort := FindAvailable() + if hostPort == 0 { + return "", 0, fmt.Errorf("could not find an available port for container port %d", containerPort) + } + return fmt.Sprintf("%d", hostPort), containerPort, nil + } + + return "", 0, fmt.Errorf("invalid port specification: %s (expected 'hostPort:containerPort' or 'containerPort')", portSpec) +} + // GetProcessOnPort returns the PID of the process listening on the given TCP port. // Returns 0 if the port is free or if the holder cannot be determined. // Uses gopsutil which provides cross-platform support (Linux: /proc, Windows: GetExtendedTcpTable, diff --git a/pkg/networking/port_test.go b/pkg/networking/port_test.go index 6a88f3cf4d..f51b919d42 100644 --- a/pkg/networking/port_test.go +++ b/pkg/networking/port_test.go @@ -71,20 +71,82 @@ func TestValidateCallbackPort(t *testing.T) { err := networking.ValidateCallbackPort(tt.port, tt.clientID) if tt.wantError { - if err == nil { - t.Errorf("ValidateCallbackPort() expected error but got nil") - } else if tt.errorMsg != "" && err.Error() != tt.errorMsg { - t.Errorf("ValidateCallbackPort() error = %v, want %v", err.Error(), tt.errorMsg) + require.Error(t, err) + if tt.errorMsg != "" { + require.EqualError(t, err, tt.errorMsg) } } else { - if err != nil { - t.Errorf("ValidateCallbackPort() unexpected error = %v", err) - } + require.NoError(t, err) } }) } } +func TestParsePortSpec(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + portSpec string + expectedHostPort string + expectedContainer int + wantError bool + }{ + { + name: "host:container", + portSpec: "8003:8001", + expectedHostPort: "8003", + expectedContainer: 8001, + wantError: false, + }, + { + name: "container only", + portSpec: "8001", + expectedHostPort: "", // Random + expectedContainer: 8001, + wantError: false, + }, + { + name: "invalid format", + portSpec: "invalid", + expectedHostPort: "", + expectedContainer: 0, + wantError: true, + }, + { + name: "invalid host port", + portSpec: "abc:8001", + expectedHostPort: "", + expectedContainer: 0, + wantError: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + hostPort, containerPort, err := networking.ParsePortSpec(tt.portSpec) + + if tt.wantError { + require.Error(t, err, "ParsePortSpec(%s) expected error", tt.portSpec) + return + } + + require.NoError(t, err, "ParsePortSpec(%s) unexpected error", tt.portSpec) + + if tt.expectedHostPort != "" { + require.Equal(t, tt.expectedHostPort, hostPort, "ParsePortSpec(%s) unexpected host port", tt.portSpec) + } else { + require.NotEmpty(t, hostPort, "ParsePortSpec(%s) hostPort is empty, want random port", tt.portSpec) + } + + require.Equal(t, tt.expectedContainer, containerPort, "ParsePortSpec(%s) unexpected container port", tt.portSpec) + }) + } +} + func TestGetProcessOnPort_InvalidPort(t *testing.T) { t.Parallel() diff --git a/pkg/runner/config.go b/pkg/runner/config.go index 7afc9e2b20..d33f8eb1f0 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -83,6 +83,9 @@ type RunConfig struct { // TargetHost is the host to forward traffic to (only applicable to SSE transport) TargetHost string `json:"target_host,omitempty" yaml:"target_host,omitempty"` + // Publish lists ports to publish to the host in format "hostPort:containerPort" + Publish []string `json:"publish,omitempty" yaml:"publish,omitempty"` + // PermissionProfileNameOrPath is the name or path of the permission profile PermissionProfileNameOrPath string `json:"permission_profile_name_or_path,omitempty" yaml:"permission_profile_name_or_path,omitempty"` //nolint:lll diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index b1373a79bd..e0f80e060e 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -174,6 +174,14 @@ func WithTargetHost(targetHost string) RunConfigBuilderOption { } } +// WithPublish sets the published ports +func WithPublish(publish []string) RunConfigBuilderOption { + return func(b *runConfigBuilder) error { + b.config.Publish = publish + return nil + } +} + // WithDebug sets debug mode func WithDebug(debug bool) RunConfigBuilderOption { return func(b *runConfigBuilder) error { diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index af235acfc1..bf4539d7e6 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -306,6 +306,7 @@ func (r *Runner) Run(ctx context.Context) error { r.Config.Host, r.Config.TargetPort, r.Config.TargetHost, + r.Config.Publish, ) if err != nil { return fmt.Errorf("failed to set up workload: %w", err) diff --git a/pkg/runtime/setup.go b/pkg/runtime/setup.go index c183cfb35f..1bdf77b683 100644 --- a/pkg/runtime/setup.go +++ b/pkg/runtime/setup.go @@ -13,6 +13,7 @@ import ( "github.com/stacklok/toolhive-core/permissions" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/ignore" + "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -50,6 +51,7 @@ func Setup( host string, targetPort int, targetHost string, + publishedPorts []string, ) (*SetupResult, error) { // Add transport-specific environment variables env, ok := transportEnvMap[transportType] @@ -74,6 +76,26 @@ func Setup( containerOptions.K8sPodTemplatePatch = k8sPodTemplatePatch containerOptions.IgnoreConfig = ignoreConfig + // Process published ports + for _, portSpec := range publishedPorts { + hostPort, containerPort, err := networking.ParsePortSpec(portSpec) + if err != nil { + return nil, fmt.Errorf("failed to parse published port '%s': %w", portSpec, err) + } + + // Add to exposed ports + containerPortStr := fmt.Sprintf("%d/tcp", containerPort) + containerOptions.ExposedPorts[containerPortStr] = struct{}{} + + // Add to port bindings + // Check if we already have bindings for this port + bindings := containerOptions.PortBindings[containerPortStr] + bindings = append(bindings, rt.PortBinding{ + HostPort: hostPort, + }) + containerOptions.PortBindings[containerPortStr] = bindings + } + if transportType == types.TransportTypeStdio { containerOptions.AttachStdio = true } else { @@ -90,7 +112,13 @@ func Setup( } // Set the port bindings - containerOptions.PortBindings[containerPortStr] = portBindings + // Note: if the user explicitly publishes the target port using --publish, + // we append the default transport binding to the list of bindings for that port. + if _, ok := containerOptions.PortBindings[containerPortStr]; ok { + containerOptions.PortBindings[containerPortStr] = append(containerOptions.PortBindings[containerPortStr], portBindings...) + } else { + containerOptions.PortBindings[containerPortStr] = portBindings + } } // Create the container