diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md
index a6cccd2a5cb..9d0c9c2f634 100644
--- a/NEXT_CHANGELOG.md
+++ b/NEXT_CHANGELOG.md
@@ -5,6 +5,7 @@
### Notable Changes
### CLI
+* `ssh connect` now supports specifying a serverless usage policy with `--usage-policy-id` ([#5781](https://github.com/databricks/cli/pull/5781)).
* `ssh connect` now accepts a `--base-environment` flag to run a serverless session on a custom base environment. It takes an `env.yaml` path, a `workspace-base-environments/...` resource ID, or a base environment display name, and is rejected together with `--environment-version` or `--cluster` ([#5706](https://github.com/databricks/cli/pull/5706)).
* `databricks aitools install` is now plugin-first: it installs the Databricks plugin through each agent's own CLI (Claude Code, Codex, GitHub Copilot) instead of copying raw skill files. Agents without a plugin (OpenCode, Antigravity) still get skill files, and Cursor prints the `/add-plugin databricks` step. Use `--skills-only` to force raw skill files for every agent, or `--path
` to write skills to a directory ([#5738](https://github.com/databricks/cli/pull/5738)).
diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go
index 972ccf4a81a..2c50b871902 100644
--- a/experimental/ssh/cmd/connect.go
+++ b/experimental/ssh/cmd/connect.go
@@ -41,6 +41,7 @@ Connect to a dedicated cluster:
var environmentVersion int
var baseEnvironment string
var autoApprove bool
+ var usagePolicyID string
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks dedicated cluster ID")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
@@ -50,6 +51,7 @@ Connect to a dedicated cluster:
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name to reuse across sessions (serverless only)")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "Serverless GPU accelerator type (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)")
+ cmd.Flags().StringVar(&usagePolicyID, "usage-policy-id", "", "Usage policy ID for the serverless SSH server job (serverless only)")
cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode")
cmd.Flags().MarkHidden("proxy")
@@ -130,6 +132,7 @@ Connect to a dedicated cluster:
BaseEnvironment: baseEnvironment,
AdditionalArgs: args,
AutoApprove: autoApprove,
+ UsagePolicyID: usagePolicyID,
}
if err := opts.Validate(); err != nil {
return err
diff --git a/experimental/ssh/cmd/server.go b/experimental/ssh/cmd/server.go
index 47e16cdc649..21c651b2365 100644
--- a/experimental/ssh/cmd/server.go
+++ b/experimental/ssh/cmd/server.go
@@ -29,6 +29,7 @@ and proxies them to local SSH daemon processes.`,
var secretScopeName string
var authorizedKeySecretName string
var serverless bool
+ var usagePolicyID string
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.MarkFlagRequired("cluster")
@@ -43,6 +44,7 @@ and proxies them to local SSH daemon processes.`,
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients")
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")
cmd.Flags().BoolVar(&serverless, "serverless", false, "Enable serverless mode for Jupyter initialization")
+ cmd.Flags().StringVar(&usagePolicyID, "usage-policy-id", "", "Usage policy ID the job was submitted with")
cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// The server can be executed under a directory with an invalid bundle configuration.
@@ -71,6 +73,7 @@ and proxies them to local SSH daemon processes.`,
DefaultPort: defaultServerPort,
PortRange: serverPortRange,
Serverless: serverless,
+ UsagePolicyID: usagePolicyID,
}
return server.Run(ctx, wsc, opts)
}
diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go
index 9fc20aa56df..cb9919cfa0f 100644
--- a/experimental/ssh/internal/client/client.go
+++ b/experimental/ssh/internal/client/client.go
@@ -119,6 +119,8 @@ type ClientOptions struct {
BaseEnvironment string
// If true, skip confirmation prompts for IDE extension install and IDE settings updates.
AutoApprove bool
+ // Id of the usage policy to use for the serverless SSH server job. Serverless only.
+ UsagePolicyID string
}
func (o *ClientOptions) Validate() error {
@@ -128,6 +130,9 @@ func (o *ClientOptions) Validate() error {
if o.Accelerator != "" && o.ConnectionName == "" {
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
}
+ if o.UsagePolicyID != "" && o.ClusterID != "" {
+ return errors.New("--usage-policy-id flag can only be used with serverless compute (--name flag)")
+ }
if o.Accelerator != "" && o.Accelerator != "GPU_1xA10" && o.Accelerator != "GPU_8xH100" {
return fmt.Errorf("invalid accelerator value: %q, expected %q or %q", o.Accelerator, "GPU_1xA10", "GPU_8xH100")
}
@@ -214,6 +219,9 @@ func (o *ClientOptions) ToProxyCommand() (string, error) {
if o.Accelerator != "" {
proxyCommand += " --accelerator=" + o.Accelerator
}
+ if o.UsagePolicyID != "" {
+ proxyCommand += " --usage-policy-id=" + o.UsagePolicyID
+ }
} else {
proxyCommand = fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s",
executablePath, o.ClusterID, o.AutoStartCluster, o.ShutdownDelay.String())
@@ -463,14 +471,25 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k
return nil
}
+// serverMetadata describes a running SSH server, combining the persisted workspace
+// metadata with the user name validated live via Driver Proxy.
+type serverMetadata struct {
+ Port int
+ UserName string
+ // ClusterID required for Driver Proxy connections. For serverless it comes from the persisted metadata.
+ ClusterID string
+ // UsagePolicyID the server was started with, used to decide whether a running server can be reused.
+ UsagePolicyID string
+}
+
// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy.
// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
// For dedicated clusters, clusterID should be the same as sessionID.
// For serverless, clusterID is read from the workspace metadata.
-func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (int, string, string, error) {
+func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (serverMetadata, error) {
wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID)
if err != nil {
- return 0, "", "", errors.Join(errServerMetadata, err)
+ return serverMetadata{}, errors.Join(errServerMetadata, err)
}
log.Debugf(ctx, "Workspace metadata: %+v", wsMetadata)
@@ -481,33 +500,38 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
}
if effectiveClusterID == "" {
- return 0, "", "", errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata"))
+ return serverMetadata{}, errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata"))
}
req, err := newDriverProxyRequest(ctx, client, effectiveClusterID, wsMetadata.Port, "metadata", liteswap)
if err != nil {
- return 0, "", "", err
+ return serverMetadata{}, err
}
log.Debugf(ctx, "Metadata URL: %s", req.URL)
httpClient := &http.Client{Transport: client.Config.HTTPTransport}
resp, err := httpClient.Do(req)
if err != nil {
- return 0, "", "", err
+ return serverMetadata{}, err
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
- return 0, "", "", err
+ return serverMetadata{}, err
}
log.Debugf(ctx, "Metadata response: %s", string(bodyBytes))
log.Debugf(ctx, "Metadata response status code: %d", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
- return 0, "", "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode))
+ return serverMetadata{}, errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode))
}
- return wsMetadata.Port, string(bodyBytes), effectiveClusterID, nil
+ return serverMetadata{
+ Port: wsMetadata.Port,
+ UserName: string(bodyBytes),
+ ClusterID: effectiveClusterID,
+ UsagePolicyID: wsMetadata.UsagePolicyID,
+ }, nil
}
// newDriverProxyRequest builds an authenticated GET request to one of the SSH server's
@@ -559,36 +583,10 @@ func fetchServerErrorLogs(ctx context.Context, client *databricks.WorkspaceClien
return strings.TrimSpace(string(body))
}
-// submitSSHTunnelJob submits the bootstrap job and waits for the SSH server task to start.
-// It returns the job run ID (when known) so callers can fetch and surface the run's error
-// details if the server never comes up.
-func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) {
+// Assemble the SubmitRun request that bootstraps the SSH server.
+// Extracted from submitSSHTunnelJob so this logic can be unit tested.
+func buildSSHServerSubmitRun(version, secretScopeName, jobNotebookPath, baseEnvironment string, opts ClientOptions) jobs.SubmitRun {
sessionID := opts.SessionIdentifier()
- contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID)
- if err != nil {
- return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
- }
-
- err = client.Workspace.MkdirsByPath(ctx, contentDir)
- if err != nil {
- return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err)
- }
-
- sshTunnelJobName := "ssh-server-bootstrap-" + sessionID
- jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap"))
- notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript
- encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent))
-
- err = client.Workspace.Import(ctx, workspace.Import{
- Path: jobNotebookPath,
- Format: workspace.ImportFormatSource,
- Content: encodedContent,
- Language: workspace.LanguagePython,
- Overwrite: true,
- })
- if err != nil {
- return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err)
- }
baseParams := map[string]string{
"version": version,
@@ -598,10 +596,11 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
"maxClients": strconv.Itoa(opts.MaxClients),
"sessionId": sessionID,
"serverless": strconv.FormatBool(opts.IsServerlessMode()),
+ // Recorded in the server's metadata.json so reconnects can tell which usage policy
+ // the running server was started under.
+ "usagePolicyId": opts.UsagePolicyID,
}
- log.Infof(ctx, "Submitting a job to start the ssh server...")
-
task := jobs.SubmitTask{
TaskKey: sshServerTaskKey,
NotebookTask: &jobs.NotebookTask{
@@ -614,7 +613,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
if opts.IsServerlessMode() {
task.EnvironmentKey = serverlessEnvironmentKey
if opts.Accelerator != "" {
- log.Infof(ctx, "Using accelerator: %s", opts.Accelerator)
task.Compute = &jobs.Compute{
HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator),
}
@@ -624,20 +622,17 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
}
submitRequest := jobs.SubmitRun{
- RunName: sshTunnelJobName,
+ RunName: "ssh-server-bootstrap-" + sessionID,
TimeoutSeconds: int(opts.ServerTimeout.Seconds()),
Tasks: []jobs.SubmitTask{task},
+ BudgetPolicyId: opts.UsagePolicyID,
}
if opts.IsServerlessMode() {
// base_environment and environment_version are mutually exclusive: a custom
// base environment carries its own version, so we don't also set one.
var spec compute.Environment
- if opts.BaseEnvironment != "" {
- baseEnvironment, err := resolveBaseEnvironment(ctx, client, opts.BaseEnvironment)
- if err != nil {
- return 0, err
- }
+ if baseEnvironment != "" {
spec.BaseEnvironment = baseEnvironment
} else {
spec.EnvironmentVersion = strconv.Itoa(max(opts.EnvironmentVersion, minEnvironmentVersion))
@@ -650,6 +645,54 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
}
}
+ return submitRequest
+}
+
+// submitSSHTunnelJob submits the bootstrap job and waits for the SSH server task to start.
+// It returns the job run ID (when known) so callers can fetch and surface the run's error
+// details if the server never comes up.
+func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) {
+ sessionID := opts.SessionIdentifier()
+ contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID)
+ if err != nil {
+ return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
+ }
+
+ err = client.Workspace.MkdirsByPath(ctx, contentDir)
+ if err != nil {
+ return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err)
+ }
+
+ jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap"))
+ notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript
+ encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent))
+
+ err = client.Workspace.Import(ctx, workspace.Import{
+ Path: jobNotebookPath,
+ Format: workspace.ImportFormatSource,
+ Content: encodedContent,
+ Language: workspace.LanguagePython,
+ Overwrite: true,
+ })
+ if err != nil {
+ return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err)
+ }
+
+ log.Infof(ctx, "Submitting a job to start the ssh server...")
+ if opts.IsServerlessMode() && opts.Accelerator != "" {
+ log.Infof(ctx, "Using accelerator: %s", opts.Accelerator)
+ }
+
+ var baseEnvironment string
+ if opts.IsServerlessMode() && opts.BaseEnvironment != "" {
+ baseEnvironment, err = resolveBaseEnvironment(ctx, client, opts.BaseEnvironment)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ submitRequest := buildSSHServerSubmitRun(version, secretScopeName, jobNotebookPath, baseEnvironment, opts)
+
waiter, err := client.Jobs.Submit(ctx, submitRequest)
if err != nil {
return 0, fmt.Errorf("failed to submit job: %w", err)
@@ -1046,18 +1089,31 @@ func hostKeyChangedHint(stderr, hostName, knownHostsFile string) string {
"Remove the stale entry and reconnect:\n " + cmd
}
+func usagePolicyMatches(storedPolicy, requestedPolicy string) bool {
+ return requestedPolicy == "" || storedPolicy == requestedPolicy
+}
+
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) {
sessionID := opts.SessionIdentifier()
// For dedicated clusters, use clusterID; for serverless, it will be read from metadata
clusterID := opts.ClusterID
- serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
- if errors.Is(err, errServerMetadata) {
+ meta, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
+ if err != nil && !errors.Is(err, errServerMetadata) {
+ return "", 0, "", err
+ }
+
+ // Start a new server when none is running, or when the running one was started under a
+ // different usage policy. A job's usage policy is fixed at submission, so we can't retarget
+ // the existing server; the new server overwrites metadata.json and the old one idles out via
+ // shutdownDelay.
+ needNewServer := err != nil || !usagePolicyMatches(meta.UsagePolicyID, opts.UsagePolicyID)
+ if needNewServer {
cmdio.LogString(ctx, "Starting SSH server...")
- runID, err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts)
- if err != nil {
- return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", err)
+ runID, submitErr := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts)
+ if submitErr != nil {
+ return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", submitErr)
}
sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
@@ -1068,7 +1124,13 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
if ctx.Err() != nil {
return "", 0, "", ctx.Err()
}
- serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
+ meta, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
+ // Accept only once metadata reflects the requested usage policy, so we don't latch
+ // onto a server a previous connection started under a different policy before our new
+ // server has overwritten metadata.json.
+ if err == nil && !usagePolicyMatches(meta.UsagePolicyID, opts.UsagePolicyID) {
+ err = fmt.Errorf("found a running SSH server with usage policy %q, waiting for the one with %q", meta.UsagePolicyID, opts.UsagePolicyID)
+ }
if err == nil {
cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...")
break
@@ -1085,11 +1147,9 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
return "", 0, "", fmt.Errorf("failed to start the ssh server: %w\n%s", err, describeRunFailure(ctx, client, runID))
}
}
- } else if err != nil {
- return "", 0, "", err
}
- return userName, serverPort, effectiveClusterID, nil
+ return meta.UserName, meta.Port, meta.ClusterID, nil
}
func logSshTunnelEvent(ctx context.Context, opts ClientOptions, isSuccess, isReconnect bool, serverStartTimeMs int64) {
diff --git a/experimental/ssh/internal/client/client_test.go b/experimental/ssh/internal/client/client_test.go
index 48fb6f0c1f4..c146f6bd813 100644
--- a/experimental/ssh/internal/client/client_test.go
+++ b/experimental/ssh/internal/client/client_test.go
@@ -111,6 +111,15 @@ func TestValidate(t *testing.T) {
{
name: "base environment with serverless GPU accelerator",
opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", BaseEnvironment: "my-gpu-env"},
+ },
+ {
+ name: "usage policy with cluster ID",
+ opts: client.ClientOptions{ClusterID: "abc-123", UsagePolicyID: "pol-1"},
+ wantErr: "--usage-policy-id flag can only be used with serverless compute (--name flag)",
+ },
+ {
+ name: "usage policy with connection name",
+ opts: client.ClientOptions{ConnectionName: "my-conn", UsagePolicyID: "pol-1"},
},
}
@@ -233,6 +242,11 @@ func TestToProxyCommand(t *testing.T) {
opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", ShutdownDelay: 2 * time.Minute},
want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --accelerator=GPU_1xA10",
},
+ {
+ name: "serverless with usage policy",
+ opts: client.ClientOptions{ConnectionName: "my-conn", UsagePolicyID: "pol-1", ShutdownDelay: 2 * time.Minute},
+ want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --usage-policy-id=pol-1",
+ },
{
name: "with metadata",
opts: client.ClientOptions{ClusterID: "abc-123", ServerMetadata: "user,2222,abc-123"},
diff --git a/experimental/ssh/internal/client/policy_internal_test.go b/experimental/ssh/internal/client/policy_internal_test.go
new file mode 100644
index 00000000000..f501f0ba6e6
--- /dev/null
+++ b/experimental/ssh/internal/client/policy_internal_test.go
@@ -0,0 +1,26 @@
+package client
+
+import "testing"
+
+func TestUsagePolicyMatches(t *testing.T) {
+ tests := []struct {
+ name string
+ stored string
+ requested string
+ want bool
+ }{
+ {name: "empty request matches any server", stored: "pol-1", requested: "", want: true},
+ {name: "empty request matches server without policy", stored: "", requested: "", want: true},
+ {name: "equal policies match", stored: "pol-1", requested: "pol-1", want: true},
+ {name: "different policies do not match", stored: "pol-1", requested: "pol-2", want: false},
+ {name: "request against server without policy does not match", stored: "", requested: "pol-1", want: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := usagePolicyMatches(tt.stored, tt.requested); got != tt.want {
+ t.Errorf("usagePolicyMatches(%q, %q) = %v, want %v", tt.stored, tt.requested, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/experimental/ssh/internal/client/ssh-server-bootstrap.py b/experimental/ssh/internal/client/ssh-server-bootstrap.py
index 87d0d2756fe..28a20f73688 100644
--- a/experimental/ssh/internal/client/ssh-server-bootstrap.py
+++ b/experimental/ssh/internal/client/ssh-server-bootstrap.py
@@ -26,6 +26,7 @@
dbutils.widgets.text("shutdownDelay", "10m")
dbutils.widgets.text("sessionId", "")
dbutils.widgets.text("serverless", "false")
+dbutils.widgets.text("usagePolicyId", "")
def cleanup():
@@ -126,6 +127,7 @@ def run_ssh_server():
if not session_id:
raise RuntimeError("Session ID is required. Please provide it using the 'sessionId' widget.")
serverless = dbutils.widgets.get("serverless")
+ usage_policy_id = dbutils.widgets.get("usagePolicyId")
# Mark this process's WSFS command origin so workspace-file activity from the
# remote SSH session is attributable
@@ -172,6 +174,10 @@ def run_ssh_server():
"--log-file=stdout",
]
+ # Recorded in the server's metadata.json so reconnects can match the usage policy.
+ if usage_policy_id:
+ server_args.append(f"--usage-policy-id={usage_policy_id}")
+
# Tee the server output instead of inheriting stdout: the run-page logs remain the only
# place to debug a RUNNING server, but on failure we attach the log tail to the exception
# so "ssh connect" can print it (the Jobs run-output API has no stdout logs for notebook tasks).
diff --git a/experimental/ssh/internal/client/submit_internal_test.go b/experimental/ssh/internal/client/submit_internal_test.go
new file mode 100644
index 00000000000..5b1af006b1d
--- /dev/null
+++ b/experimental/ssh/internal/client/submit_internal_test.go
@@ -0,0 +1,76 @@
+package client
+
+import (
+ "testing"
+ "time"
+
+ "github.com/databricks/databricks-sdk-go/service/compute"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildSSHServerSubmitRun(t *testing.T) {
+ const notebookPath = "/Workspace/Users/me/.databricks/ssh-tunnel/v1/conn/ssh-server-bootstrap"
+
+ t.Run("serverless with usage policy", func(t *testing.T) {
+ opts := ClientOptions{
+ ConnectionName: "conn",
+ UsagePolicyID: "pol-1",
+ ServerTimeout: time.Hour,
+ EnvironmentVersion: 4,
+ }
+ got := buildSSHServerSubmitRun("v1", "scope", notebookPath, "", opts)
+
+ // Usage policy flows onto the run and into the base params the server reads.
+ assert.Equal(t, "pol-1", got.BudgetPolicyId)
+ assert.Equal(t, "pol-1", got.Tasks[0].NotebookTask.BaseParameters["usagePolicyId"])
+ assert.Equal(t, "true", got.Tasks[0].NotebookTask.BaseParameters["serverless"])
+
+ // Serverless runs on an environment, not an existing cluster.
+ assert.Equal(t, serverlessEnvironmentKey, got.Tasks[0].EnvironmentKey)
+ assert.Empty(t, got.Tasks[0].ExistingClusterId)
+ assert.Len(t, got.Environments, 1)
+ assert.Nil(t, got.Tasks[0].Compute)
+ })
+
+ t.Run("serverless with accelerator", func(t *testing.T) {
+ opts := ClientOptions{
+ ConnectionName: "conn",
+ Accelerator: "GPU_1xA10",
+ ServerTimeout: time.Hour,
+ }
+ got := buildSSHServerSubmitRun("v1", "scope", notebookPath, "", opts)
+
+ assert.Equal(t, compute.HardwareAcceleratorType("GPU_1xA10"), got.Tasks[0].Compute.HardwareAccelerator)
+ })
+
+ t.Run("serverless with base environment", func(t *testing.T) {
+ opts := ClientOptions{
+ ConnectionName: "conn",
+ ServerTimeout: time.Hour,
+ EnvironmentVersion: 4,
+ BaseEnvironment: "my-env",
+ }
+ got := buildSSHServerSubmitRun("v1", "scope", notebookPath, "workspace-base-environments/dbe_123", opts)
+
+ // A resolved base environment carries its own version, so environment_version is not set.
+ require.Len(t, got.Environments, 1)
+ assert.Equal(t, "workspace-base-environments/dbe_123", got.Environments[0].Spec.BaseEnvironment)
+ assert.Empty(t, got.Environments[0].Spec.EnvironmentVersion)
+ })
+
+ t.Run("dedicated cluster", func(t *testing.T) {
+ opts := ClientOptions{
+ ClusterID: "abc-123",
+ ServerTimeout: time.Hour,
+ }
+ got := buildSSHServerSubmitRun("v1", "scope", notebookPath, "", opts)
+
+ // Usage policy is serverless-only; a dedicated run carries none and targets the cluster.
+ assert.Empty(t, got.BudgetPolicyId)
+ assert.Empty(t, got.Tasks[0].NotebookTask.BaseParameters["usagePolicyId"])
+ assert.Equal(t, "abc-123", got.Tasks[0].ExistingClusterId)
+ assert.Empty(t, got.Tasks[0].EnvironmentKey)
+ assert.Empty(t, got.Environments)
+ })
+}
diff --git a/experimental/ssh/internal/server/server.go b/experimental/ssh/internal/server/server.go
index b07b6863c00..a5e89a7b701 100644
--- a/experimental/ssh/internal/server/server.go
+++ b/experimental/ssh/internal/server/server.go
@@ -37,6 +37,9 @@ type ServerOptions struct {
SessionID string
// Serverless indicates whether the server is running on serverless compute.
Serverless bool
+ // UsagePolicyID the job was submitted with. Persisted to metadata.json so reconnects
+ // can tell which usage policy the running server was started under.
+ UsagePolicyID string
// The directory to store sshd configuration
ConfigDir string
// The name of the secrets scope to use for client and server keys
@@ -66,8 +69,9 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt
// Save metadata including ClusterID (required for Driver Proxy connections in serverless mode)
metadata := &workspace.WorkspaceMetadata{
- Port: port,
- ClusterID: opts.ClusterID,
+ Port: port,
+ ClusterID: opts.ClusterID,
+ UsagePolicyID: opts.UsagePolicyID,
}
err = workspace.SaveWorkspaceMetadata(ctx, client, opts.Version, opts.SessionID, metadata)
if err != nil {
diff --git a/experimental/ssh/internal/workspace/workspace.go b/experimental/ssh/internal/workspace/workspace.go
index 0a28b684ebc..576e8a6df9f 100644
--- a/experimental/ssh/internal/workspace/workspace.go
+++ b/experimental/ssh/internal/workspace/workspace.go
@@ -19,6 +19,9 @@ type WorkspaceMetadata struct {
Port int `json:"port"`
// ClusterID is required for Driver Proxy websocket connections (for any compute type, including serverless)
ClusterID string `json:"cluster_id,omitempty"`
+ // UsagePolicyID records the usage policy the server's job was submitted with, so a
+ // reconnect can tell whether a running server matches the requested usage policy.
+ UsagePolicyID string `json:"usage_policy_id,omitempty"`
}
func getWorkspaceRootDir(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {