Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Notable Changes

### CLI
* `ssh connect` now supports specifying a serverless usage policy with `--usage-policy-id` ([#5781](https://github.com/databricks/cli/pull/5781)).

* `ssh connect` now accepts a `--base-environment` flag to run a serverless session on a custom base environment. It takes an `env.yaml` path, a `workspace-base-environments/...` resource ID, or a base environment display name, and is rejected together with `--environment-version` or `--cluster` ([#5706](https://github.com/databricks/cli/pull/5706)).
* `databricks aitools install` is now plugin-first: it installs the Databricks plugin through each agent's own CLI (Claude Code, Codex, GitHub Copilot) instead of copying raw skill files. Agents without a plugin (OpenCode, Antigravity) still get skill files, and Cursor prints the `/add-plugin databricks` step. Use `--skills-only` to force raw skill files for every agent, or `--path <dir>` to write skills to a directory ([#5738](https://github.com/databricks/cli/pull/5738)).
Expand Down
3 changes: 3 additions & 0 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Connect to a dedicated cluster:
var environmentVersion int
var baseEnvironment string
var autoApprove bool
var usagePolicyID string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks dedicated cluster ID")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
Expand All @@ -50,6 +51,7 @@ Connect to a dedicated cluster:
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name to reuse across sessions (serverless only)")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "Serverless GPU accelerator type (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)")
cmd.Flags().StringVar(&usagePolicyID, "usage-policy-id", "", "Usage policy ID for the serverless SSH server job (serverless only)")

cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode")
cmd.Flags().MarkHidden("proxy")
Expand Down Expand Up @@ -130,6 +132,7 @@ Connect to a dedicated cluster:
BaseEnvironment: baseEnvironment,
AdditionalArgs: args,
AutoApprove: autoApprove,
UsagePolicyID: usagePolicyID,
}
if err := opts.Validate(); err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions experimental/ssh/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and proxies them to local SSH daemon processes.`,
var secretScopeName string
var authorizedKeySecretName string
var serverless bool
var usagePolicyID string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.MarkFlagRequired("cluster")
Expand All @@ -43,6 +44,7 @@ and proxies them to local SSH daemon processes.`,
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients")
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")
cmd.Flags().BoolVar(&serverless, "serverless", false, "Enable serverless mode for Jupyter initialization")
cmd.Flags().StringVar(&usagePolicyID, "usage-policy-id", "", "Usage policy ID the job was submitted with")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// The server can be executed under a directory with an invalid bundle configuration.
Expand Down Expand Up @@ -71,6 +73,7 @@ and proxies them to local SSH daemon processes.`,
DefaultPort: defaultServerPort,
PortRange: serverPortRange,
Serverless: serverless,
UsagePolicyID: usagePolicyID,
}
return server.Run(ctx, wsc, opts)
}
Expand Down
170 changes: 115 additions & 55 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ type ClientOptions struct {
BaseEnvironment string
// If true, skip confirmation prompts for IDE extension install and IDE settings updates.
AutoApprove bool
// Id of the usage policy to use for the serverless SSH server job. Serverless only.
UsagePolicyID string
}

func (o *ClientOptions) Validate() error {
Expand All @@ -128,6 +130,9 @@ func (o *ClientOptions) Validate() error {
if o.Accelerator != "" && o.ConnectionName == "" {
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
}
if o.UsagePolicyID != "" && o.ClusterID != "" {
return errors.New("--usage-policy-id flag can only be used with serverless compute (--name flag)")
}
if o.Accelerator != "" && o.Accelerator != "GPU_1xA10" && o.Accelerator != "GPU_8xH100" {
return fmt.Errorf("invalid accelerator value: %q, expected %q or %q", o.Accelerator, "GPU_1xA10", "GPU_8xH100")
}
Expand Down Expand Up @@ -214,6 +219,9 @@ func (o *ClientOptions) ToProxyCommand() (string, error) {
if o.Accelerator != "" {
proxyCommand += " --accelerator=" + o.Accelerator
}
if o.UsagePolicyID != "" {
proxyCommand += " --usage-policy-id=" + o.UsagePolicyID
}
} else {
proxyCommand = fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s",
executablePath, o.ClusterID, o.AutoStartCluster, o.ShutdownDelay.String())
Expand Down Expand Up @@ -463,14 +471,25 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k
return nil
}

// serverMetadata describes a running SSH server, combining the persisted workspace
// metadata with the user name validated live via Driver Proxy.
type serverMetadata struct {
Port int
UserName string
// ClusterID required for Driver Proxy connections. For serverless it comes from the persisted metadata.
ClusterID string
// UsagePolicyID the server was started with, used to decide whether a running server can be reused.
UsagePolicyID string
}

// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy.
// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
// For dedicated clusters, clusterID should be the same as sessionID.
// For serverless, clusterID is read from the workspace metadata.
func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (int, string, string, error) {
func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (serverMetadata, error) {
wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID)
if err != nil {
return 0, "", "", errors.Join(errServerMetadata, err)
return serverMetadata{}, errors.Join(errServerMetadata, err)
}
log.Debugf(ctx, "Workspace metadata: %+v", wsMetadata)

Expand All @@ -481,33 +500,38 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
}

if effectiveClusterID == "" {
return 0, "", "", errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata"))
return serverMetadata{}, errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata"))
}

req, err := newDriverProxyRequest(ctx, client, effectiveClusterID, wsMetadata.Port, "metadata", liteswap)
if err != nil {
return 0, "", "", err
return serverMetadata{}, err
}
log.Debugf(ctx, "Metadata URL: %s", req.URL)
httpClient := &http.Client{Transport: client.Config.HTTPTransport}
resp, err := httpClient.Do(req)
if err != nil {
return 0, "", "", err
return serverMetadata{}, err
}
defer resp.Body.Close()

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return 0, "", "", err
return serverMetadata{}, err
}
log.Debugf(ctx, "Metadata response: %s", string(bodyBytes))
log.Debugf(ctx, "Metadata response status code: %d", resp.StatusCode)

if resp.StatusCode != http.StatusOK {
return 0, "", "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode))
return serverMetadata{}, errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode))
}

return wsMetadata.Port, string(bodyBytes), effectiveClusterID, nil
return serverMetadata{
Port: wsMetadata.Port,
UserName: string(bodyBytes),
ClusterID: effectiveClusterID,
UsagePolicyID: wsMetadata.UsagePolicyID,
}, nil
}

// newDriverProxyRequest builds an authenticated GET request to one of the SSH server's
Expand Down Expand Up @@ -559,36 +583,10 @@ func fetchServerErrorLogs(ctx context.Context, client *databricks.WorkspaceClien
return strings.TrimSpace(string(body))
}

// submitSSHTunnelJob submits the bootstrap job and waits for the SSH server task to start.
// It returns the job run ID (when known) so callers can fetch and surface the run's error
// details if the server never comes up.
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) {
// Assemble the SubmitRun request that bootstraps the SSH server.
// Extracted from submitSSHTunnelJob so this logic can be unit tested.
func buildSSHServerSubmitRun(version, secretScopeName, jobNotebookPath, baseEnvironment string, opts ClientOptions) jobs.SubmitRun {
sessionID := opts.SessionIdentifier()
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID)
if err != nil {
return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
}

err = client.Workspace.MkdirsByPath(ctx, contentDir)
if err != nil {
return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err)
}

sshTunnelJobName := "ssh-server-bootstrap-" + sessionID
jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap"))
notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript
encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent))

err = client.Workspace.Import(ctx, workspace.Import{
Path: jobNotebookPath,
Format: workspace.ImportFormatSource,
Content: encodedContent,
Language: workspace.LanguagePython,
Overwrite: true,
})
if err != nil {
return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err)
}

baseParams := map[string]string{
"version": version,
Expand All @@ -598,10 +596,11 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
"maxClients": strconv.Itoa(opts.MaxClients),
"sessionId": sessionID,
"serverless": strconv.FormatBool(opts.IsServerlessMode()),
// Recorded in the server's metadata.json so reconnects can tell which usage policy
// the running server was started under.
"usagePolicyId": opts.UsagePolicyID,
}

log.Infof(ctx, "Submitting a job to start the ssh server...")

task := jobs.SubmitTask{
TaskKey: sshServerTaskKey,
NotebookTask: &jobs.NotebookTask{
Expand All @@ -614,7 +613,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
if opts.IsServerlessMode() {
task.EnvironmentKey = serverlessEnvironmentKey
if opts.Accelerator != "" {
log.Infof(ctx, "Using accelerator: %s", opts.Accelerator)
task.Compute = &jobs.Compute{
HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator),
}
Expand All @@ -624,20 +622,17 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
}

submitRequest := jobs.SubmitRun{
RunName: sshTunnelJobName,
RunName: "ssh-server-bootstrap-" + sessionID,
TimeoutSeconds: int(opts.ServerTimeout.Seconds()),
Tasks: []jobs.SubmitTask{task},
BudgetPolicyId: opts.UsagePolicyID,
}

if opts.IsServerlessMode() {
// base_environment and environment_version are mutually exclusive: a custom
// base environment carries its own version, so we don't also set one.
var spec compute.Environment
if opts.BaseEnvironment != "" {
baseEnvironment, err := resolveBaseEnvironment(ctx, client, opts.BaseEnvironment)
if err != nil {
return 0, err
}
if baseEnvironment != "" {
spec.BaseEnvironment = baseEnvironment
} else {
spec.EnvironmentVersion = strconv.Itoa(max(opts.EnvironmentVersion, minEnvironmentVersion))
Expand All @@ -650,6 +645,54 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
}
}

return submitRequest
}

// submitSSHTunnelJob submits the bootstrap job and waits for the SSH server task to start.
// It returns the job run ID (when known) so callers can fetch and surface the run's error
// details if the server never comes up.
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) {
sessionID := opts.SessionIdentifier()
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID)
if err != nil {
return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
}

err = client.Workspace.MkdirsByPath(ctx, contentDir)
if err != nil {
return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err)
}

jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap"))
notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript
encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent))

err = client.Workspace.Import(ctx, workspace.Import{
Path: jobNotebookPath,
Format: workspace.ImportFormatSource,
Content: encodedContent,
Language: workspace.LanguagePython,
Overwrite: true,
})
if err != nil {
return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err)
}

log.Infof(ctx, "Submitting a job to start the ssh server...")
if opts.IsServerlessMode() && opts.Accelerator != "" {
log.Infof(ctx, "Using accelerator: %s", opts.Accelerator)
}

var baseEnvironment string
if opts.IsServerlessMode() && opts.BaseEnvironment != "" {
baseEnvironment, err = resolveBaseEnvironment(ctx, client, opts.BaseEnvironment)
if err != nil {
return 0, err
}
}

submitRequest := buildSSHServerSubmitRun(version, secretScopeName, jobNotebookPath, baseEnvironment, opts)

waiter, err := client.Jobs.Submit(ctx, submitRequest)
if err != nil {
return 0, fmt.Errorf("failed to submit job: %w", err)
Expand Down Expand Up @@ -1046,18 +1089,31 @@ func hostKeyChangedHint(stderr, hostName, knownHostsFile string) string {
"Remove the stale entry and reconnect:\n " + cmd
}

func usagePolicyMatches(storedPolicy, requestedPolicy string) bool {
return requestedPolicy == "" || storedPolicy == requestedPolicy
}

func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) {
sessionID := opts.SessionIdentifier()
// For dedicated clusters, use clusterID; for serverless, it will be read from metadata
clusterID := opts.ClusterID

serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
if errors.Is(err, errServerMetadata) {
meta, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
if err != nil && !errors.Is(err, errServerMetadata) {
return "", 0, "", err
}

// Start a new server when none is running, or when the running one was started under a
// different usage policy. A job's usage policy is fixed at submission, so we can't retarget
// the existing server; the new server overwrites metadata.json and the old one idles out via
// shutdownDelay.
needNewServer := err != nil || !usagePolicyMatches(meta.UsagePolicyID, opts.UsagePolicyID)
if needNewServer {
cmdio.LogString(ctx, "Starting SSH server...")

runID, err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts)
if err != nil {
return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", err)
runID, submitErr := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts)
if submitErr != nil {
return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", submitErr)
}

sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
Expand All @@ -1068,7 +1124,13 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
if ctx.Err() != nil {
return "", 0, "", ctx.Err()
}
serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
meta, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
// Accept only once metadata reflects the requested usage policy, so we don't latch
// onto a server a previous connection started under a different policy before our new
// server has overwritten metadata.json.
if err == nil && !usagePolicyMatches(meta.UsagePolicyID, opts.UsagePolicyID) {
err = fmt.Errorf("found a running SSH server with usage policy %q, waiting for the one with %q", meta.UsagePolicyID, opts.UsagePolicyID)
}
if err == nil {
cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...")
break
Expand All @@ -1085,11 +1147,9 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
return "", 0, "", fmt.Errorf("failed to start the ssh server: %w\n%s", err, describeRunFailure(ctx, client, runID))
}
}
} else if err != nil {
return "", 0, "", err
}

return userName, serverPort, effectiveClusterID, nil
return meta.UserName, meta.Port, meta.ClusterID, nil
}

func logSshTunnelEvent(ctx context.Context, opts ClientOptions, isSuccess, isReconnect bool, serverStartTimeMs int64) {
Expand Down
14 changes: 14 additions & 0 deletions experimental/ssh/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@
{
name: "base environment with serverless GPU accelerator",
opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", BaseEnvironment: "my-gpu-env"},
},

Check failure on line 114 in experimental/ssh/internal/client/client_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not properly formatted (gofmt)
{
name: "usage policy with cluster ID",
opts: client.ClientOptions{ClusterID: "abc-123", UsagePolicyID: "pol-1"},
wantErr: "--usage-policy-id flag can only be used with serverless compute (--name flag)",
},
{
name: "usage policy with connection name",
opts: client.ClientOptions{ConnectionName: "my-conn", UsagePolicyID: "pol-1"},
},
}

Expand Down Expand Up @@ -233,6 +242,11 @@
opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", ShutdownDelay: 2 * time.Minute},
want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --accelerator=GPU_1xA10",
},
{
name: "serverless with usage policy",
opts: client.ClientOptions{ConnectionName: "my-conn", UsagePolicyID: "pol-1", ShutdownDelay: 2 * time.Minute},
want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --usage-policy-id=pol-1",
},
{
name: "with metadata",
opts: client.ClientOptions{ClusterID: "abc-123", ServerMetadata: "user,2222,abc-123"},
Expand Down
Loading
Loading