From 214605f6902905267b4544c1bfcee3d3bd13d91e Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 9 Apr 2026 10:57:50 +0100 Subject: [PATCH 1/5] Store auth credentials in the OS keychain (macOS + Linux) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move auth tokens and refresh tokens out of the plaintext config.yaml into the OS-native credential store via 99designs/keyring: - macOS: Keychain Services with per-app ACL (re-prompts after binary updates when another process tries to read the item) - Linux: Secret Service (GNOME Keyring / KWallet) when a D-Bus session is available - Windows support to follow Credentials are managed through a SecureStore interface (pkg/keychain) with Get/Set/DeleteCredentials methods. A TokenHolder (pkg/httputil) is populated once during PersistentPreRunE and read by API client request editors on every outbound request — replacing the previous per-request context lookup for the token. On first run, MigrateLegacyCredentials copies token/refreshtoken fields from existing config.yaml contexts into the secure store and clears them from the config file. On Linux, when no Secret Service daemon is reachable (headless CI, containers without D-Bus), the store falls back to a plaintext JSON file at ~/.astro/credentials.json (mode 0600). This matches the previous config.yaml security posture and is acceptable because CI environments typically use ASTRO_API_TOKEN rather than interactive login credentials. Windows uses the same plaintext file fallback for now; the Credential Manager backend is added in the follow-up commit. --- airflow-client/airflow-client.go | 24 ++--- airflow-client/airflow-client_test.go | 111 ++++++++++++-------- airflow/container.go | 3 +- airflow/docker.go | 17 +++- airflow/docker_test.go | 46 +++++---- airflow/mocks/ContainerHandler.go | 11 +- airflow/standalone.go | 3 +- airflow/standalone_test.go | 2 +- airflow/standalone_windows.go | 3 +- astro-client-core/client.go | 9 +- astro-client-core/client.test.go | 2 +- astro-client-iam-core/client.go | 5 +- astro-client-iam-core/client.test.go | 2 +- astro-client-platform-core/client.go | 6 +- cloud/auth/auth.go | 35 ++++--- cloud/auth/auth_test.go | 141 ++++++++++++-------------- cloud/auth/types.go | 24 ----- cloud/deploy/bundle.go | 2 +- cloud/deploy/deploy.go | 8 +- cloud/deploy/deploy_test.go | 26 ++--- cloud/deployment/deployment.go | 2 +- cloud/deployment/fromfile/fromfile.go | 6 +- cloud/organization/organization.go | 4 - cloud/platformclient/client.go | 4 +- cloud/platformclient/client_test.go | 2 +- cmd/airflow.go | 13 +-- cmd/airflow_test.go | 32 +++--- cmd/api/airflow.go | 11 +- cmd/api/airflow_test.go | 19 ++-- cmd/api/api.go | 12 ++- cmd/api/api_test.go | 8 +- cmd/api/cloud.go | 12 ++- cmd/api/cloud_test.go | 15 ++- cmd/auth.go | 78 ++++++++------ cmd/auth_test.go | 74 +++++++------- cmd/cloud/setup.go | 93 ++++++----------- cmd/cloud/setup_test.go | 128 ++++++++++++----------- cmd/root.go | 29 +++--- cmd/root_hooks.go | 65 +++++++++--- cmd/root_test.go | 7 ++ cmd/software/deploy.go | 6 +- cmd/software/deploy_test.go | 31 +++--- cmd/software/deployment_logs.go | 12 ++- cmd/software/deployment_teams_test.go | 2 +- cmd/software/root.go | 5 +- cmd/software/root_test.go | 6 +- cmd/software/utils_test.go | 4 +- cmd/software/workspace_teams_test.go | 2 +- cmd/software/workspace_user_test.go | 2 +- config/context.go | 71 ++++++++----- config/context_test.go | 49 ++------- config/migrate_test.go | 131 ++++++++++++++++++++++++ context/context_test.go | 3 +- go.mod | 7 ++ go.sum | 15 +++ houston/app_test.go | 18 ++-- houston/auth_test.go | 8 +- houston/decorator.go | 2 +- houston/deployment_teams_test.go | 16 +-- houston/deployment_test.go | 48 ++++----- houston/deployment_user_test.go | 16 +-- houston/houston.go | 20 ++-- houston/houston_test.go | 2 +- houston/runtime_test.go | 8 +- houston/service_account_test.go | 24 ++--- houston/teams_test.go | 20 ++-- houston/user_test.go | 4 +- houston/workspace_teams_test.go | 20 ++-- houston/workspace_test.go | 28 ++--- houston/workspace_users_test.go | 24 ++--- pkg/httputil/token_holder.go | 34 +++++++ pkg/httputil/token_holder_test.go | 20 ++++ pkg/keychain/keychain.go | 79 +++++++++++++++ pkg/keychain/keychain_darwin.go | 31 ++++++ pkg/keychain/keychain_file.go | 108 ++++++++++++++++++++ pkg/keychain/keychain_linux.go | 39 +++++++ pkg/keychain/keychain_test.go | 68 +++++++++++++ pkg/keychain/keychain_windows.go | 18 ++++ pkg/testing/testing.go | 1 - software/auth/auth.go | 33 +++--- software/auth/auth_test.go | 35 +++---- software/deploy/deploy.go | 30 ++++-- software/deploy/deploy_test.go | 75 +++++++------- software/deployment/logs.go | 4 +- software/deployment/logs_test.go | 4 +- 85 files changed, 1450 insertions(+), 827 deletions(-) create mode 100644 config/migrate_test.go create mode 100644 pkg/httputil/token_holder.go create mode 100644 pkg/httputil/token_holder_test.go create mode 100644 pkg/keychain/keychain.go create mode 100644 pkg/keychain/keychain_darwin.go create mode 100644 pkg/keychain/keychain_file.go create mode 100644 pkg/keychain/keychain_linux.go create mode 100644 pkg/keychain/keychain_test.go create mode 100644 pkg/keychain/keychain_windows.go diff --git a/airflow-client/airflow-client.go b/airflow-client/airflow-client.go index e8a56f8c2..65e77cb7e 100644 --- a/airflow-client/airflow-client.go +++ b/airflow-client/airflow-client.go @@ -11,7 +11,6 @@ import ( "github.com/hashicorp/go-retryablehttp" - "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/httputil" ) @@ -41,12 +40,14 @@ type Client interface { // Client containers the logger and HTTPClient used to communicate with the Astronomer API type HTTPClient struct { *httputil.HTTPClient + tokenHolder *httputil.TokenHolder } -// NewAstroClient returns a new Client with the logger and HTTP client setup. -func NewAirflowClient(c *httputil.HTTPClient) *HTTPClient { +// NewAirflowClient returns a new Client with the logger and HTTP client setup. +func NewAirflowClient(c *httputil.HTTPClient, tokenHolder *httputil.TokenHolder) *HTTPClient { return &HTTPClient{ - c, + HTTPClient: c, + tokenHolder: tokenHolder, } } @@ -242,17 +243,14 @@ func checkRetryPolicy(method string) retryablehttp.CheckRetry { } func (c *HTTPClient) DoAirflowClient(doOpts *httputil.DoOptions) (*Response, error) { - cl, err := context.GetCurrentContext() - if err != nil { - return nil, err - } - - if cl.Token != "" { - doOpts.Headers = map[string]string{ - "authorization": cl.Token, + if c.tokenHolder != nil { + if tok := c.tokenHolder.Get(); tok != "" { + if doOpts.Headers == nil { + doOpts.Headers = map[string]string{} + } + doOpts.Headers["authorization"] = tok } } - req, err := retryablehttp.NewRequest(doOpts.Method, doOpts.Path, doOpts.Data) if err != nil { return nil, err diff --git a/airflow-client/airflow-client_test.go b/airflow-client/airflow-client_test.go index 5a290fc4b..c764a3ffb 100644 --- a/airflow-client/airflow-client_test.go +++ b/airflow-client/airflow-client_test.go @@ -32,7 +32,7 @@ func TestAirflowClient(t *testing.T) { } func (s *Suite) TestNew() { - client := NewAirflowClient(httputil.NewHTTPClient()) + client := NewAirflowClient(httputil.NewHTTPClient(), nil) s.NotNil(client, "Can't create new Astro client") } @@ -46,7 +46,8 @@ func (s *Suite) TestDoAirflowClient() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) doOpts := &httputil.DoOptions{ Path: "/test", Headers: map[string]string{ @@ -110,7 +111,8 @@ func (s *Suite) TestGetConnections() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) response, err := airflowClient.GetConnections("test-airflow-url") s.NoError(err) @@ -126,7 +128,7 @@ func (s *Suite) TestGetConnections() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) _, err := airflowClient.GetConnections("test-airflow-url") s.Error(err) @@ -141,7 +143,8 @@ func (s *Suite) TestGetConnections() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) _, err := airflowClient.GetConnections("test-airflow-url") s.Error(err) @@ -174,7 +177,8 @@ func (s *Suite) TestUpdateConnection() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdateConnection("test-airflow-url", mockConn) s.NoError(err) @@ -188,7 +192,8 @@ func (s *Suite) TestUpdateConnection() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdateConnection("test-airflow-url", mockConn) s.Error(err) @@ -203,7 +208,8 @@ func (s *Suite) TestUpdateConnection() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) // Pass a nil connection to force JSON marshal error err := airflowClient.UpdateConnection("test-airflow-url", mockConn) @@ -219,7 +225,8 @@ func (s *Suite) TestUpdateConnection() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdateConnection("test-airflow-url", mockConn) s.Error(err) @@ -252,7 +259,8 @@ func (s *Suite) TestCreateConnection() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreateConnection("test-airflow-url", mockConn) s.NoError(err) @@ -266,7 +274,8 @@ func (s *Suite) TestCreateConnection() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreateConnection("test-airflow-url", mockConn) s.Error(err) @@ -281,7 +290,8 @@ func (s *Suite) TestCreateConnection() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) // Pass a nil connection to force JSON marshal error err := airflowClient.CreateConnection("test-airflow-url", nil) @@ -297,7 +307,8 @@ func (s *Suite) TestCreateConnection() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreateConnection("test-airflow-url", mockConn) s.Error(err) @@ -330,7 +341,8 @@ func (s *Suite) TestCreateVariable() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreateVariable("test-airflow-url", *mockVar) s.NoError(err) @@ -344,7 +356,8 @@ func (s *Suite) TestCreateVariable() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreateVariable("test-airflow-url", *mockVar) s.Error(err) @@ -359,7 +372,8 @@ func (s *Suite) TestCreateVariable() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreateVariable("test-airflow-url", Variable{Key: "", Value: "test-value"}) s.Error(err) @@ -385,7 +399,8 @@ func (s *Suite) TestGetVariables() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) response, err := airflowClient.GetVariables("test-airflow-url") s.NoError(err) @@ -401,7 +416,7 @@ func (s *Suite) TestGetVariables() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) _, err := airflowClient.GetVariables("test-airflow-url") s.Error(err) @@ -435,7 +450,8 @@ func (s *Suite) TestUpdateVariable() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdateVariable("test-airflow-url", *mockVar) s.NoError(err) @@ -449,7 +465,8 @@ func (s *Suite) TestUpdateVariable() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdateVariable("test-airflow-url", *mockVar) s.Error(err) @@ -464,7 +481,8 @@ func (s *Suite) TestUpdateVariable() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdateVariable("test-airflow-url", Variable{Key: "", Value: "test-value"}) s.Error(err) @@ -514,7 +532,8 @@ func (s *Suite) TestCreatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreatePool("test-airflow-url", *mockPool) s.NoError(err) @@ -528,7 +547,8 @@ func (s *Suite) TestCreatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreatePool("test-airflow-url", *mockPool) s.Error(err) @@ -543,7 +563,8 @@ func (s *Suite) TestCreatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) // Pass a nil pool to force JSON marshal error err := airflowClient.CreatePool("test-airflow-url", *mockPool) @@ -559,7 +580,8 @@ func (s *Suite) TestCreatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.CreatePool("test-airflow-url", *mockPool) s.Error(err) @@ -592,7 +614,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdatePool("test-airflow-url", *mockPool) s.NoError(err) @@ -623,7 +646,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err = airflowClient.UpdatePool("test-airflow-url", defaultPool) s.NoError(err) @@ -637,7 +661,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdatePool("test-airflow-url", *mockPool) s.Error(err) @@ -652,7 +677,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) // Pass a nil pool to force JSON marshal error err := airflowClient.UpdatePool("test-airflow-url", Pool{}) @@ -668,7 +694,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) err := airflowClient.UpdatePool("test-airflow-url", *mockPool) s.Error(err) @@ -693,7 +720,8 @@ func (s *Suite) TestGetPools() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) response, err := airflowClient.GetPools("test-airflow-url") s.NoError(err) @@ -709,7 +737,7 @@ func (s *Suite) TestGetPools() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) response, err := airflowClient.GetPools("test-airflow-url") s.Error(err) @@ -726,7 +754,8 @@ func (s *Suite) TestGetPools() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + th := httputil.NewTokenHolder("token") + airflowClient := NewAirflowClient(client, th) response, err := airflowClient.GetPools("test-airflow-url") s.Error(err) @@ -759,7 +788,7 @@ func (s *Suite) TestDoAirflowClientRetry() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) doOpts := &httputil.DoOptions{ Path: "/test", @@ -783,7 +812,7 @@ func (s *Suite) TestDoAirflowClientRetry() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) doOpts := &httputil.DoOptions{ Path: "/test", @@ -807,7 +836,7 @@ func (s *Suite) TestDoAirflowClientRetry() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) doOpts := &httputil.DoOptions{ Path: "/test", @@ -829,7 +858,7 @@ func (s *Suite) TestDoAirflowClientRetry() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) doOpts := &httputil.DoOptions{ Path: "/test", @@ -848,7 +877,7 @@ func (s *Suite) TestDoAirflowClientRetry() { callCount++ return stdctx.Canceled }) - airflowClient := NewAirflowClient(cancelTransport) + airflowClient := NewAirflowClient(cancelTransport, nil) doOpts := &httputil.DoOptions{ Path: "/test", @@ -897,7 +926,7 @@ func (s *Suite) TestGetConnectionsPagination() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) response, err := airflowClient.GetConnections("test-airflow-url") s.NoError(err) @@ -941,7 +970,7 @@ func (s *Suite) TestGetVariablesPagination() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) response, err := airflowClient.GetVariables("test-airflow-url") s.NoError(err) @@ -985,7 +1014,7 @@ func (s *Suite) TestGetPoolsPagination() { Header: make(http.Header), } }) - airflowClient := NewAirflowClient(client) + airflowClient := NewAirflowClient(client, nil) response, err := airflowClient.GetPools("test-airflow-url") s.NoError(err) diff --git a/airflow/container.go b/airflow/container.go index 88096c8c5..ca0704998 100644 --- a/airflow/container.go +++ b/airflow/container.go @@ -18,6 +18,7 @@ import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/pkg/fileutil" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" "github.com/astronomer/astro-cli/pkg/util" ) @@ -40,7 +41,7 @@ type ContainerHandler interface { ComposeExport(settingsFile, composeFile string) error Pytest(pytestFile, customImageName, deployImageName, pytestArgsString, buildSecretString string) (string, error) Parse(customImageName, deployImageName, buildSecretString string) error - UpgradeTest(runtimeVersion, deploymentID, customImageName, buildSecretString string, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix bool, lintConfigFile string, astroPlatformCore astroplatformcore.ClientWithResponsesInterface) error + UpgradeTest(runtimeVersion, deploymentID, customImageName, buildSecretString string, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix bool, lintConfigFile string, astroPlatformCore astroplatformcore.ClientWithResponsesInterface, store keychain.SecureStore) error } // RegistryHandler defines methods require to handle all operations with registry diff --git a/airflow/docker.go b/airflow/docker.go index 7b3670340..d2b273f68 100644 --- a/airflow/docker.go +++ b/airflow/docker.go @@ -42,6 +42,7 @@ import ( "github.com/astronomer/astro-cli/docker" "github.com/astronomer/astro-cli/pkg/ansi" "github.com/astronomer/astro-cli/pkg/fileutil" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" "github.com/astronomer/astro-cli/pkg/spinner" "github.com/astronomer/astro-cli/pkg/util" @@ -701,7 +702,7 @@ func (d *DockerCompose) Pytest(pytestFile, customImageName, deployImageName, pyt return exitCode, errors.New("something went wrong while Pytesting your DAGs") } -func (d *DockerCompose) UpgradeTest(newVersion, deploymentID, customImage, buildSecretString string, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix bool, lintConfigFile string, astroPlatformCore astroplatformcore.CoreClient) error { //nolint:gocognit,gocyclo +func (d *DockerCompose) UpgradeTest(newVersion, deploymentID, customImage, buildSecretString string, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix bool, lintConfigFile string, astroPlatformCore astroplatformcore.CoreClient, store keychain.SecureStore) error { //nolint:gocognit,gocyclo // figure out which tests to run if !versionTest && !dagTest && !lintTest { versionTest = true @@ -719,7 +720,7 @@ func (d *DockerCompose) UpgradeTest(newVersion, deploymentID, customImage, build } // if user supplies deployment id pull down current image if deploymentID != "" { - err := d.pullImageFromDeployment(deploymentID, astroPlatformCore) + err := d.pullImageFromDeployment(deploymentID, astroPlatformCore, store) if err != nil { return err } @@ -806,7 +807,7 @@ func (d *DockerCompose) UpgradeTest(newVersion, deploymentID, customImage, build return nil } -func (d *DockerCompose) pullImageFromDeployment(deploymentID string, platformCoreClient astroplatformcore.CoreClient) error { +func (d *DockerCompose) pullImageFromDeployment(deploymentID string, platformCoreClient astroplatformcore.CoreClient, store keychain.SecureStore) error { c, err := config.GetCurrentContext() if err != nil { return err @@ -817,9 +818,15 @@ func (d *DockerCompose) pullImageFromDeployment(deploymentID string, platformCor return err } deploymentImage := fmt.Sprintf("%s:%s", currentDeployment.ImageRepository, currentDeployment.ImageTag) - token := c.Token + if store == nil { + return fmt.Errorf("no credentials found for %s: credential store unavailable", c.Domain) + } + creds, err := store.GetCredentials(c.Domain) + if err != nil { + return fmt.Errorf("no credentials found for %s, please run 'astro login': %w", c.Domain, err) + } fmt.Printf("\nPulling image from Astro Deployment %s\n\n", currentDeployment.Name) - err = d.imageHandler.Pull(deploymentImage, registryUsername, token) + err = d.imageHandler.Pull(deploymentImage, registryUsername, creds.Token) if err != nil { return err } diff --git a/airflow/docker_test.go b/airflow/docker_test.go index 58d1e0bea..78a9b5a2e 100644 --- a/airflow/docker_test.go +++ b/airflow/docker_test.go @@ -26,6 +26,7 @@ import ( astroplatformcore_mocks "github.com/astronomer/astro-cli/astro-client-platform-core/mocks" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/pkg/fileutil" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -1255,13 +1256,16 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil) // All tests enabled by default + err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil, nil) // All tests enabled by default s.NoError(err) imageHandler.AssertExpectations(s.T()) }) s.Run("success with deployment id", func() { + testUtil.InitTestConfig(testUtil.LocalPlatform) + store := keychain.NewTestStore() + _ = store.SetCredentials("localhost", keychain.Credentials{Token: "Bearer test-token"}) imageHandler := new(mocks.ImageHandler) mockPlatformCoreClient := new(astroplatformcore_mocks.ClientWithResponsesInterface) @@ -1275,8 +1279,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { imageHandler.On("Pytest", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, airflowTypes.ImageBuildConfig{Path: mockDockerCompose.airflowHome, NoCache: false}).Return("0", nil).Once() mockDockerCompose.imageHandler = imageHandler - // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "test-deployment-id", "", "", true, true, true, false, false, "", mockPlatformCoreClient) // All tests enabled by default + err := mockDockerCompose.UpgradeTest("new-version", "test-deployment-id", "", "", true, true, true, false, false, "", mockPlatformCoreClient, store) // All tests enabled by default s.NoError(err) imageHandler.AssertExpectations(s.T()) @@ -1288,7 +1291,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil) + err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil, nil) s.Error(err) imageHandler.AssertExpectations(s.T()) }) @@ -1300,7 +1303,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil) + err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil, nil) s.Error(err) imageHandler.AssertExpectations(s.T()) }) @@ -1313,7 +1316,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil) // versionTest=true is required for this path + err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil, nil) // versionTest=true is required for this path s.Error(err) imageHandler.AssertExpectations(s.T()) }) @@ -1328,7 +1331,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil) // versionTest=true is required for this path + err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil, nil) // versionTest=true is required for this path s.Error(err) imageHandler.AssertExpectations(s.T()) }) @@ -1344,7 +1347,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil) // dagTest=true is required for this path + err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil, nil) // dagTest=true is required for this path s.Error(err) imageHandler.AssertExpectations(s.T()) }) @@ -1359,7 +1362,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil) // dagTest=true is required for this path + err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, true, true, false, false, "", nil, nil) // dagTest=true is required for this path s.Error(err) imageHandler.AssertExpectations(s.T()) }) @@ -1370,13 +1373,15 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(nil, errMock).Once() // Error on first call mockDockerCompose.imageHandler = imageHandler - // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "deployment-id", "", "", false, false, false, false, false, "", mockPlatformCoreClient) + err := mockDockerCompose.UpgradeTest("new-version", "deployment-id", "", "", false, false, false, false, false, "", mockPlatformCoreClient, nil) s.Error(err) // No image handler expectations needed as it fails before pull/build }) s.Run("image pull failure", func() { + testUtil.InitTestConfig(testUtil.LocalPlatform) + store := keychain.NewTestStore() + _ = store.SetCredentials("localhost", keychain.Credentials{Token: "Bearer test-token"}) imageHandler := new(mocks.ImageHandler) mockPlatformCoreClient := new(astroplatformcore_mocks.ClientWithResponsesInterface) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Twice() @@ -1384,8 +1389,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { imageHandler.On("Pull", mock.Anything, mock.Anything, mock.Anything).Return(errMockDocker) mockDockerCompose.imageHandler = imageHandler - // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "test-deployment-id", "", "", false, false, false, false, false, "", mockPlatformCoreClient) + err := mockDockerCompose.UpgradeTest("new-version", "test-deployment-id", "", "", false, false, false, false, false, "", mockPlatformCoreClient, store) s.Error(err) imageHandler.AssertExpectations(s.T()) // Only Pull is called }) @@ -1399,7 +1403,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler // Add default values for new lint flags - err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, false, false, false, false, "", nil) // versionTest=true is required for this path + err := mockDockerCompose.UpgradeTest("new-version", "", "", "", true, false, false, false, false, "", nil, nil) // versionTest=true is required for this path s.Error(err) imageHandler.AssertExpectations(s.T()) }) @@ -1411,7 +1415,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { s.NoError(err) // Add default values for new lint flags - err = mockDockerCompose.UpgradeTest("new-version", "deployment-id", "", "", false, false, false, false, false, "", nil) + err = mockDockerCompose.UpgradeTest("new-version", "deployment-id", "", "", false, false, false, false, false, "", nil, nil) s.Error(err) // Expect error due to missing context/domain }) @@ -1429,7 +1433,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler mockDockerCompose.ruffImageHandler = ruffImageHandler // Call with lintTest=true, includeLintDeprecations=false, lintFix=false, lintConfigFile="" - err := mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, false, false, "", nil) + err := mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, false, false, "", nil, nil) s.NoError(err) imageHandler.AssertExpectations(s.T()) @@ -1448,7 +1452,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler mockDockerCompose.ruffImageHandler = ruffImageHandler // Call with lintTest=true, includeLintDeprecations=true, lintFix=false, lintConfigFile="" - err := mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, true, false, "", nil) + err := mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, true, false, "", nil, nil) s.NoError(err) imageHandler.AssertExpectations(s.T()) @@ -1482,7 +1486,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { originalWorkingPath := config.WorkingPath config.WorkingPath = cwd defer func() { config.WorkingPath = originalWorkingPath }() - err = mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, false, false, "my-custom-ruff.toml", nil) + err = mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, false, false, "my-custom-ruff.toml", nil, nil) s.NoError(err) imageHandler.AssertExpectations(s.T()) @@ -1501,7 +1505,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler mockDockerCompose.ruffImageHandler = ruffImageHandler // Call with lintTest=true, includeLintDeprecations=false, lintFix=false, lintConfigFile="" - err := mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, false, false, "", nil) + err := mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, false, false, "", nil, nil) s.Error(err) s.Contains(err.Error(), "one of the tests run above failed") @@ -1520,7 +1524,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.ruffImageHandler = ruffImageHandler // Call with lintTest=true, includeLintDeprecations=false, lintFix=false, lintConfigFile="" // Target version is 2.0.0, so lint test should be skipped internally - err := mockDockerCompose.UpgradeTest("2.0.0", "", "", "", false, false, true, false, false, "", nil) + err := mockDockerCompose.UpgradeTest("2.0.0", "", "", "", false, false, true, false, false, "", nil, nil) s.NoError(err) // Should succeed without running lint imageHandler.AssertExpectations(s.T()) @@ -1540,7 +1544,7 @@ func (s *Suite) TestDockerComposeUpgradeTest() { mockDockerCompose.imageHandler = imageHandler mockDockerCompose.ruffImageHandler = ruffImageHandler // Call with lintTest=true, includeLintDeprecations=false, lintFix=true, lintConfigFile="" - err := mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, false, true, "", nil) + err := mockDockerCompose.UpgradeTest("3.0-1", "", "", "", false, false, true, false, true, "", nil, nil) s.NoError(err) imageHandler.AssertExpectations(s.T()) diff --git a/airflow/mocks/ContainerHandler.go b/airflow/mocks/ContainerHandler.go index a07489721..d3962f048 100644 --- a/airflow/mocks/ContainerHandler.go +++ b/airflow/mocks/ContainerHandler.go @@ -5,6 +5,7 @@ package mocks import ( "github.com/astronomer/astro-cli/airflow/types" astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" + "github.com/astronomer/astro-cli/pkg/keychain" mock "github.com/stretchr/testify/mock" ) @@ -295,17 +296,17 @@ func (_m *ContainerHandler) Stop(waitForExit bool) error { return r0 } -// UpgradeTest provides a mock function with given fields: runtimeVersion, deploymentID, customImageName, buildSecretString, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix, lintConfigFile, astroPlatformCore -func (_m *ContainerHandler) UpgradeTest(runtimeVersion string, deploymentID string, customImageName string, buildSecretString string, versionTest bool, dagTest bool, lintTest bool, includeLintDeprecations bool, lintFix bool, lintConfigFile string, astroPlatformCore astroplatformcore.ClientWithResponsesInterface) error { - ret := _m.Called(runtimeVersion, deploymentID, customImageName, buildSecretString, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix, lintConfigFile, astroPlatformCore) +// UpgradeTest provides a mock function with given fields: runtimeVersion, deploymentID, customImageName, buildSecretString, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix, lintConfigFile, astroPlatformCore, store +func (_m *ContainerHandler) UpgradeTest(runtimeVersion string, deploymentID string, customImageName string, buildSecretString string, versionTest bool, dagTest bool, lintTest bool, includeLintDeprecations bool, lintFix bool, lintConfigFile string, astroPlatformCore astroplatformcore.ClientWithResponsesInterface, store keychain.SecureStore) error { + ret := _m.Called(runtimeVersion, deploymentID, customImageName, buildSecretString, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix, lintConfigFile, astroPlatformCore, store) if len(ret) == 0 { panic("no return value specified for UpgradeTest") } var r0 error - if rf, ok := ret.Get(0).(func(string, string, string, string, bool, bool, bool, bool, bool, string, astroplatformcore.ClientWithResponsesInterface) error); ok { - r0 = rf(runtimeVersion, deploymentID, customImageName, buildSecretString, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix, lintConfigFile, astroPlatformCore) + if rf, ok := ret.Get(0).(func(string, string, string, string, bool, bool, bool, bool, bool, string, astroplatformcore.ClientWithResponsesInterface, keychain.SecureStore) error); ok { + r0 = rf(runtimeVersion, deploymentID, customImageName, buildSecretString, versionTest, dagTest, lintTest, includeLintDeprecations, lintFix, lintConfigFile, astroPlatformCore, store) } else { r0 = ret.Error(0) } diff --git a/airflow/standalone.go b/airflow/standalone.go index 5993a4183..272909d72 100644 --- a/airflow/standalone.go +++ b/airflow/standalone.go @@ -29,6 +29,7 @@ import ( "github.com/astronomer/astro-cli/pkg/airflowrt" "github.com/astronomer/astro-cli/pkg/ansi" "github.com/astronomer/astro-cli/pkg/fileutil" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" "github.com/astronomer/astro-cli/pkg/spinner" "github.com/astronomer/astro-cli/pkg/util" @@ -990,7 +991,7 @@ func (s *Standalone) Parse(_, _, _ string) error { return nil } -func (s *Standalone) UpgradeTest(_, _, _, _ string, _, _, _, _, _ bool, _ string, _ astroplatformcore.ClientWithResponsesInterface) error { +func (s *Standalone) UpgradeTest(_, _, _, _ string, _, _, _, _, _ bool, _ string, _ astroplatformcore.ClientWithResponsesInterface, _ keychain.SecureStore) error { return errors.New("astro dev upgrade-test is not available in standalone mode") } diff --git a/airflow/standalone_test.go b/airflow/standalone_test.go index 4c11ee2d0..42af72748 100644 --- a/airflow/standalone_test.go +++ b/airflow/standalone_test.go @@ -238,7 +238,7 @@ func (s *Suite) TestStandaloneUnsupportedCommands() { s.Error(composeErr) s.Contains(composeErr.Error(), "not available in standalone mode") - upgradeErr := handler.UpgradeTest("", "", "", "", false, false, false, false, false, "", nil) + upgradeErr := handler.UpgradeTest("", "", "", "", false, false, false, false, false, "", nil, nil) s.Error(upgradeErr) s.Contains(upgradeErr.Error(), "not available in standalone mode") } diff --git a/airflow/standalone_windows.go b/airflow/standalone_windows.go index 80e49bdf7..57125d66f 100644 --- a/airflow/standalone_windows.go +++ b/airflow/standalone_windows.go @@ -10,6 +10,7 @@ import ( "github.com/astronomer/astro-cli/airflow/types" airflowversions "github.com/astronomer/astro-cli/airflow_versions" astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" + "github.com/astronomer/astro-cli/pkg/keychain" ) var ( @@ -51,6 +52,6 @@ func (s *Standalone) Pytest(_, _, _, _, _ string) (string, error) { return "", errStandaloneWindows } func (s *Standalone) Parse(_, _, _ string) error { return errStandaloneWindows } -func (s *Standalone) UpgradeTest(_, _, _, _ string, _, _, _, _, _ bool, _ string, _ astroplatformcore.ClientWithResponsesInterface) error { +func (s *Standalone) UpgradeTest(_, _, _, _ string, _, _, _, _, _ bool, _ string, _ astroplatformcore.ClientWithResponsesInterface, _ keychain.SecureStore) error { return errStandaloneWindows } diff --git a/astro-client-core/client.go b/astro-client-core/client.go index f0df34b1a..d9ee55873 100644 --- a/astro-client-core/client.go +++ b/astro-client-core/client.go @@ -12,15 +12,16 @@ var NormalizeAPIError = httputil.NormalizeAPIError // a shorter alias type CoreClient = ClientWithResponsesInterface -// create api client for astro core services -func NewCoreClient(c *httputil.HTTPClient) *ClientWithResponses { - // we append base url in request editor, so set to an empty string here +// NewCoreClient creates an API client for astro core services. +// The provided TokenHolder is read on every request — set it via +// TokenHolder.Set after credentials are resolved in PersistentPreRunE. +func NewCoreClient(c *httputil.HTTPClient, holder *httputil.TokenHolder) *ClientWithResponses { cl, _ := NewClientWithResponses("", WithHTTPClient(c.HTTPClient), WithRequestEditorFn(httputil.NewRequestEditorFn(func() (string, string, error) { ctx, err := context.GetCurrentContext() if err != nil { return "", "", err } - return ctx.Token, ctx.GetPublicRESTAPIURL("v1alpha1"), nil + return holder.Get(), ctx.GetPublicRESTAPIURL("v1alpha1"), nil }))) return cl } diff --git a/astro-client-core/client.test.go b/astro-client-core/client.test.go index 5e28721ba..ec950ffe0 100644 --- a/astro-client-core/client.test.go +++ b/astro-client-core/client.test.go @@ -9,6 +9,6 @@ import ( ) func TestNewCoreClient(t *testing.T) { - client := NewCoreClient(httputil.NewHTTPClient()) + client := NewCoreClient(httputil.NewHTTPClient(), &httputil.TokenHolder{}) assert.NotNil(t, client, "Can't create new Astro Core client") } diff --git a/astro-client-iam-core/client.go b/astro-client-iam-core/client.go index 94a1fdc9c..1a820de41 100644 --- a/astro-client-iam-core/client.go +++ b/astro-client-iam-core/client.go @@ -12,14 +12,13 @@ var NormalizeAPIError = httputil.NormalizeAPIError // a shorter alias type CoreClient = ClientWithResponsesInterface -func NewIamCoreClient(c *httputil.HTTPClient) *ClientWithResponses { - // we append base url in request editor, so set to an empty string here +func NewIamCoreClient(c *httputil.HTTPClient, holder *httputil.TokenHolder) *ClientWithResponses { cl, _ := NewClientWithResponses("", WithHTTPClient(c.HTTPClient), WithRequestEditorFn(httputil.NewRequestEditorFn(func() (string, string, error) { ctx, err := context.GetCurrentContext() if err != nil { return "", "", err } - return ctx.Token, ctx.GetPublicRESTAPIURL("iam/v1beta1"), nil + return holder.Get(), ctx.GetPublicRESTAPIURL("iam/v1beta1"), nil }))) return cl } diff --git a/astro-client-iam-core/client.test.go b/astro-client-iam-core/client.test.go index 63a7dfa11..a41af28a6 100644 --- a/astro-client-iam-core/client.test.go +++ b/astro-client-iam-core/client.test.go @@ -9,6 +9,6 @@ import ( ) func TestNewIamCoreClient(t *testing.T) { - client := NewIamCoreClient(httputil.NewHTTPClient()) + client := NewIamCoreClient(httputil.NewHTTPClient(), &httputil.TokenHolder{}) assert.NotNil(t, client, "Can't create new Astro IAM Core client") } diff --git a/astro-client-platform-core/client.go b/astro-client-platform-core/client.go index c78f8217c..610d9bfcc 100644 --- a/astro-client-platform-core/client.go +++ b/astro-client-platform-core/client.go @@ -12,15 +12,13 @@ var NormalizeAPIError = httputil.NormalizeAPIError // a shorter alias type CoreClient = ClientWithResponsesInterface -// create api client for astro platform core services -func NewPlatformCoreClient(c *httputil.HTTPClient) *ClientWithResponses { - // we append base url in request editor, so set to an empty string here +func NewPlatformCoreClient(c *httputil.HTTPClient, holder *httputil.TokenHolder) *ClientWithResponses { cl, _ := NewClientWithResponses("", WithHTTPClient(c.HTTPClient), WithRequestEditorFn(httputil.NewRequestEditorFn(func() (string, string, error) { ctx, err := context.GetCurrentContext() if err != nil { return "", "", err } - return ctx.Token, ctx.GetPublicRESTAPIURL("platform/v1beta1"), nil + return holder.Get(), ctx.GetPublicRESTAPIURL("platform/v1beta1"), nil }))) return cl } diff --git a/cloud/auth/auth.go b/cloud/auth/auth.go index b17ef7e99..2db1e3e2a 100644 --- a/cloud/auth/auth.go +++ b/cloud/auth/auth.go @@ -24,6 +24,7 @@ import ( "github.com/astronomer/astro-cli/pkg/astroauth" "github.com/astronomer/astro-cli/pkg/domainutil" "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" "github.com/astronomer/astro-cli/pkg/util" ) @@ -340,7 +341,7 @@ func CheckUserSession(c *config.Context, coreClient astrocore.CoreClient, platfo } // Login handles authentication to astronomer api and registry -func Login(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { +func Login(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { var res Result domain = domainutil.FormatDomain(domain) authConfig, err := FetchDomainAuthConfig(domain) @@ -387,9 +388,20 @@ func Login(domain, token string, coreClient astrocore.CoreClient, platformCoreCl return err } - err = res.writeToContext(&c) - if err != nil { - return err + creds := keychain.Credentials{ + Token: "Bearer " + res.AccessToken, + RefreshToken: res.RefreshToken, + UserEmail: res.UserEmail, + ExpiresAt: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second), + } + if store == nil { + return fmt.Errorf("credential store not available; cannot save login credentials") + } + if err := store.SetCredentials(domain, creds); err != nil { + return fmt.Errorf("storing credentials: %w", err) + } + if tokenHolder != nil { + tokenHolder.Set(creds.Token) } fmt.Printf("Logging in as %s\n", ansi.Green(res.UserEmail)) @@ -404,16 +416,11 @@ func Login(domain, token string, coreClient astrocore.CoreClient, platformCoreCl } // Logout logs a user out of the docker registry. Will need to logout of Astro next. -func Logout(domain string, out io.Writer) { - c, _ := context.GetContext(domain) - - err = c.SetContextKey("token", "") - if err != nil { - return - } - err = c.SetContextKey("user_email", "") - if err != nil { - return +func Logout(domain string, store keychain.SecureStore, out io.Writer) { + if store == nil { + fmt.Fprintln(out, "Warning: credential store not available; local credentials may not be cleared") + } else if err := store.DeleteCredentials(domain); err != nil { + fmt.Fprintf(out, "Failed to remove credentials from secure store: %s\n", err.Error()) } // remove the current context diff --git a/cloud/auth/auth_test.go b/cloud/auth/auth_test.go index 7fe8597ea..99fc85078 100644 --- a/cloud/auth/auth_test.go +++ b/cloud/auth/auth_test.go @@ -12,14 +12,15 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" astrocore "github.com/astronomer/astro-cli/astro-client-core" astrocore_mocks "github.com/astronomer/astro-cli/astro-client-core/mocks" astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" astroplatformcore_mocks "github.com/astronomer/astro-cli/astro-client-platform-core/mocks" "github.com/astronomer/astro-cli/config" - "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -680,8 +681,12 @@ func TestLogin(t *testing.T) { mockCoreClient.On("GetSelfUserWithResponse", mock.Anything, mock.Anything).Return(&mockGetSelfResponse, nil).Once() mockPlatformCoreClient.On("ListOrganizationsWithResponse", mock.Anything, &astroplatformcore.ListOrganizationsParams{}).Return(&mockOrganizationsResponse, nil).Once() mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Once() - err := Login("astronomer.io", "", mockCoreClient, mockPlatformCoreClient, os.Stdout, false) + store := keychain.NewTestStore() + err := Login("astronomer.io", "", store, nil, mockCoreClient, mockPlatformCoreClient, os.Stdout, false) assert.NoError(t, err) + creds, err := store.GetCredentials("astronomer.io") + require.NoError(t, err) + assert.Equal(t, "Bearer test-token", creds.Token) mockCoreClient.AssertExpectations(t) mockPlatformCoreClient.AssertExpectations(t) }) @@ -723,7 +728,7 @@ func TestLogin(t *testing.T) { mockCoreClient.On("GetSelfUserWithResponse", mock.Anything, mock.Anything).Return(&mockGetSelfResponse, nil).Once() mockPlatformCoreClient.On("ListOrganizationsWithResponse", mock.Anything, &astroplatformcore.ListOrganizationsParams{}).Return(&mockOrganizationsResponse, nil).Once() - err = Login("pr5723.cloud.astronomer-dev.io", "", mockCoreClient, mockPlatformCoreClient, os.Stdout, false) + err = Login("pr5723.cloud.astronomer-dev.io", "", keychain.NewTestStore(), nil, mockCoreClient, mockPlatformCoreClient, os.Stdout, false) assert.NoError(t, err) mockCoreClient.AssertExpectations(t) mockPlatformCoreClient.AssertExpectations(t) @@ -752,14 +757,14 @@ func TestLogin(t *testing.T) { mockCoreClient.On("GetSelfUserWithResponse", mock.Anything, mock.Anything).Return(&mockGetSelfResponse, nil).Once() mockPlatformCoreClient.On("ListOrganizationsWithResponse", mock.Anything, &astroplatformcore.ListOrganizationsParams{}).Return(&mockOrganizationsResponse, nil).Once() - err := Login("astronomer.io", "OAuth Token", mockCoreClient, mockPlatformCoreClient, os.Stdout, false) + err := Login("astronomer.io", "OAuth Token", keychain.NewTestStore(), nil, mockCoreClient, mockPlatformCoreClient, os.Stdout, false) assert.NoError(t, err) mockCoreClient.AssertExpectations(t) mockPlatformCoreClient.AssertExpectations(t) }) t.Run("invalid domain", func(t *testing.T) { - err := Login("fail.astronomer.io", "", nil, nil, os.Stdout, false) + err := Login("fail.astronomer.io", "", nil, nil, nil, nil, os.Stdout, false) assert.Error(t, err) assert.Contains(t, err.Error(), "Invalid domain.") }) @@ -769,7 +774,7 @@ func TestLogin(t *testing.T) { return "", errMock } authenticator = Authenticator{callbackHandler: callbackHandler} - err := Login("cloud.astronomer.io", "", nil, nil, os.Stdout, false) + err := Login("cloud.astronomer.io", "", nil, nil, nil, nil, os.Stdout, false) assert.ErrorIs(t, err, errMock) }) @@ -793,7 +798,7 @@ func TestLogin(t *testing.T) { mockCoreClient := new(astrocore_mocks.ClientWithResponsesInterface) mockPlatformCoreClient := new(astroplatformcore_mocks.ClientWithResponsesInterface) mockCoreClient.On("GetSelfUserWithResponse", mock.Anything, mock.Anything).Return(&mockGetSelfErrorResponse, nil).Once() - err := Login("", "", mockCoreClient, mockPlatformCoreClient, os.Stdout, false) + err := Login("", "", keychain.NewTestStore(), nil, mockCoreClient, mockPlatformCoreClient, os.Stdout, false) assert.Contains(t, err.Error(), "failed to fetch self user") mockCoreClient.AssertExpectations(t) mockPlatformCoreClient.AssertExpectations(t) @@ -827,7 +832,7 @@ func TestLogin(t *testing.T) { // initialize stdin with user email input defer testUtil.MockUserInput(t, "test.user@astronomer.io")() // do the test - err = Login("astronomer.io", "", mockCoreClient, mockPlatformCoreClient, os.Stdout, true) + err = Login("astronomer.io", "", keychain.NewTestStore(), nil, mockCoreClient, mockPlatformCoreClient, os.Stdout, true) assert.NoError(t, err) mockCoreClient.AssertExpectations(t) mockPlatformCoreClient.AssertExpectations(t) @@ -860,15 +865,13 @@ func TestLogin(t *testing.T) { } // initialize user input with email defer testUtil.MockUserInput(t, "test.user@astronomer.io")() - err := Login("astronomer.io", "", mockCoreClient, mockPlatformCoreClient, os.Stdout, true) + store := keychain.NewTestStore() + err := Login("astronomer.io", "", store, nil, mockCoreClient, mockPlatformCoreClient, os.Stdout, true) assert.NoError(t, err) - // assert that everything got set in the right spot - domainContext, err := context.GetContext("astronomer.io") - assert.NoError(t, err) - currentContext, err := context.GetContext("localhost") - assert.NoError(t, err) - assert.Equal(t, domainContext.Token, "Bearer access_token") - assert.Equal(t, currentContext.Token, "token") + // assert that credentials were stored in the keychain + creds, err := store.GetCredentials("astronomer.io") + require.NoError(t, err) + assert.Equal(t, "Bearer access_token", creds.Token) mockCoreClient.AssertExpectations(t) mockPlatformCoreClient.AssertExpectations(t) }) @@ -878,75 +881,61 @@ func TestLogout(t *testing.T) { testUtil.InitTestConfig(testUtil.LocalPlatform) t.Run("success", func(t *testing.T) { buf := new(bytes.Buffer) - Logout("astronomer.io", buf) + Logout("astronomer.io", keychain.NewTestStore(), buf) assert.Equal(t, "Successfully logged out of Astronomer\n", buf.String()) }) - t.Run("success_with_email", func(t *testing.T) { - assertions := func(expUserEmail string, expToken string) { - contexts, err := config.GetContexts() - assert.NoError(t, err) - context := contexts.Contexts["localhost"] - - assert.NoError(t, err) - assert.Equal(t, expUserEmail, context.UserEmail) - assert.Equal(t, expToken, context.Token) - } + t.Run("success_with_credentials_deleted", func(t *testing.T) { testUtil.InitTestConfig(testUtil.LocalPlatform) - c, err := config.GetCurrentContext() - assert.NoError(t, err) - err = c.SetContextKey("user_email", "test.user@astronomer.io") - assert.NoError(t, err) - err = c.SetContextKey("token", "Bearer some-token") - assert.NoError(t, err) - // test before - assertions("test.user@astronomer.io", "Bearer some-token") + store := keychain.NewTestStore() + err := store.SetCredentials("localhost", keychain.Credentials{ + Token: "Bearer some-token", + UserEmail: "test.user@astronomer.io", + }) + require.NoError(t, err) - // log out - c, err = config.GetCurrentContext() - assert.NoError(t, err) - Logout(c.Domain, os.Stdout) + c, err := config.GetCurrentContext() + require.NoError(t, err) + Logout(c.Domain, store, os.Stdout) - // test after logout - assertions("", "") + _, err = store.GetCredentials("localhost") + assert.ErrorIs(t, err, keychain.ErrNotFound) }) } -func Test_writeResultToContext(t *testing.T) { - assertConfigContents := func(expToken string, expRefresh string, expExpires time.Time, expUserEmail string) { - context, err := config.GetCurrentContext() - assert.NoError(t, err) - // test the output on the config file - assert.Equal(t, expToken, context.Token) - assert.Equal(t, expRefresh, context.RefreshToken) - expiresIn, err := context.GetExpiresIn() - assert.NoError(t, err) - assert.Equal(t, expExpires.Round(time.Second), expiresIn.Round(time.Second)) - assert.Equal(t, expUserEmail, context.UserEmail) - assert.NoError(t, err) - } +func TestLogin_storesCredentialsInKeychain(t *testing.T) { testUtil.InitTestConfig(testUtil.LocalPlatform) - c, err := config.GetCurrentContext() - assert.NoError(t, err) - err = c.SetContextKey("token", "old_token") - assert.NoError(t, err) - // test input - res := Result{ - AccessToken: "new_token", - RefreshToken: "new_refresh_token", - ExpiresIn: 1234, - UserEmail: "test.user@astronomer.io", + mockUserInfo := UserInfo{Email: "test.user@astronomer.io"} + userInfoRequester := func(authConfig Config, accessToken string) (UserInfo, error) { + return mockUserInfo, nil } - // test before changes - var timeZero time.Time - assertConfigContents("old_token", "", timeZero, "") - - // apply function - c, err = config.GetCurrentContext() - assert.NoError(t, err) - err = res.writeToContext(&c) - assert.NoError(t, err) - - // test after changes - assertConfigContents("Bearer new_token", "new_refresh_token", time.Now().Add(1234*time.Second), "test.user@astronomer.io") + authenticator = Authenticator{ + userInfoRequester: userInfoRequester, + callbackHandler: func() (string, error) { return "authorizationCode", nil }, + tokenRequester: func(authConfig Config, verifier, code string) (Result, error) { + return Result{ + RefreshToken: "new_refresh_token", + AccessToken: "new_access_token", + ExpiresIn: 1234, + }, nil + }, + } + mockCoreClient := new(astrocore_mocks.ClientWithResponsesInterface) + mockPlatformCoreClient := new(astroplatformcore_mocks.ClientWithResponsesInterface) + mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Once() + mockCoreClient.On("GetSelfUserWithResponse", mock.Anything, mock.Anything).Return(&mockGetSelfResponse, nil).Once() + mockPlatformCoreClient.On("ListOrganizationsWithResponse", mock.Anything, &astroplatformcore.ListOrganizationsParams{}).Return(&mockOrganizationsResponse, nil).Once() + + store := keychain.NewTestStore() + err := Login("astronomer.io", "", store, nil, mockCoreClient, mockPlatformCoreClient, os.Stdout, true) + require.NoError(t, err) + + creds, err := store.GetCredentials("astronomer.io") + require.NoError(t, err) + assert.Equal(t, "Bearer new_access_token", creds.Token) + assert.Equal(t, "new_refresh_token", creds.RefreshToken) + assert.Equal(t, "test.user@astronomer.io", creds.UserEmail) + assert.WithinDuration(t, time.Now().Add(1234*time.Second), creds.ExpiresAt, 5*time.Second) + mockCoreClient.AssertExpectations(t) + mockPlatformCoreClient.AssertExpectations(t) } diff --git a/cloud/auth/types.go b/cloud/auth/types.go index f3846b488..e4feb90e5 100644 --- a/cloud/auth/types.go +++ b/cloud/auth/types.go @@ -1,7 +1,6 @@ package auth import ( - "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/pkg/astroauth" ) @@ -43,26 +42,3 @@ type CallbackMessage struct { authorizationCode string errorMessage string } - -func (res Result) writeToContext(c *config.Context) error { - err = c.SetContextKey("token", "Bearer "+res.AccessToken) - if err != nil { - return err - } - - err = c.SetContextKey("refreshtoken", res.RefreshToken) - if err != nil { - return err - } - - err = c.SetExpiresIn(res.ExpiresIn) - if err != nil { - return err - } - - err = c.SetContextKey("user_email", res.UserEmail) - if err != nil { - return err - } - return nil -} diff --git a/cloud/deploy/bundle.go b/cloud/deploy/bundle.go index ef154cd9e..accfeeb9b 100644 --- a/cloud/deploy/bundle.go +++ b/cloud/deploy/bundle.go @@ -44,7 +44,7 @@ func DeployBundle(input *DeployBundleInput) error { } // if CI/CD is enforced, check the subject can deploy - if currentDeployment.IsCicdEnforced && !canCiCdDeploy(c.Token) { + if currentDeployment.IsCicdEnforced && !canCiCdDeploy("Bearer "+os.Getenv("ASTRO_API_TOKEN")) { return fmt.Errorf(errCiCdEnforcementUpdate, currentDeployment.Name) } diff --git a/cloud/deploy/deploy.go b/cloud/deploy/deploy.go index 71d181244..f06f65ec8 100644 --- a/cloud/deploy/deploy.go +++ b/cloud/deploy/deploy.go @@ -224,7 +224,7 @@ func Deploy(deployInput InputDeploy, platformCoreClient astroplatformcore.CoreCl } if deployInfo.cicdEnforcement { - if !canCiCdDeploy(c.Token) { + if !canCiCdDeploy("Bearer " + os.Getenv("ASTRO_API_TOKEN")) { return fmt.Errorf(errCiCdEnforcementUpdate, deployInfo.name) //nolint } } @@ -413,7 +413,7 @@ func Deploy(deployInput InputDeploy, platformCoreClient astroplatformcore.CoreCl imageHandler := airflowImageHandler(deployInfo.deployImage) fmt.Println("Pushing image to Astronomer registry") - _, err = imageHandler.Push(remoteImage, registryUsername, c.Token, false) + _, err = imageHandler.Push(remoteImage, registryUsername, "Bearer "+os.Getenv("ASTRO_API_TOKEN"), false) if err != nil { return err } @@ -920,7 +920,7 @@ func setupClientDependencyFiles(buildDir string) error { // DeployClientImage handles the client deploy functionality func DeployClientImage(deployInput InputClientDeploy, platformCoreClient astroplatformcore.CoreClient) error { //nolint:gocritic - c, err := config.GetCurrentContext() + _, err := config.GetCurrentContext() if err != nil { return errors.Wrap(err, "failed to get current context") } @@ -969,7 +969,7 @@ func DeployClientImage(deployInput InputClientDeploy, platformCoreClient astropl } baseImageRegistry := config.CFG.RemoteBaseImageRegistry.GetString() fmt.Printf("Authenticating with base image registry: %s\n", baseImageRegistry) - err := airflow.DockerLogin(baseImageRegistry, registryUsername, c.Token) + err := airflow.DockerLogin(baseImageRegistry, registryUsername, "Bearer "+os.Getenv("ASTRO_API_TOKEN")) if err != nil { fmt.Println("Failed to authenticate with Astronomer registry that contains the base agent image used in the Dockerfile.client file.") fmt.Println("This could be because either your token has expired or you don't have permission to pull the base agent image.") diff --git a/cloud/deploy/deploy_test.go b/cloud/deploy/deploy_test.go index 74dcbcf2f..498971f21 100644 --- a/cloud/deploy/deploy_test.go +++ b/cloud/deploy/deploy_test.go @@ -209,7 +209,7 @@ func TestDeployWithoutDagsDeploySuccess(t *testing.T) { ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test testing" + err = ctx.SetContext() assert.NoError(t, err) @@ -309,7 +309,7 @@ func TestDeployOnRemoteExecutionDeployment(t *testing.T) { ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test testing" + err = ctx.SetContext() assert.NoError(t, err) @@ -452,7 +452,7 @@ func TestDeployWithDagsDeploySuccess(t *testing.T) { ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test testing" + err = ctx.SetContext() assert.NoError(t, err) @@ -681,7 +681,7 @@ func TestNoDagsDeploy(t *testing.T) { ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test testing" + err = ctx.SetContext() assert.NoError(t, err) @@ -717,7 +717,6 @@ func TestNoDagsDeployForceSkipsPrompt(t *testing.T) { ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test testing" err = ctx.SetContext() assert.NoError(t, err) @@ -753,7 +752,6 @@ func TestNoDagsImageDeployForceSkipsPrompt(t *testing.T) { ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test testing" err = ctx.SetContext() assert.NoError(t, err) @@ -1374,6 +1372,8 @@ func TestDeployClientImage(t *testing.T) { }() t.Run("successful client deploy", func(t *testing.T) { + t.Setenv("ASTRO_API_TOKEN", "test-token") + // Set up temporary directory with Dockerfile.client tempDir, err := os.MkdirTemp("", "test-deploy-*") assert.NoError(t, err) @@ -1394,7 +1394,7 @@ func TestDeployClientImage(t *testing.T) { testUtil.InitTestConfig(testUtil.CloudPlatform) ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test-token" + err = ctx.SetContext() assert.NoError(t, err) // Mock DockerLogin @@ -1435,7 +1435,7 @@ func TestDeployClientImage(t *testing.T) { assert.True(t, dockerLoginCalled, "DockerLogin should have been called") assert.Equal(t, "images.astronomer.cloud", capturedRegistry) assert.Equal(t, "cli", capturedUsername) - assert.Equal(t, "test-token", capturedToken) + assert.Equal(t, "Bearer test-token", capturedToken) mockImageHandler.AssertExpectations(t) }) @@ -1444,7 +1444,7 @@ func TestDeployClientImage(t *testing.T) { testUtil.InitTestConfig(testUtil.CloudPlatform) ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test-token" + err = ctx.SetContext() assert.NoError(t, err) @@ -1470,7 +1470,7 @@ func TestDeployClientImage(t *testing.T) { testUtil.InitTestConfig(testUtil.CloudPlatform) ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test-token" + err = ctx.SetContext() assert.NoError(t, err) @@ -1515,7 +1515,7 @@ func TestDeployClientImage(t *testing.T) { testUtil.InitTestConfig(testUtil.CloudPlatform) ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test-token" + err = ctx.SetContext() assert.NoError(t, err) @@ -1571,7 +1571,7 @@ func TestDeployClientImage(t *testing.T) { testUtil.InitTestConfig(testUtil.CloudPlatform) ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test-token" + err = ctx.SetContext() assert.NoError(t, err) @@ -1612,7 +1612,7 @@ func TestDeployClientImage(t *testing.T) { testUtil.InitTestConfig(testUtil.CloudPlatform) ctx, err := config.GetCurrentContext() assert.NoError(t, err) - ctx.Token = "test-token" + err = ctx.SetContext() assert.NoError(t, err) diff --git a/cloud/deployment/deployment.go b/cloud/deployment/deployment.go index 8dd9a3483..3e7229da9 100644 --- a/cloud/deployment/deployment.go +++ b/cloud/deployment/deployment.go @@ -893,7 +893,7 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec isCicdEnforced = true } if !force && isCicdEnforced && dagDeploy != "" { - if !canCiCdDeploy(c.Token) { + if !canCiCdDeploy("Bearer " + os.Getenv("ASTRO_API_TOKEN")) { fmt.Printf("\nWarning: You are trying to update the dag deploy setting with ci-cd enforcement enabled. Once the setting is updated, you will not be able to deploy your dags using the CLI. Until you deploy your dags, dags will not be visible in the UI nor will new tasks start." + "\nAfter the setting is updated, either disable cicd enforcement and then deploy your dags OR deploy your dags via CICD or using API Tokens.") y, _ := input.Confirm("\n\nAre you sure you want to continue?") diff --git a/cloud/deployment/fromfile/fromfile.go b/cloud/deployment/fromfile/fromfile.go index 8dbe6a9a0..4ddb2e225 100644 --- a/cloud/deployment/fromfile/fromfile.go +++ b/cloud/deployment/fromfile/fromfile.go @@ -694,11 +694,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c } // update deployment if !force && deploymentFromFile.Deployment.Configuration.APIKeyOnlyDeployments && dagDeploy { - c, err := config.GetCurrentContext() - if err != nil { - return err - } - if !canCiCdDeploy(c.Token) { + if !canCiCdDeploy("Bearer " + os.Getenv("ASTRO_API_TOKEN")) { fmt.Printf("\nWarning: You are trying to update dag deploy setting on a deployment with ci-cd enforcement enabled. You will not be able to deploy your dags using the CLI and that dags will not be visible in the UI and new tasks will not start." + "\nEither disable ci-cd enforcement or please cancel this operation and use API Tokens instead.") y, _ := input.Confirm("\n\nAre you sure you want to continue?") diff --git a/cloud/organization/organization.go b/cloud/organization/organization.go index d17c6c83a..d5a3883d5 100644 --- a/cloud/organization/organization.go +++ b/cloud/organization/organization.go @@ -200,10 +200,6 @@ func SwitchWithContext(domain string, targetOrg *astroplatformcore.Organization, orgProduct = fmt.Sprintf("%s", *targetOrg.Product) //nolint } _ = c.SetOrganizationContext(targetOrg.Id, orgProduct) - // need to reset all relevant keys because of https://github.com/spf13/viper/issues/1106 :shrug - _ = c.SetContextKey("token", c.Token) - _ = c.SetContextKey("refreshtoken", c.RefreshToken) - _ = c.SetContextKey("user_email", c.UserEmail) c, _ = context.GetCurrentContext() // call check user session which will trigger workspace switcher flow err := CheckUserSession(&c, coreClient, platformCoreClient, out) diff --git a/cloud/platformclient/client.go b/cloud/platformclient/client.go index ae5e92af1..b1b10294b 100644 --- a/cloud/platformclient/client.go +++ b/cloud/platformclient/client.go @@ -6,6 +6,6 @@ import ( ) // NewPlatformCoreClient creates an API client for Astro platform core services. -func NewPlatformCoreClient(c *httputil.HTTPClient) *astroplatformcore.ClientWithResponses { - return astroplatformcore.NewPlatformCoreClient(c) +func NewPlatformCoreClient(c *httputil.HTTPClient, holder *httputil.TokenHolder) *astroplatformcore.ClientWithResponses { + return astroplatformcore.NewPlatformCoreClient(c, holder) } diff --git a/cloud/platformclient/client_test.go b/cloud/platformclient/client_test.go index 809f4e234..414c0dd4e 100644 --- a/cloud/platformclient/client_test.go +++ b/cloud/platformclient/client_test.go @@ -9,6 +9,6 @@ import ( ) func TestNewPlatformCoreClient(t *testing.T) { - client := NewPlatformCoreClient(httputil.NewHTTPClient()) + client := NewPlatformCoreClient(httputil.NewHTTPClient(), &httputil.TokenHolder{}) assert.NotNil(t, client, "Can't create new Astro Platform Core client") } diff --git a/cmd/airflow.go b/cmd/airflow.go index 618df9bf3..e4e514e94 100644 --- a/cmd/airflow.go +++ b/cmd/airflow.go @@ -29,6 +29,7 @@ import ( "github.com/astronomer/astro-cli/pkg/fileutil" "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/input" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/output" "github.com/astronomer/astro-cli/pkg/util" ) @@ -150,7 +151,7 @@ astro dev init --remote-execution-enabled --remote-image-repository quay.io/acme proxyPortFlag string ) -func newDevRootCmd(platformCoreClient astroplatformcore.CoreClient, astroCoreClient astrocore.CoreClient) *cobra.Command { +func newDevRootCmd(platformCoreClient astroplatformcore.CoreClient, astroCoreClient astrocore.CoreClient, store keychain.SecureStore) *cobra.Command { cmd := &cobra.Command{ Use: "dev", Aliases: []string{"d"}, @@ -182,7 +183,7 @@ func newDevRootCmd(platformCoreClient astroplatformcore.CoreClient, astroCoreCli newAirflowRestartCmd(astroCoreClient), newAirflowBashCmd(), newAirflowObjectRootCmd(), - newAirflowUpgradeTestCmd(platformCoreClient), + newAirflowUpgradeTestCmd(platformCoreClient, store), newProxyRootCmd(), ) return cmd @@ -256,14 +257,14 @@ func newAirflowInitCmd() *cobra.Command { return cmd } -func newAirflowUpgradeTestCmd(platformCoreClient astroplatformcore.CoreClient) *cobra.Command { +func newAirflowUpgradeTestCmd(platformCoreClient astroplatformcore.CoreClient, store keychain.SecureStore) *cobra.Command { cmd := &cobra.Command{ Use: "upgrade-test", Short: "Test compatibility with a new Airflow or Runtime version", Long: "Run compatibility tests to check if your environment and DAGs work with a new version of Airflow or Astro Runtime. Produces reports covering dependency version changes, DAG import errors, and Airflow deprecation lint issues. Does not modify your project or local environment.", PreRunE: EnsureRuntime, RunE: func(cmd *cobra.Command, args []string) error { - return airflowUpgradeTest(cmd, platformCoreClient) + return airflowUpgradeTest(cmd, platformCoreClient, store) }, } cmd.Flags().StringVarP(&airflowVersion, "airflow-version", "a", "", "The version of Airflow you want to upgrade to. The default is the latest available version. Tests are run against the equivalent Astro Runtime version.") @@ -757,7 +758,7 @@ func ensureProjectName(args []string, projectName string) (string, error) { return projectName, nil } -func airflowUpgradeTest(cmd *cobra.Command, platformCoreClient astroplatformcore.CoreClient) error { //nolint:gocognit +func airflowUpgradeTest(cmd *cobra.Command, platformCoreClient astroplatformcore.CoreClient, store keychain.SecureStore) error { //nolint:gocognit // Validate runtimeVersion and airflowVersion if airflowVersion != "" && runtimeVersion != "" { return errInvalidBothAirflowAndRuntimeVersionsUpgrade @@ -800,7 +801,7 @@ func airflowUpgradeTest(cmd *cobra.Command, platformCoreClient astroplatformcore buildSecretString = util.GetbuildSecretString(buildSecrets) - err = containerHandler.UpgradeTest(runtimeVersion, deploymentID, customImageName, buildSecretString, versionTest, dagTest, lintTest, lintDeprecations, lintFix, lintConfigFile, platformCoreClient) + err = containerHandler.UpgradeTest(runtimeVersion, deploymentID, customImageName, buildSecretString, versionTest, dagTest, lintTest, lintDeprecations, lintFix, lintConfigFile, platformCoreClient, store) if err != nil { return err } diff --git a/cmd/airflow_test.go b/cmd/airflow_test.go index 15efc8e29..6eb6f0181 100644 --- a/cmd/airflow_test.go +++ b/cmd/airflow_test.go @@ -148,7 +148,7 @@ func (s *AirflowSuite) TestDevInitCommandSoftware() { } func (s *AirflowSuite) TestNewAirflowDevRootCmd() { - cmd := newDevRootCmd(nil, nil) + cmd := newDevRootCmd(nil, nil, nil) s.Nil(cmd.PersistentPreRunE(new(cobra.Command), []string{})) } @@ -761,71 +761,71 @@ func (s *AirflowSuite) TestAirflowStart() { func (s *AirflowSuite) TestAirflowUpgradeTest() { s.Run("success", func() { - cmd := newAirflowUpgradeTestCmd(nil) + cmd := newAirflowUpgradeTestCmd(nil, nil) mockContainerHandler := new(mocks.ContainerHandler) containerHandlerInit = func(airflowHome, envFile, dockerfile, imageName string) (airflow.ContainerHandler, error) { - mockContainerHandler.On("UpgradeTest", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, false, false, false, "", nil).Return(nil).Once() + mockContainerHandler.On("UpgradeTest", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, false, false, false, "", nil, nil).Return(nil).Once() return mockContainerHandler, nil } - err := airflowUpgradeTest(cmd, nil) + err := airflowUpgradeTest(cmd, nil, nil) s.NoError(err) mockContainerHandler.AssertExpectations(s.T()) }) s.Run("failure", func() { - cmd := newAirflowUpgradeTestCmd(nil) + cmd := newAirflowUpgradeTestCmd(nil, nil) mockContainerHandler := new(mocks.ContainerHandler) containerHandlerInit = func(airflowHome, envFile, dockerfile, imageName string) (airflow.ContainerHandler, error) { - mockContainerHandler.On("UpgradeTest", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, false, false, false, "", nil).Return(errMock).Once() + mockContainerHandler.On("UpgradeTest", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, false, false, false, "", nil, nil).Return(errMock).Once() return mockContainerHandler, nil } - err := airflowUpgradeTest(cmd, nil) + err := airflowUpgradeTest(cmd, nil, nil) s.ErrorIs(err, errMock) mockContainerHandler.AssertExpectations(s.T()) }) s.Run("containerHandlerInit failure", func() { - cmd := newAirflowUpgradeTestCmd(nil) + cmd := newAirflowUpgradeTestCmd(nil, nil) containerHandlerInit = func(airflowHome, envFile, dockerfile, imageName string) (airflow.ContainerHandler, error) { return nil, errMock } - err := airflowUpgradeTest(cmd, nil) + err := airflowUpgradeTest(cmd, nil, nil) s.ErrorIs(err, errMock) }) s.Run("Both airflow and runtime version used", func() { - cmd := newAirflowUpgradeTestCmd(nil) + cmd := newAirflowUpgradeTestCmd(nil, nil) airflowVersion = "something" runtimeVersion = "something" - err := airflowUpgradeTest(cmd, nil) + err := airflowUpgradeTest(cmd, nil, nil) s.ErrorIs(err, errInvalidBothAirflowAndRuntimeVersionsUpgrade) }) s.Run("Both runtime version and custom image used", func() { - cmd := newAirflowUpgradeTestCmd(nil) + cmd := newAirflowUpgradeTestCmd(nil, nil) customImageName = "something" runtimeVersion = "something" - err := airflowUpgradeTest(cmd, nil) + err := airflowUpgradeTest(cmd, nil, nil) s.ErrorIs(err, errInvalidBothCustomImageandVersion) }) s.Run("Both airflow version and custom image used", func() { - cmd := newAirflowUpgradeTestCmd(nil) + cmd := newAirflowUpgradeTestCmd(nil, nil) customImageName = "something" airflowVersion = "something" - err := airflowUpgradeTest(cmd, nil) + err := airflowUpgradeTest(cmd, nil, nil) s.ErrorIs(err, errInvalidBothCustomImageandVersion) }) } @@ -1919,7 +1919,7 @@ func (s *AirflowSuite) TestDevCommandLocalSubcommandRemoved() { func (s *AirflowSuite) TestStandaloneDockerFlagsMutuallyExclusive() { // Verify that the flags are registered as mutually exclusive on the dev root command - cmd := newDevRootCmd(nil, nil) + cmd := newDevRootCmd(nil, nil, nil) s.NotNil(cmd.PersistentFlags().Lookup("standalone")) s.NotNil(cmd.PersistentFlags().Lookup("docker")) diff --git a/cmd/api/airflow.go b/cmd/api/airflow.go index 30bc5a715..60a7c2163 100644 --- a/cmd/api/airflow.go +++ b/cmd/api/airflow.go @@ -43,18 +43,19 @@ type AirflowOptions struct { // Internal detectedVersion string // The Airflow version being used (detected or overridden) CredentialsExplicit bool // true when --username or --password was explicitly passed + tokenHolder *httputil.TokenHolder } // NewAirflowCmd creates the 'astro api airflow' command. // //nolint:dupl -func NewAirflowCmd(out io.Writer) *cobra.Command { +func NewAirflowCmd(out io.Writer, tokenHolder *httputil.TokenHolder) *cobra.Command { opts := &AirflowOptions{ RequestOptions: RequestOptions{ Out: out, ErrOut: os.Stderr, - // specCache is initialized lazily when we know the Airflow version }, + tokenHolder: tokenHolder, } cmd := &cobra.Command{ @@ -459,7 +460,7 @@ func resolveDeploymentAirflowURL(opts *AirflowOptions) (baseURL, authToken strin } // Check for token - if ctx.Token == "" { + if opts.tokenHolder == nil || opts.tokenHolder.Get() == "" { return "", "", fmt.Errorf("not authenticated. Run 'astro login' to authenticate") } @@ -473,7 +474,7 @@ func resolveDeploymentAirflowURL(opts *AirflowOptions) (baseURL, authToken strin } // Create platform client - platformCoreClient := platformclient.NewPlatformCoreClient(httputil.NewHTTPClient()) + platformCoreClient := platformclient.NewPlatformCoreClient(httputil.NewHTTPClient(), opts.tokenHolder) // Fetch deployment dep, err := deployment.CoreGetDeployment(orgID, opts.DeploymentID, platformCoreClient) @@ -492,7 +493,7 @@ func resolveDeploymentAirflowURL(opts *AirflowOptions) (baseURL, authToken strin airflowURL = "https://" + airflowURL } - return airflowURL, ctx.Token, nil + return airflowURL, opts.tokenHolder.Get(), nil } // runAirflowInteractive runs the airflow API command in interactive mode. diff --git a/cmd/api/airflow_test.go b/cmd/api/airflow_test.go index 0d251df45..d78b1ec29 100644 --- a/cmd/api/airflow_test.go +++ b/cmd/api/airflow_test.go @@ -14,13 +14,14 @@ import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" "github.com/astronomer/astro-cli/cloud/deployment" "github.com/astronomer/astro-cli/config" + "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/openapi" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) func TestNewAirflowCmd(t *testing.T) { out := new(bytes.Buffer) - cmd := NewAirflowCmd(out) + cmd := NewAirflowCmd(out, nil) assert.Equal(t, "airflow ", cmd.Use) assert.NotEmpty(t, cmd.Short) @@ -39,7 +40,7 @@ func TestNewAirflowCmd(t *testing.T) { func TestAirflowCmdFlags(t *testing.T) { out := new(bytes.Buffer) - cmd := NewAirflowCmd(out) + cmd := NewAirflowCmd(out, nil) // Check airflow-specific flags exist (these are persistent flags so they're inherited by subcommands) assert.NotNil(t, cmd.PersistentFlags().Lookup("api-url")) @@ -128,7 +129,7 @@ func TestResolveAirflowAPIURL_DeploymentID_NoToken(t *testing.T) { // Set up context without token ctx, err := config.GetCurrentContext() require.NoError(t, err) - ctx.Token = "" + err = ctx.SetContext() require.NoError(t, err) @@ -149,7 +150,7 @@ func TestResolveAirflowAPIURL_DeploymentID_Success(t *testing.T) { // Set up context with token and organization ctx, err := config.GetCurrentContext() require.NoError(t, err) - ctx.Token = "test-token" + ctx.Organization = "test-org" err = ctx.SetContext() require.NoError(t, err) @@ -168,8 +169,10 @@ func TestResolveAirflowAPIURL_DeploymentID_Success(t *testing.T) { }, nil } + th := httputil.NewTokenHolder("test-token") opts := &AirflowOptions{ DeploymentID: "test-deployment-id", + tokenHolder: th, } baseURL, authToken, err := resolveAirflowAPIURL(opts) @@ -186,7 +189,7 @@ func TestResolveAirflowAPIURL_DeploymentID_WithOrgOverride(t *testing.T) { // Set up context with token but different organization ctx, err := config.GetCurrentContext() require.NoError(t, err) - ctx.Token = "test-token" + ctx.Organization = "context-org" err = ctx.SetContext() require.NoError(t, err) @@ -206,9 +209,11 @@ func TestResolveAirflowAPIURL_DeploymentID_WithOrgOverride(t *testing.T) { }, nil } + th := httputil.NewTokenHolder("test-token") opts := &AirflowOptions{ DeploymentID: "test-deployment-id", OrganizationID: "override-org", + tokenHolder: th, } baseURL, authToken, err := resolveAirflowAPIURL(opts) @@ -225,7 +230,7 @@ func TestResolveAirflowAPIURL_DeploymentID_NoAirflowURL(t *testing.T) { // Set up context with token and organization ctx, err := config.GetCurrentContext() require.NoError(t, err) - ctx.Token = "test-token" + ctx.Organization = "test-org" err = ctx.SetContext() require.NoError(t, err) @@ -241,8 +246,10 @@ func TestResolveAirflowAPIURL_DeploymentID_NoAirflowURL(t *testing.T) { }, nil } + th := httputil.NewTokenHolder("test-token") opts := &AirflowOptions{ DeploymentID: "test-deployment-id", + tokenHolder: th, } _, _, err = resolveAirflowAPIURL(opts) diff --git a/cmd/api/api.go b/cmd/api/api.go index be07b33c0..51f49561a 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -7,15 +7,17 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" + + "github.com/astronomer/astro-cli/pkg/httputil" ) // NewAPICmd creates the parent 'astro api' command. -func NewAPICmd() *cobra.Command { - return NewAPICmdWithOutput(os.Stdout) +func NewAPICmd(tokenHolder *httputil.TokenHolder) *cobra.Command { + return NewAPICmdWithOutput(os.Stdout, tokenHolder) } // NewAPICmdWithOutput creates the parent 'astro api' command with a custom output writer. -func NewAPICmdWithOutput(out io.Writer) *cobra.Command { +func NewAPICmdWithOutput(out io.Writer, tokenHolder *httputil.TokenHolder) *cobra.Command { var noColor bool cmd := &cobra.Command{ @@ -68,8 +70,8 @@ Use "astro api [command] --help" for more information about a command.`, cmd.PersistentFlags().BoolVar(&noColor, "no-color", false, "Disable colorized output") - cmd.AddCommand(NewAirflowCmd(out)) - cmd.AddCommand(NewCloudCmd(out)) + cmd.AddCommand(NewAirflowCmd(out, tokenHolder)) + cmd.AddCommand(NewCloudCmd(out, tokenHolder)) cmd.AddCommand(NewRegistryCmd(out)) return cmd diff --git a/cmd/api/api_test.go b/cmd/api/api_test.go index 9557a97d8..b70fb75f3 100644 --- a/cmd/api/api_test.go +++ b/cmd/api/api_test.go @@ -17,7 +17,7 @@ func newRootWithAPI(rootHook func(cmd *cobra.Command, args []string) error) *cob Use: "astro", PersistentPreRunE: rootHook, } - apiCmd := NewAPICmdWithOutput(new(bytes.Buffer)) + apiCmd := NewAPICmdWithOutput(new(bytes.Buffer), nil) root.AddCommand(apiCmd) return apiCmd } @@ -53,7 +53,7 @@ func TestPersistentPreRunE_RootHookErrorPropagates(t *testing.T) { func TestPersistentPreRunE_NoRootHook(t *testing.T) { // Root has no PersistentPreRunE -- should not panic or error. root := &cobra.Command{Use: "astro"} - apiCmd := NewAPICmdWithOutput(new(bytes.Buffer)) + apiCmd := NewAPICmdWithOutput(new(bytes.Buffer), nil) root.AddCommand(apiCmd) child := &cobra.Command{Use: "child"} @@ -77,7 +77,7 @@ func TestPersistentPreRunE_SilenceUsagePropagated(t *testing.T) { } func TestNewAPICmdWithOutput_SubcommandRegistration(t *testing.T) { - apiCmd := NewAPICmdWithOutput(new(bytes.Buffer)) + apiCmd := NewAPICmdWithOutput(new(bytes.Buffer), nil) names := make([]string, 0, len(apiCmd.Commands())) for _, sub := range apiCmd.Commands() { @@ -88,6 +88,6 @@ func TestNewAPICmdWithOutput_SubcommandRegistration(t *testing.T) { } func TestNewAPICmdWithOutput_Flags(t *testing.T) { - apiCmd := NewAPICmdWithOutput(new(bytes.Buffer)) + apiCmd := NewAPICmdWithOutput(new(bytes.Buffer), nil) assert.NotNil(t, apiCmd.PersistentFlags().Lookup("no-color")) } diff --git a/cmd/api/cloud.go b/cmd/api/cloud.go index 049a08140..4f5312474 100644 --- a/cmd/api/cloud.go +++ b/cmd/api/cloud.go @@ -14,6 +14,7 @@ import ( "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/ansi" "github.com/astronomer/astro-cli/pkg/domainutil" + "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/openapi" ) @@ -22,18 +23,19 @@ type CloudOptions struct { RequestOptions SpecURL string // hidden flag: alternative OpenAPI spec URL SpecTokenEnvVar string // hidden flag: env var name containing auth token for spec fetch + tokenHolder *httputil.TokenHolder } // NewCloudCmd creates the 'astro api cloud' command. // //nolint:dupl -func NewCloudCmd(out io.Writer) *cobra.Command { +func NewCloudCmd(out io.Writer, tokenHolder *httputil.TokenHolder) *cobra.Command { opts := &CloudOptions{ RequestOptions: RequestOptions{ Out: out, ErrOut: os.Stderr, - // specCache is initialized lazily when the domain is known }, + tokenHolder: tokenHolder, } cmd := &cobra.Command{ @@ -159,7 +161,7 @@ func runCloud(opts *CloudOptions) error { } // Check for token - if ctx.Token == "" { + if opts.tokenHolder == nil || opts.tokenHolder.Get() == "" { return fmt.Errorf("not authenticated. Run 'astro login' to authenticate") } @@ -232,11 +234,11 @@ func runCloud(opts *CloudOptions) error { // Generate curl command if requested if opts.GenerateCurl { - return generateCurl(opts.Out, method, url, ctx.Token, opts.RequestHeaders, params, opts.RequestInputFile) + return generateCurl(opts.Out, method, url, opts.tokenHolder.Get(), opts.RequestHeaders, params, opts.RequestInputFile) } // Build and execute the request - return executeRequest(&opts.RequestOptions, method, url, ctx.Token, params) + return executeRequest(&opts.RequestOptions, method, url, opts.tokenHolder.Get(), params) } // isOperationID checks if the input looks like an operation ID rather than a path. diff --git a/cmd/api/cloud_test.go b/cmd/api/cloud_test.go index f54d2b476..9c3121fba 100644 --- a/cmd/api/cloud_test.go +++ b/cmd/api/cloud_test.go @@ -36,7 +36,7 @@ contexts: func TestNewCloudCmd(t *testing.T) { out := new(bytes.Buffer) - cmd := NewCloudCmd(out) + cmd := NewCloudCmd(out, nil) assert.Equal(t, "cloud ", cmd.Use) assert.NotEmpty(t, cmd.Short) @@ -55,7 +55,7 @@ func TestNewCloudCmd(t *testing.T) { func TestCloudCmdFlags(t *testing.T) { out := new(bytes.Buffer) - cmd := NewCloudCmd(out) + cmd := NewCloudCmd(out, nil) // Request flags assert.NotNil(t, cmd.Flags().Lookup("method")) @@ -345,7 +345,7 @@ func TestPlaceholderRE(t *testing.T) { func TestCloudCmd_NoArgs_ShowsHelp(t *testing.T) { out := new(bytes.Buffer) - cmd := NewCloudCmd(out) + cmd := NewCloudCmd(out, nil) // Verify Args validator err := cmd.Args(cmd, nil) assert.NoError(t, err) // MaximumNArgs(1) allows 0 @@ -359,7 +359,7 @@ func TestCloudCmd_NoArgs_ShowsHelp(t *testing.T) { func TestCloudCmdLongDescription(t *testing.T) { out := new(bytes.Buffer) - cmd := NewCloudCmd(out) + cmd := NewCloudCmd(out, nil) assert.Contains(t, cmd.Long, "Astro Cloud API") assert.Contains(t, cmd.Example, "astro api cloud") } @@ -368,7 +368,7 @@ func TestCloudCmdLongDescription(t *testing.T) { func TestCloudSpecURLFlag(t *testing.T) { out := new(bytes.Buffer) - cmd := NewCloudCmd(out) + cmd := NewCloudCmd(out, nil) flag := cmd.PersistentFlags().Lookup("spec-url") require.NotNil(t, flag, "--spec-url flag should exist") @@ -383,7 +383,6 @@ func TestInitCloudSpecCache_SpecURL(t *testing.T) { opts := &CloudOptions{SpecURL: "https://example.com/spec.json"} ctx := &config.Context{ Domain: "example.com", - Token: "my-secret-token", } err := initCloudSpecCache(opts, ctx) @@ -421,7 +420,6 @@ func TestRunCloud_SpecURL_BaseURL(t *testing.T) { ctx := &config.Context{ Domain: "example.com", - Token: "test-token", Organization: "org-123", } @@ -442,7 +440,7 @@ func TestRunCloud_SpecURL_BaseURL(t *testing.T) { func TestCloudSpecTokenEnvVarFlag(t *testing.T) { out := new(bytes.Buffer) - cmd := NewCloudCmd(out) + cmd := NewCloudCmd(out, nil) flag := cmd.PersistentFlags().Lookup("spec-token-env-var") require.NotNil(t, flag, "--spec-token-env-var flag should exist") @@ -460,7 +458,6 @@ func TestInitCloudSpecCache_SpecTokenEnvVar(t *testing.T) { } ctx := &config.Context{ Domain: "example.com", - Token: "context-token", } err := initCloudSpecCache(opts, ctx) diff --git a/cmd/auth.go b/cmd/auth.go index 3b008a340..0a31e4581 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -1,6 +1,7 @@ package cmd import ( + "errors" "fmt" "io" "strings" @@ -10,9 +11,10 @@ import ( astrocore "github.com/astronomer/astro-cli/astro-client-core" astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" cloudAuth "github.com/astronomer/astro-cli/cloud/auth" - "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/domainutil" + "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" softwareAuth "github.com/astronomer/astro-cli/software/auth" ) @@ -28,20 +30,20 @@ var ( ) // newLoginCommand is a top-level alias for "astro auth login" kept for backward compatibility. -func newLoginCommand(coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { - cmd := newAuthLoginCommand(coreClient, platformCoreClient, out) +func newLoginCommand(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { + cmd := newAuthLoginCommand(store, tokenHolder, coreClient, platformCoreClient, out) cmd.Long = "Authenticate to Astro or Astro Private Cloud. This is an alias for 'astro auth login'." return cmd } // newLogoutCommand is a top-level alias for "astro auth logout" kept for backward compatibility. -func newLogoutCommand(out io.Writer) *cobra.Command { - cmd := newAuthLogoutCommand(out) +func newLogoutCommand(store keychain.SecureStore, out io.Writer) *cobra.Command { + cmd := newAuthLogoutCommand(store, out) cmd.Long = "Log out of Astronomer. This is an alias for 'astro auth logout'." return cmd } -func login(cmd *cobra.Command, args []string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) error { +func login(cmd *cobra.Command, args []string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) error { // Silence Usage as we have now validated command input cmd.SilenceUsage = true @@ -53,22 +55,22 @@ func login(cmd *cobra.Command, args []string, coreClient astrocore.CoreClient, p if context.IsCloudDomain(ctx.Domain) { fmt.Fprintf(out, "To login to Astro Private Cloud follow the instructions below. If you are attempting to login in to Astro cancel the login and run 'astro login'.\n\n") } - return softwareLogin(args[0], oAuth, "", "", houstonVersion, houstonClient, out) + return softwareLogin(args[0], oAuth, "", "", houstonVersion, store, houstonClient, out) } - return cloudLogin(args[0], token, coreClient, platformCoreClient, out, shouldDisplayLoginLink) + return cloudLogin(args[0], token, store, tokenHolder, coreClient, platformCoreClient, out, shouldDisplayLoginLink) } // Log back into the current context in case no domain is passed ctx, err := context.GetCurrentContext() if err != nil || ctx.Domain == "" { // Default case when no domain is passed, and error getting current context - return cloudLogin(domainutil.DefaultDomain, token, coreClient, platformCoreClient, out, shouldDisplayLoginLink) + return cloudLogin(domainutil.DefaultDomain, token, store, tokenHolder, coreClient, platformCoreClient, out, shouldDisplayLoginLink) } else if context.IsCloudDomain(ctx.Domain) { - return cloudLogin(ctx.Domain, token, coreClient, platformCoreClient, out, shouldDisplayLoginLink) + return cloudLogin(ctx.Domain, token, store, tokenHolder, coreClient, platformCoreClient, out, shouldDisplayLoginLink) } - return softwareLogin(ctx.Domain, oAuth, "", "", houstonVersion, houstonClient, out) + return softwareLogin(ctx.Domain, oAuth, "", "", houstonVersion, store, houstonClient, out) } -func logout(cmd *cobra.Command, args []string, out io.Writer) error { +func logout(cmd *cobra.Command, args []string, store keychain.SecureStore, out io.Writer) error { var domain string if len(args) == 1 { domain = args[0] @@ -84,35 +86,35 @@ func logout(cmd *cobra.Command, args []string, out io.Writer) error { cmd.SilenceUsage = true if context.IsCloudDomain(domain) { - cloudLogout(domain, out) + cloudLogout(domain, store, out) } else { - softwareLogout(domain) + softwareLogout(domain, store) } return nil } -func newAuthRootCmd(coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { +func newAuthRootCmd(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { cmd := &cobra.Command{ Use: "auth", Short: "Manage authentication to Astronomer", Long: "Commands for authenticating to Astro or Astro Private Cloud", } cmd.AddCommand( - newAuthLoginCommand(coreClient, platformCoreClient, out), - newAuthLogoutCommand(out), - newAuthTokenCommand(out), + newAuthLoginCommand(store, tokenHolder, coreClient, platformCoreClient, out), + newAuthLogoutCommand(store, out), + newAuthTokenCommand(store, out), ) return cmd } -func newAuthLoginCommand(coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { +func newAuthLoginCommand(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { cmd := &cobra.Command{ Use: "login [BASEDOMAIN]", Short: "Log in to Astronomer", Long: "Authenticate to Astro or Astro Private Cloud", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return login(cmd, args, coreClient, platformCoreClient, out) + return login(cmd, args, store, tokenHolder, coreClient, platformCoreClient, out) }, } @@ -122,20 +124,20 @@ func newAuthLoginCommand(coreClient astrocore.CoreClient, platformCoreClient ast return cmd } -func newAuthLogoutCommand(out io.Writer) *cobra.Command { +func newAuthLogoutCommand(store keychain.SecureStore, out io.Writer) *cobra.Command { cmd := &cobra.Command{ Use: "logout", Short: "Log out of Astronomer", Long: "Log out of Astronomer", RunE: func(cmd *cobra.Command, args []string) error { - return logout(cmd, args, out) + return logout(cmd, args, store, out) }, Args: cobra.MaximumNArgs(1), } return cmd } -func newAuthTokenCommand(out io.Writer) *cobra.Command { +func newAuthTokenCommand(store keychain.SecureStore, out io.Writer) *cobra.Command { var tokenDomain string cmd := &cobra.Command{ Use: "token", @@ -143,33 +145,43 @@ func newAuthTokenCommand(out io.Writer) *cobra.Command { Long: "Print the current authentication token to standard output. This is useful for using the token in scripts or CI/CD pipelines.", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return printAuthToken(cmd, tokenDomain, out) + return printAuthToken(cmd, store, tokenDomain, out) }, } cmd.Flags().StringVarP(&tokenDomain, "domain", "d", "", "Print the token for a specific context domain instead of the current context") return cmd } -func printAuthToken(cmd *cobra.Command, contextDomain string, out io.Writer) error { +func printAuthToken(cmd *cobra.Command, store keychain.SecureStore, contextDomain string, out io.Writer) error { // Silence Usage as we have now validated command input cmd.SilenceUsage = true - var c config.Context - var err error + var domain string if contextDomain != "" { - c, err = context.GetContext(contextDomain) + domain = contextDomain } else { - c, err = context.GetCurrentContext() + c, err := context.GetCurrentContext() + if err != nil { + return err + } + domain = c.Domain } + + if store == nil { + return fmt.Errorf("no token found. Please run 'astro login' to authenticate") + } + creds, err := store.GetCredentials(domain) if err != nil { - return err + if errors.Is(err, keychain.ErrNotFound) { + return fmt.Errorf("no token found. Please run 'astro login' to authenticate") + } + return fmt.Errorf("reading credentials: %w", err) } - - if c.Token == "" { + if creds.Token == "" { return fmt.Errorf("no token found. Please run 'astro login' to authenticate") } - rawToken := strings.TrimPrefix(c.Token, "Bearer ") + rawToken := strings.TrimPrefix(creds.Token, "Bearer ") fmt.Fprintln(out, rawToken) return nil } diff --git a/cmd/auth_test.go b/cmd/auth_test.go index 7222d25c2..24e26e2f9 100644 --- a/cmd/auth_test.go +++ b/cmd/auth_test.go @@ -11,6 +11,8 @@ import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/houston" + "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -26,38 +28,38 @@ func (s *CmdSuite) TestLogin() { cloudDomain := "astronomer.io" softwareDomain := "astronomer_dev.com" - cloudLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + cloudLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { s.Equal(cloudDomain, domain) return nil } - softwareLogin = func(domain string, oAuthOnly bool, username, password, houstonVersion string, client houston.ClientInterface, out io.Writer) error { + softwareLogin = func(domain string, oAuthOnly bool, username, password, houstonVersion string, store keychain.SecureStore, client houston.ClientInterface, out io.Writer) error { s.Equal(softwareDomain, domain) return nil } // cloud login success - login(&cobra.Command{}, []string{cloudDomain}, nil, nil, buf) + login(&cobra.Command{}, []string{cloudDomain}, nil, nil, nil, nil, buf) // software login success testUtil.InitTestConfig(testUtil.Initial) - login(&cobra.Command{}, []string{softwareDomain}, nil, nil, buf) + login(&cobra.Command{}, []string{softwareDomain}, nil, nil, nil, nil, buf) // no domain, cloud login testUtil.InitTestConfig(testUtil.CloudPlatform) - login(&cobra.Command{}, []string{}, nil, nil, buf) + login(&cobra.Command{}, []string{}, nil, nil, nil, nil, buf) // no domain, software login testUtil.InitTestConfig(testUtil.SoftwarePlatform) - login(&cobra.Command{}, []string{}, nil, nil, buf) + login(&cobra.Command{}, []string{}, nil, nil, nil, nil, buf) // no domain, no current context set config.ResetCurrentContext() - login(&cobra.Command{}, []string{}, nil, nil, buf) + login(&cobra.Command{}, []string{}, nil, nil, nil, nil, buf) testUtil.InitTestConfig(testUtil.LocalPlatform) softwareDomain = "software.astronomer.io" - login(&cobra.Command{}, []string{softwareDomain}, nil, nil, buf) + login(&cobra.Command{}, []string{softwareDomain}, nil, nil, nil, nil, buf) s.Contains(buf.String(), "To login to Astro Private Cloud follow the instructions below. If you are attempting to login in to Astro cancel the login and run 'astro login'.\n\n") } @@ -65,96 +67,96 @@ func (s *CmdSuite) TestLogout() { localDomain := "localhost" softwareDomain := "astronomer_dev.com" - cloudLogout = func(domain string, out io.Writer) { + cloudLogout = func(domain string, store keychain.SecureStore, out io.Writer) { s.Equal(localDomain, domain) } - softwareLogout = func(domain string) { + softwareLogout = func(domain string, store keychain.SecureStore) { s.Equal(softwareDomain, domain) } // cloud logout success - err := logout(&cobra.Command{}, []string{localDomain}, os.Stdout) + err := logout(&cobra.Command{}, []string{localDomain}, nil, os.Stdout) s.NoError(err) // software logout success - err = logout(&cobra.Command{}, []string{softwareDomain}, os.Stdout) + err = logout(&cobra.Command{}, []string{softwareDomain}, nil, os.Stdout) s.NoError(err) // no domain, cloud logout testUtil.InitTestConfig(testUtil.LocalPlatform) - err = logout(&cobra.Command{}, []string{}, os.Stdout) + err = logout(&cobra.Command{}, []string{}, nil, os.Stdout) s.NoError(err) // no domain, software logout testUtil.InitTestConfig(testUtil.SoftwarePlatform) - err = logout(&cobra.Command{}, []string{}, os.Stdout) + err = logout(&cobra.Command{}, []string{}, nil, os.Stdout) s.NoError(err) // no domain, no current context set config.ResetCurrentContext() - err = logout(&cobra.Command{}, []string{}, os.Stdout) + err = logout(&cobra.Command{}, []string{}, nil, os.Stdout) s.EqualError(err, "no context set, have you authenticated to Astro or Astro Private Cloud? Run astro login and try again") } func (s *CmdSuite) TestAuthToken() { buf := new(bytes.Buffer) - - // Test with valid token (with Bearer prefix) testUtil.InitTestConfig(testUtil.CloudPlatform) c, err := config.GetCurrentContext() s.NoError(err) expectedToken := "test-token-12345" - err = c.SetContextKey("token", "Bearer "+expectedToken) - s.NoError(err) - err = printAuthToken(&cobra.Command{}, "", buf) + store := keychain.NewTestStore() + + // Test with valid token (with Bearer prefix) + s.NoError(store.SetCredentials(c.Domain, keychain.Credentials{Token: "Bearer " + expectedToken})) + err = printAuthToken(&cobra.Command{}, store, "", buf) s.NoError(err) s.Equal(expectedToken+"\n", buf.String()) // Test with token without Bearer prefix buf.Reset() - err = c.SetContextKey("token", expectedToken) - s.NoError(err) - - err = printAuthToken(&cobra.Command{}, "", buf) + s.NoError(store.SetCredentials(c.Domain, keychain.Credentials{Token: expectedToken})) + err = printAuthToken(&cobra.Command{}, store, "", buf) s.NoError(err) s.Equal(expectedToken+"\n", buf.String()) // Test with no token (not authenticated) buf.Reset() - err = c.SetContextKey("token", "") - s.NoError(err) + s.NoError(store.SetCredentials(c.Domain, keychain.Credentials{})) + err = printAuthToken(&cobra.Command{}, store, "", buf) + s.EqualError(err, "no token found. Please run 'astro login' to authenticate") - err = printAuthToken(&cobra.Command{}, "", buf) + // Test with nil store + buf.Reset() + err = printAuthToken(&cobra.Command{}, nil, "", buf) s.EqualError(err, "no token found. Please run 'astro login' to authenticate") // Test with no current context set buf.Reset() config.ResetCurrentContext() - err = printAuthToken(&cobra.Command{}, "", buf) + err = printAuthToken(&cobra.Command{}, store, "", buf) s.Error(err) } func (s *CmdSuite) TestAuthTokenWithContext() { buf := new(bytes.Buffer) - - // Set up a specific context with a token testUtil.InitTestConfig(testUtil.CloudPlatform) c, err := config.GetCurrentContext() s.NoError(err) expectedToken := "context-specific-token" - err = c.SetContextKey("token", "Bearer "+expectedToken) - s.NoError(err) + + store := keychain.NewTestStore() + s.NoError(store.SetCredentials(c.Domain, keychain.Credentials{Token: "Bearer " + expectedToken})) // Retrieve token using explicit context domain - err = printAuthToken(&cobra.Command{}, c.Domain, buf) + err = printAuthToken(&cobra.Command{}, store, c.Domain, buf) s.NoError(err) s.Equal(expectedToken+"\n", buf.String()) - // Test with non-existent context + // Test with domain that has no credentials buf.Reset() - err = printAuthToken(&cobra.Command{}, "nonexistent.domain.com", buf) - s.Error(err) + err = printAuthToken(&cobra.Command{}, store, "nonexistent.domain.com", buf) + s.EqualError(err, "no token found. Please run 'astro login' to authenticate") } func (s *CmdSuite) TestAuthRootCmd() { diff --git a/cmd/cloud/setup.go b/cmd/cloud/setup.go index 692c07869..223a99f53 100644 --- a/cmd/cloud/setup.go +++ b/cmd/cloud/setup.go @@ -22,6 +22,7 @@ import ( "github.com/astronomer/astro-cli/cloud/organization" "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" "github.com/astronomer/astro-cli/pkg/util" ) @@ -47,6 +48,7 @@ type TokenResponse struct { IDToken string `json:"id_token"` TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` Scope string `json:"scope"` Error *string `json:"error,omitempty"` ErrorDescription string `json:"error_description,omitempty"` @@ -64,7 +66,7 @@ type CustomClaims struct { } //nolint:gocognit -func Setup(cmd *cobra.Command, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient) error { +func Setup(cmd *cobra.Command, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient) error { // If the user is trying to login or logout no need to go through auth setup. if cmd.CalledAs() == "login" || cmd.CalledAs() == "logout" { return nil @@ -108,7 +110,7 @@ func Setup(cmd *cobra.Command, platformCoreClient astroplatformcore.CoreClient, } // Check for APITokens before API keys or refresh tokens - apiToken, err := checkAPIToken(isDeploymentFile, platformCoreClient) + apiToken, err := checkAPIToken(isDeploymentFile, tokenHolder, platformCoreClient) if err != nil { return err } @@ -117,14 +119,14 @@ func Setup(cmd *cobra.Command, platformCoreClient astroplatformcore.CoreClient, } // run auth setup for any command that requires auth - apiKey, err := checkAPIKeys(platformCoreClient, isDeploymentFile) + apiKey, err := checkAPIKeys(platformCoreClient, tokenHolder, isDeploymentFile) if err != nil { return err } if apiKey { return nil } - err = checkToken(coreClient, platformCoreClient, os.Stdout) + err = checkToken(store, tokenHolder, coreClient, platformCoreClient, os.Stdout) if err != nil { return err } @@ -132,60 +134,43 @@ func Setup(cmd *cobra.Command, platformCoreClient astroplatformcore.CoreClient, return nil } -func checkToken(coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) error { - c, err := context.GetCurrentContext() // get current context +func checkToken(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) error { + c, err := context.GetCurrentContext() if err != nil { return err } - expireTime, _ := c.GetExpiresIn() - // check if user is logged in - if c.Token == "Bearer " || c.Token == "" || c.Domain == "" { - // guide the user through the login process if not logged in - err := authLogin(c.Domain, "", coreClient, platformCoreClient, out, false) - if err != nil { - return err - } - return nil - } else if isExpired(expireTime, accessTokenExpThreshold) { + creds, err := store.GetCredentials(c.Domain) + if err != nil || creds.Token == "" { + return authLogin(c.Domain, "", store, tokenHolder, coreClient, platformCoreClient, out, false) + } + + if isExpired(creds.ExpiresAt, accessTokenExpThreshold) { authConfig, err := auth.FetchDomainAuthConfig(c.Domain) if err != nil { return err } - res, err := refresh(c.RefreshToken, authConfig) + res, err := refresh(creds.RefreshToken, authConfig) if err != nil { - // guide the user through the login process if refresh doesn't work - err := authLogin(c.Domain, "", coreClient, platformCoreClient, out, false) - if err != nil { - return err - } + return authLogin(c.Domain, "", store, tokenHolder, coreClient, platformCoreClient, out, false) } - // persist the updated context with the renewed access token - err = c.SetContextKey("token", "Bearer "+res.AccessToken) - if err != nil { - return err + newCreds := keychain.Credentials{ + Token: "Bearer " + res.AccessToken, + RefreshToken: creds.RefreshToken, + UserEmail: creds.UserEmail, + ExpiresAt: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second), } - err = c.SetExpiresIn(res.ExpiresIn) - if err != nil { - return err + if res.RefreshToken != "" { + newCreds.RefreshToken = res.RefreshToken } - err = c.SetContextKey("workspace", c.Workspace) - if err != nil { - return err - } - err = c.SetContextKey("workspace", c.LastUsedWorkspace) - if err != nil { - return err - } - err = c.SetContextKey("organization", c.Organization) - if err != nil { - return err - } - err = c.SetContextKey("organization_product", c.OrganizationProduct) - if err != nil { + if err := store.SetCredentials(c.Domain, newCreds); err != nil { return err } + tokenHolder.Set(newCreds.Token) + return nil } + + tokenHolder.Set(creds.Token) return nil } @@ -234,7 +219,7 @@ func refresh(refreshToken string, authConfig auth.Config) (TokenResponse, error) return tokenRes, nil } -func checkAPIKeys(platformCoreClient astroplatformcore.CoreClient, isDeploymentFile bool) (bool, error) { +func checkAPIKeys(platformCoreClient astroplatformcore.CoreClient, tokenHolder *httputil.TokenHolder, isDeploymentFile bool) (bool, error) { // check os variables astronomerKeyID := os.Getenv("ASTRONOMER_KEY_ID") astronomerKeySecret := os.Getenv("ASTRONOMER_KEY_SECRET") @@ -315,15 +300,8 @@ func checkAPIKeys(platformCoreClient astroplatformcore.CoreClient, isDeploymentF return false, errors.New(tokenRes.ErrorDescription) } - err = c.SetContextKey("token", "Bearer "+tokenRes.AccessToken) - if err != nil { - return false, err - } + tokenHolder.Set("Bearer " + tokenRes.AccessToken) - err = c.SetExpiresIn(tokenRes.ExpiresIn) - if err != nil { - return false, err - } orgs, err := organization.ListOrganizations(platformCoreClient) if err != nil { return false, err @@ -352,7 +330,7 @@ func checkAPIKeys(platformCoreClient astroplatformcore.CoreClient, isDeploymentF return true, nil } -func checkAPIToken(isDeploymentFile bool, platformCoreClient astroplatformcore.CoreClient) (bool, error) { +func checkAPIToken(isDeploymentFile bool, tokenHolder *httputil.TokenHolder, platformCoreClient astroplatformcore.CoreClient) (bool, error) { // check os variables astroAPIToken := os.Getenv("ASTRO_API_TOKEN") if astroAPIToken == "" { @@ -389,15 +367,8 @@ func checkAPIToken(isDeploymentFile bool, platformCoreClient astroplatformcore.C } } - err = c.SetContextKey("token", "Bearer "+astroAPIToken) - if err != nil { - return false, err - } + tokenHolder.Set("Bearer " + astroAPIToken) - err = c.SetExpiresIn(time.Now().AddDate(1, 0, 0).Unix()) - if err != nil { - return false, err - } // Parse the token to peek at the custom claims claims, err := parseAPIToken(astroAPIToken) if err != nil { diff --git a/cmd/cloud/setup_test.go b/cmd/cloud/setup_test.go index aa3209066..1e1bb6fea 100644 --- a/cmd/cloud/setup_test.go +++ b/cmd/cloud/setup_test.go @@ -20,6 +20,8 @@ import ( astroplatformcore_mocks "github.com/astronomer/astro-cli/astro-client-platform-core/mocks" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" "github.com/astronomer/astro-cli/pkg/util" ) @@ -38,7 +40,7 @@ func TestSetup(t *testing.T) { cmd := &cobra.Command{Use: "login"} cmd, err := cmd.ExecuteC() assert.NoError(t, err) - err = Setup(cmd, nil, nil) + err = Setup(cmd, nil, nil, nil, nil) assert.NoError(t, err) }) @@ -50,7 +52,7 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - err = Setup(cmd, nil, nil) + err = Setup(cmd, nil, nil, nil, nil) assert.NoError(t, err) }) @@ -67,11 +69,11 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - err = Setup(cmd, nil, nil) + err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, nil, nil) assert.NoError(t, err) }) @@ -88,11 +90,11 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - err = Setup(cmd, nil, nil) + err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, nil, nil) assert.NoError(t, err) }) @@ -104,7 +106,7 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - err = Setup(cmd, nil, nil) + err = Setup(cmd, nil, nil, nil, nil) assert.NoError(t, err) }) @@ -116,7 +118,7 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - err = Setup(cmd, nil, nil) + err = Setup(cmd, nil, nil, nil, nil) assert.NoError(t, err) }) @@ -128,7 +130,7 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - err = Setup(cmd, nil, nil) + err = Setup(cmd, nil, nil, nil, nil) assert.NoError(t, err) }) @@ -140,7 +142,7 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "context"} rootCmd.AddCommand(cmd) - err = Setup(cmd, nil, nil) + err = Setup(cmd, nil, nil, nil, nil) assert.NoError(t, err) }) @@ -152,7 +154,7 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "completion"} rootCmd.AddCommand(cmd) - err = Setup(cmd, nil, nil) + err = Setup(cmd, nil, nil, nil, nil) assert.NoError(t, err) }) @@ -168,11 +170,11 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "deployment"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - err = Setup(cmd, nil, nil) + err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, nil, nil) assert.NoError(t, err) }) @@ -188,11 +190,11 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - err = Setup(cmd, nil, nil) + err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, nil, nil) assert.NoError(t, err) }) @@ -216,10 +218,10 @@ func TestSetup(t *testing.T) { RegisteredClaims: jwt.RegisteredClaims{ Issuer: "test-issuer", Subject: "test-subject", - Audience: jwt.ClaimStrings{"audience1", "audience2"}, // Audience can be a single string or an array of strings - ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), // Set expiration date 24 hours from now - NotBefore: jwt.NewNumericDate(time.Now()), // Set not before to current time - IssuedAt: jwt.NewNumericDate(time.Now()), // Set issued at to current time + Audience: jwt.ClaimStrings{"audience1", "audience2"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + NotBefore: jwt.NewNumericDate(time.Now()), + IssuedAt: jwt.NewNumericDate(time.Now()), ID: "test-id", }, } @@ -241,7 +243,7 @@ func TestSetup(t *testing.T) { t.Setenv("ASTRO_API_TOKEN", "token") - err = Setup(cmd, mockPlatformCoreClient, mockCoreClient) + err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, mockPlatformCoreClient, mockCoreClient) assert.NoError(t, err) mockPlatformCoreClient.AssertExpectations(t) }) @@ -265,7 +267,7 @@ func TestSetup(t *testing.T) { t.Setenv("ASTRO_API_TOKEN", "bad token") - err = Setup(cmd, mockPlatformCoreClient, mockCoreClient) + err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, mockPlatformCoreClient, mockCoreClient) assert.Error(t, err) }) @@ -282,13 +284,13 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } t.Setenv("ASTRO_API_TOKEN", "") - err = Setup(cmd, mockPlatformCoreClient, mockCoreClient) + err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, mockPlatformCoreClient, mockCoreClient) assert.NoError(t, err) }) @@ -318,7 +320,7 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -339,7 +341,7 @@ func TestSetup(t *testing.T) { Header: make(http.Header), } }) - err = Setup(cmd, mockPlatformCoreClient, mockCoreClient) + err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, mockPlatformCoreClient, mockCoreClient) assert.NoError(t, err) mockPlatformCoreClient.AssertExpectations(t) mockCoreClient.AssertExpectations(t) @@ -369,7 +371,7 @@ func TestCheckAPIKeys(t *testing.T) { mockPlatformCoreClient.On("ListOrganizationsWithResponse", mock.Anything, mock.Anything).Return(&mockOrgsResponse, nil).Once() mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Once() - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -396,8 +398,8 @@ func TestCheckAPIKeys(t *testing.T) { err = context.Switch(domain) assert.NoError(t, err) - // run CheckAPIKeys - _, err = checkAPIKeys(mockPlatformCoreClient, false) + holder := &httputil.TokenHolder{} + _, err = checkAPIKeys(mockPlatformCoreClient, holder, false) assert.NoError(t, err) mockPlatformCoreClient.AssertExpectations(t) mockCoreClient.AssertExpectations(t) @@ -409,26 +411,34 @@ func TestCheckToken(t *testing.T) { mockPlatformCoreClient := new(astroplatformcore_mocks.ClientWithResponsesInterface) t.Run("test check token", func(t *testing.T) { mockCoreClient := new(astrocore_mocks.ClientWithResponsesInterface) - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - // run checkToken - err := checkToken(mockCoreClient, mockPlatformCoreClient, nil) + holder := &httputil.TokenHolder{} + err := checkToken(keychain.NewTestStore(), holder, mockCoreClient, mockPlatformCoreClient, nil) assert.NoError(t, err) }) t.Run("trigger login when no token is found", func(t *testing.T) { mockCoreClient := new(astrocore_mocks.ClientWithResponsesInterface) - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return errorLogin } - ctx, err := context.GetCurrentContext() - assert.NoError(t, err) - ctx.SetContextKey("token", "") - // run checkToken - err = checkToken(mockCoreClient, mockPlatformCoreClient, nil) + holder := &httputil.TokenHolder{} + err := checkToken(keychain.NewTestStore(), holder, mockCoreClient, mockPlatformCoreClient, nil) assert.Contains(t, err.Error(), "failed to login") }) + t.Run("valid token already in store sets token holder", func(t *testing.T) { + mockCoreClient := new(astrocore_mocks.ClientWithResponsesInterface) + + store := keychain.NewTestStore() + _ = store.SetCredentials("astronomer.io", keychain.Credentials{Token: "Bearer tok", ExpiresAt: time.Now().Add(time.Hour)}) + + holder := &httputil.TokenHolder{} + err := checkToken(store, holder, mockCoreClient, mockPlatformCoreClient, nil) + assert.NoError(t, err) + assert.Equal(t, "Bearer tok", holder.Get()) + }) } func TestCheckAPIToken(t *testing.T) { @@ -457,17 +467,17 @@ func TestCheckAPIToken(t *testing.T) { RegisteredClaims: jwt.RegisteredClaims{ Issuer: "test-issuer", Subject: "test-subject", - Audience: jwt.ClaimStrings{"audience1", "audience2"}, // Audience can be a single string or an array of strings - ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), // Set expiration date 24 hours from now - NotBefore: jwt.NewNumericDate(time.Now()), // Set not before to current time - IssuedAt: jwt.NewNumericDate(time.Now()), // Set issued at to current time + Audience: jwt.ClaimStrings{"audience1", "audience2"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + NotBefore: jwt.NewNumericDate(time.Now()), + IssuedAt: jwt.NewNumericDate(time.Now()), ID: "test-id", }, } mockPlatformCoreClient := new(astroplatformcore_mocks.ClientWithResponsesInterface) t.Run("test context switch", func(t *testing.T) { - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -484,13 +494,13 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - // run checkAPIToken - _, err = checkAPIToken(true, mockPlatformCoreClient) + holder := &httputil.TokenHolder{} + _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.NoError(t, err) }) t.Run("failed to parse api token", func(t *testing.T) { - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -507,12 +517,12 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - // run checkAPIToken - _, err = checkAPIToken(true, mockPlatformCoreClient) + holder := &httputil.TokenHolder{} + _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.Error(t, err) }) t.Run("unable to fetch current context", func(t *testing.T) { - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -526,8 +536,8 @@ func TestCheckAPIToken(t *testing.T) { err := config.ResetCurrentContext() assert.NoError(t, err) - // run checkAPIToken - _, err = checkAPIToken(true, mockPlatformCoreClient) + holder := &httputil.TokenHolder{} + _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.NoError(t, err) }) @@ -546,7 +556,7 @@ func TestCheckAPIToken(t *testing.T) { }, } - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -563,8 +573,8 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - // run checkAPIToken - _, err = checkAPIToken(false, mockPlatformCoreClient) + holder := &httputil.TokenHolder{} + _, err = checkAPIToken(false, holder, mockPlatformCoreClient) assert.ErrorIs(t, err, errNotAPIToken) }) @@ -586,7 +596,7 @@ func TestCheckAPIToken(t *testing.T) { }, } - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -603,8 +613,8 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - // run checkAPIToken - _, err = checkAPIToken(true, mockPlatformCoreClient) + holder := &httputil.TokenHolder{} + _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.ErrorIs(t, err, errExpiredAPIToken) }) @@ -625,7 +635,7 @@ func TestCheckAPIToken(t *testing.T) { }, } - authLogin = func(domain, token string, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -642,8 +652,8 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - // run checkAPIToken - _, err = checkAPIToken(true, mockPlatformCoreClient) + holder := &httputil.TokenHolder{} + _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.NoError(t, err) }) } diff --git a/cmd/root.go b/cmd/root.go index 14bbd6b0e..b6494de18 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -20,12 +20,14 @@ import ( "github.com/astronomer/astro-cli/internal/telemetry" "github.com/astronomer/astro-cli/pkg/ansi" "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" ) var ( verboseLevel string houstonClient houston.ClientInterface houstonVersion string + newSecureStore = keychain.New ) const ( @@ -36,13 +38,16 @@ const ( // NewRootCmd adds all of the primary commands for the cli func NewRootCmd() *cobra.Command { var err error + tokenHolder := &httputil.TokenHolder{} + store, storeErr := newSecureStore() + httpClient := houston.NewHTTPClient() - houstonClient = houston.NewClient(httpClient) + houstonClient = houston.NewClient(httpClient, tokenHolder) - airflowClient := airflowclient.NewAirflowClient(httputil.NewHTTPClient()) - astroCoreClient := astrocore.NewCoreClient(httputil.NewHTTPClient()) - astroCoreIamClient := astroiamcore.NewIamCoreClient(httputil.NewHTTPClient()) - platformCoreClient := platformclient.NewPlatformCoreClient(httputil.NewHTTPClient()) + airflowClient := airflowclient.NewAirflowClient(httputil.NewHTTPClient(), tokenHolder) + astroCoreClient := astrocore.NewCoreClient(httputil.NewHTTPClient(), tokenHolder) + astroCoreIamClient := astroiamcore.NewIamCoreClient(httputil.NewHTTPClient(), tokenHolder) + platformCoreClient := platformclient.NewPlatformCoreClient(httputil.NewHTTPClient(), tokenHolder) ctx := cloudPlatform isCloudCtx := context.IsCloudContext() @@ -74,22 +79,22 @@ Welcome to the Astro CLI, the modern command line interface for data orchestrati } return utils.ChainRunEs( SetupLogging, - CreateRootPersistentPreRunE(astroCoreClient, platformCoreClient), + CreateRootPersistentPreRunE(storeErr, store, tokenHolder, astroCoreClient, platformCoreClient), telemetry.CreateTrackingHook(), )(cmd, args) }, } rootCmd.AddCommand( - newLoginCommand(astroCoreClient, platformCoreClient, os.Stdout), - newLogoutCommand(os.Stdout), - newAuthRootCmd(astroCoreClient, platformCoreClient, os.Stdout), + newLoginCommand(store, tokenHolder, astroCoreClient, platformCoreClient, os.Stdout), + newLogoutCommand(store, os.Stdout), + newAuthRootCmd(store, tokenHolder, astroCoreClient, platformCoreClient, os.Stdout), newVersionCommand(), - newDevRootCmd(platformCoreClient, astroCoreClient), + newDevRootCmd(platformCoreClient, astroCoreClient, store), newContextCmd(os.Stdout), newConfigRootCmd(os.Stdout), newRunCommand(), - api.NewAPICmd(), + api.NewAPICmd(tokenHolder), newTelemetryCmd(os.Stdout), newTelemetrySendCmd(), ) @@ -100,7 +105,7 @@ Welcome to the Astro CLI, the modern command line interface for data orchestrati ) } else { // Include all the commands to be exposed for software users rootCmd.AddCommand( - softwareCmd.AddCmds(houstonClient, os.Stdout)..., + softwareCmd.AddCmds(houstonClient, store, os.Stdout)..., ) softwareCmd.VersionMatchCmds(rootCmd, []string{"astro"}) } diff --git a/cmd/root_hooks.go b/cmd/root_hooks.go index ae105f4b3..8225463b2 100644 --- a/cmd/root_hooks.go +++ b/cmd/root_hooks.go @@ -2,6 +2,7 @@ package cmd import ( "errors" + "fmt" "net/http" "os" "strings" @@ -15,6 +16,8 @@ import ( softwareCmd "github.com/astronomer/astro-cli/cmd/software" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/version" ) @@ -26,32 +29,68 @@ func SetupLogging(_ *cobra.Command, _ []string) error { // CreateRootPersistentPreRunE takes clients as arguments and returns a cobra // pre-run hook that sets up the context and checks for the latest version. -func CreateRootPersistentPreRunE(astroCoreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) func(cmd *cobra.Command, args []string) error { +func CreateRootPersistentPreRunE(storeErr error, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, astroCoreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error { + // login/logout don't need existing credentials, skip auth setup + if cmd.CalledAs() == "login" || cmd.CalledAs() == "logout" { + return nil + } + + if storeErr != nil { + return fmt.Errorf("secure credential store unavailable: %w", storeErr) + } + // Check for latest version if config.CFG.UpgradeMessage.GetBool() { - // create http client with 3 second timeout, setting an aggressive timeout since its not mandatory to get a response in each command execution httpClient := &http.Client{Timeout: 3 * time.Second} - - // compare current version to latest err := version.CompareVersions(cmd.Context(), httpClient) if err != nil { softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error comparing CLI versions: "+err.Error()) } } + + if migrated, err := config.MigrateLegacyCredentials(store); err != nil { + softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "credential migration error: "+err.Error()) + } else if migrated > 0 { + fmt.Printf("Migrated credentials for %d context(s) to your system's secure store.\n", migrated) + } + if context.IsCloudContext() { - err := cloudCmd.Setup(cmd, platformCoreClient, astroCoreClient) - if err != nil { - if strings.Contains(err.Error(), "token is invalid or malformed") { - return errors.New("API Token is invalid or malformed") //nolint - } - if strings.Contains(err.Error(), "the API token given has expired") { - return errors.New("API Token is expired") //nolint - } - softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error during cmd setup: "+err.Error()) + if err := handleCloudSetup(cmd, store, tokenHolder, platformCoreClient, astroCoreClient); err != nil { + return err } + } else { + loadSoftwareToken(store, tokenHolder) } softwareCmd.PrintDebugLogs() return nil } } + +func handleCloudSetup(cmd *cobra.Command, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, platformCoreClient astroplatformcore.CoreClient, astroCoreClient astrocore.CoreClient) error { + err := cloudCmd.Setup(cmd, store, tokenHolder, platformCoreClient, astroCoreClient) + if err == nil { + return nil + } + if strings.Contains(err.Error(), "token is invalid or malformed") { + return errors.New("API Token is invalid or malformed") //nolint + } + if strings.Contains(err.Error(), "the API token given has expired") { + return errors.New("API Token is expired") //nolint + } + softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error during cmd setup: "+err.Error()) + return nil +} + +func loadSoftwareToken(store keychain.SecureStore, tokenHolder *httputil.TokenHolder) { + if store == nil { + return + } + c, err := context.GetCurrentContext() + if err != nil { + return + } + if creds, credErr := store.GetCredentials(c.Domain); credErr == nil { + tokenHolder.Set(creds.Token) + } +} diff --git a/cmd/root_test.go b/cmd/root_test.go index 149d4b400..221069c2e 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -7,10 +7,17 @@ import ( "github.com/spf13/cobra" "github.com/stretchr/testify/suite" + "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" "github.com/astronomer/astro-cli/version" ) +func init() { + newSecureStore = func() (keychain.SecureStore, error) { + return keychain.NewTestStore(), nil + } +} + type CmdSuite struct { suite.Suite } diff --git a/cmd/software/deploy.go b/cmd/software/deploy.go index 62fe636a7..4ab7530b7 100644 --- a/cmd/software/deploy.go +++ b/cmd/software/deploy.go @@ -114,7 +114,7 @@ func deployAirflow(cmd *cobra.Command, args []string) error { } if isDagOnlyDeploy { - return DagsOnlyDeploy(houstonClient, ws, deploymentID, config.WorkingPath, nil, true, description) + return DagsOnlyDeploy(houstonClient, store, ws, deploymentID, config.WorkingPath, nil, true, description) } if imagePresentOnRemote { @@ -127,7 +127,7 @@ func deployAirflow(cmd *cobra.Command, args []string) error { } } else { // Since we prompt the user to enter the deploymentID in come cases for DeployAirflowImage, reusing the same deploymentID for DagsOnlyDeploy - deploymentID, err = DeployAirflowImage(houstonClient, config.WorkingPath, deploymentID, ws, ignoreCacheDeploy, forcePrompt, description, isImageOnlyDeploy, imageName) + deploymentID, err = DeployAirflowImage(houstonClient, store, config.WorkingPath, deploymentID, ws, ignoreCacheDeploy, forcePrompt, description, isImageOnlyDeploy, imageName) if err != nil { return err } @@ -139,7 +139,7 @@ func deployAirflow(cmd *cobra.Command, args []string) error { return nil } - err = DagsOnlyDeploy(houstonClient, ws, deploymentID, config.WorkingPath, nil, true, description) + err = DagsOnlyDeploy(houstonClient, store, ws, deploymentID, config.WorkingPath, nil, true, description) // Don't throw the error if dag-deploy itself is disabled if errors.Is(err, deploy.ErrDagOnlyDeployDisabledInConfig) || errors.Is(err, deploy.ErrDagOnlyDeployNotEnabledForDeployment) { return nil diff --git a/cmd/software/deploy_test.go b/cmd/software/deploy_test.go index 47108464b..b2d2f1150 100644 --- a/cmd/software/deploy_test.go +++ b/cmd/software/deploy_test.go @@ -6,6 +6,7 @@ import ( "github.com/spf13/cobra" "github.com/astronomer/astro-cli/houston" + "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" "github.com/astronomer/astro-cli/software/deploy" ) @@ -28,14 +29,14 @@ func (s *Suite) TestDeploy() { EnsureProjectDir = func(cmd *cobra.Command, args []string) error { return nil } - DeployAirflowImage = func(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { + DeployAirflowImage = func(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { if description == "" { return deploymentID, fmt.Errorf("description should not be empty") } return deploymentID, nil } - DagsOnlyDeploy = func(houstonClient houston.ClientInterface, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { + DagsOnlyDeploy = func(houstonClient houston.ClientInterface, store keychain.SecureStore, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { return nil } @@ -53,7 +54,7 @@ func (s *Suite) TestDeploy() { s.NoError(err) // Test when the default description is used - DeployAirflowImage = func(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { + DeployAirflowImage = func(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { expectedDesc := "Deployed via " if description != expectedDesc { return deploymentID, fmt.Errorf("expected description to be '%s', but got '%s'", expectedDesc, description) @@ -68,20 +69,20 @@ func (s *Suite) TestDeploy() { DagsOnlyDeploy = deploy.DagsOnlyDeploy s.Run("error should be returned for astro deploy, if DeployAirflowImage throws error", func() { - DeployAirflowImage = func(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { + DeployAirflowImage = func(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { return deploymentID, deploy.ErrNoWorkspaceID } err := execDeployCmd([]string{"-f"}...) s.ErrorIs(err, deploy.ErrNoWorkspaceID) - DeployAirflowImage = func(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { + DeployAirflowImage = func(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { return deploymentID, nil } }) s.Run("error should be returned for astro deploy, if dags deploy throws error and the feature is enabled", func() { - DagsOnlyDeploy = func(houstonClient houston.ClientInterface, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { + DagsOnlyDeploy = func(houstonClient houston.ClientInterface, store keychain.SecureStore, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { return deploy.ErrNoWorkspaceID } err := execDeployCmd([]string{"-f"}...) @@ -89,7 +90,7 @@ func (s *Suite) TestDeploy() { }) s.Run("Test for the flag --dags when the feature is disabled", func() { - DagsOnlyDeploy = func(houstonClient houston.ClientInterface, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { + DagsOnlyDeploy = func(houstonClient houston.ClientInterface, store keychain.SecureStore, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { return deploy.ErrDagOnlyDeployDisabledInConfig } err := execDeployCmd([]string{"test-deployment-id", "--dags", "--force"}...) @@ -102,7 +103,7 @@ func (s *Suite) TestDeploy() { }) s.Run("Test for the flag --image for image deployment", func() { - DeployAirflowImage = func(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { + DeployAirflowImage = func(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { return deploymentID, deploy.ErrDeploymentTypeIncorrectForImageOnly } err := execDeployCmd([]string{"test-deployment-id", "--image", "--force"}...) @@ -110,11 +111,11 @@ func (s *Suite) TestDeploy() { }) s.Run("Test for the flag --image for dags-only deployment", func() { - DeployAirflowImage = func(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { + DeployAirflowImage = func(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { return deploymentID, nil } // This function is not called since --image is passed - DagsOnlyDeploy = func(houstonClient houston.ClientInterface, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { + DagsOnlyDeploy = func(houstonClient houston.ClientInterface, store keychain.SecureStore, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { return deploy.ErrNoWorkspaceID } err := execDeployCmd([]string{"test-deployment-id", "--image", "--force"}...) @@ -123,11 +124,11 @@ func (s *Suite) TestDeploy() { s.Run("Test for the flag --image-name", func() { var capturedImageName string - DeployAirflowImage = func(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { + DeployAirflowImage = func(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { capturedImageName = imageName // Capture the imageName return deploymentID, nil } - DagsOnlyDeploy = func(houstonClient houston.ClientInterface, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { + DagsOnlyDeploy = func(houstonClient houston.ClientInterface, store keychain.SecureStore, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { return nil } testImageName := "test-image-name" // Set the expected image name @@ -138,14 +139,14 @@ func (s *Suite) TestDeploy() { }) s.Run("Test for the flag --image-name with --remote. Dags should be deployed but DeployAirflowImage shouldn't be called", func() { - DagsOnlyDeploy = func(houstonClient houston.ClientInterface, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { + DagsOnlyDeploy = func(houstonClient houston.ClientInterface, store keychain.SecureStore, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { return nil } // Create a flag to track if DeployAirflowImage is called deployAirflowImageCalled := false // Mock function for DeployAirflowImage - DeployAirflowImage = func(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { + DeployAirflowImage = func(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { deployAirflowImageCalled = true // Set the flag if this function is called return deploymentID, nil } @@ -177,7 +178,7 @@ func (s *Suite) TestDeploy() { }) s.Run("error should be returned if BYORegistryEnabled is true but BYORegistryDomain is empty", func() { - DagsOnlyDeploy = func(houstonClient houston.ClientInterface, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { + DagsOnlyDeploy = func(houstonClient houston.ClientInterface, store keychain.SecureStore, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { return deploy.ErrBYORegistryDomainNotSet } err := execDeployCmd([]string{"-f"}...) diff --git a/cmd/software/deployment_logs.go b/cmd/software/deployment_logs.go index b72a6c774..6cc3db6a2 100644 --- a/cmd/software/deployment_logs.go +++ b/cmd/software/deployment_logs.go @@ -6,6 +6,7 @@ import ( "github.com/spf13/cobra" + "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/software/deployment" ) @@ -144,7 +145,16 @@ astro deployment logs triggerer YOU_DEPLOYMENT_ID -s string-to-find func fetchRemoteLogs(component string, args []string, out io.Writer) error { if follow { - return deployment.SubscribeDeploymentLog(args[0], component, search, since) + var token string + if store != nil { + c, err := config.GetCurrentContext() + if err == nil { + if creds, credErr := store.GetCredentials(c.Domain); credErr == nil { + token = creds.Token + } + } + } + return deployment.SubscribeDeploymentLog(args[0], component, search, token, since) } return deployment.Log(args[0], component, search, since, houstonClient, out) } diff --git a/cmd/software/deployment_teams_test.go b/cmd/software/deployment_teams_test.go index 427441d69..382d11439 100644 --- a/cmd/software/deployment_teams_test.go +++ b/cmd/software/deployment_teams_test.go @@ -101,7 +101,7 @@ func (s *Suite) TestDeploymentTeamsListCmd() { Header: make(http.Header), } }) - houstonClient = houston.NewClient(client) + houstonClient = houston.NewClient(client, nil) buf := new(bytes.Buffer) cmd := newDeploymentTeamListCmd(buf) s.NotNil(cmd) diff --git a/cmd/software/root.go b/cmd/software/root.go index 9105a2218..60fb81ac7 100644 --- a/cmd/software/root.go +++ b/cmd/software/root.go @@ -9,6 +9,7 @@ import ( "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/houston" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" ) @@ -19,14 +20,16 @@ var ( houstonClient houston.ClientInterface appConfig *houston.AppConfig houstonVersion string + store keychain.SecureStore workspaceID string teamID string ) // AddCmds adds all the command initialized in this package for the cmd package to import -func AddCmds(client houston.ClientInterface, out io.Writer) []*cobra.Command { +func AddCmds(client houston.ClientInterface, s keychain.SecureStore, out io.Writer) []*cobra.Command { houstonClient = client + store = s var err error // There is no clusterID in the GetAppConfig call at this point of lifecycle, so we are getting the app config for the default cluster diff --git a/cmd/software/root_test.go b/cmd/software/root_test.go index a8e14aeb2..c426c8a8a 100644 --- a/cmd/software/root_test.go +++ b/cmd/software/root_test.go @@ -44,7 +44,7 @@ func (s *AddCmdSuite) TestAddCmds() { houstonMock.On("GetAppConfig", "").Return(appConfig, nil) houstonMock.On("GetPlatformVersion", nil).Return("0.30.0", nil) buf := new(bytes.Buffer) - cmds := AddCmds(houstonMock, buf) + cmds := AddCmds(houstonMock, nil, buf) for cmdIdx := range cmds { s.Contains([]string{"deployment", "deploy [DEPLOYMENT ID]", "user", "workspace", "team"}, cmds[cmdIdx].Use) } @@ -56,7 +56,7 @@ func (s *AddCmdSuite) TestAppConfigFailure() { houstonMock.On("GetAppConfig", "").Return(nil, errMock) houstonMock.On("GetPlatformVersion", nil).Return("0.30.0", nil) buf := new(bytes.Buffer) - cmds := AddCmds(houstonMock, buf) + cmds := AddCmds(houstonMock, nil, buf) for cmdIdx := range cmds { s.Contains([]string{"deployment", "deploy [DEPLOYMENT ID]", "user", "workspace", "team"}, cmds[cmdIdx].Use) } @@ -75,7 +75,7 @@ func (s *AddCmdSuite) TestPlatformVersionFailure() { houstonMock.On("GetAppConfig", "").Return(appConfig, nil) houstonMock.On("GetPlatformVersion", nil).Return("", errMock) buf := new(bytes.Buffer) - cmds := AddCmds(houstonMock, buf) + cmds := AddCmds(houstonMock, nil, buf) for cmdIdx := range cmds { s.Contains([]string{"deployment", "deploy [DEPLOYMENT ID]", "user", "workspace", "team"}, cmds[cmdIdx].Use) } diff --git a/cmd/software/utils_test.go b/cmd/software/utils_test.go index 185ce5b81..a348a10cf 100644 --- a/cmd/software/utils_test.go +++ b/cmd/software/utils_test.go @@ -18,7 +18,7 @@ func (s *Suite) TestVersionMatchCmds() { mockAPI.On("GetAppConfig", "").Return(&houston.AppConfig{Version: "0.27.0"}, nil) mockAPI.On("GetPlatformVersion", nil).Return("0.27.0", nil) cmd := &cobra.Command{Use: "astro"} - childCMDs := AddCmds(mockAPI, buf) + childCMDs := AddCmds(mockAPI, nil, buf) cmd.AddCommand(childCMDs...) VersionMatchCmds(cmd, []string{"astro"}) @@ -46,7 +46,7 @@ func (s *Suite) TestVersionMatchCmds() { mockAPI.On("GetAppConfig", "").Return(&houston.AppConfig{Version: "0.30.0"}, nil) mockAPI.On("GetPlatformVersion", nil).Return("0.30.0", nil) cmd := &cobra.Command{Use: "astro"} - childCMDs := AddCmds(mockAPI, buf) + childCMDs := AddCmds(mockAPI, nil, buf) cmd.AddCommand(childCMDs...) VersionMatchCmds(cmd, []string{"astro"}) diff --git a/cmd/software/workspace_teams_test.go b/cmd/software/workspace_teams_test.go index 8ba3d4622..2bd1a38ef 100644 --- a/cmd/software/workspace_teams_test.go +++ b/cmd/software/workspace_teams_test.go @@ -102,7 +102,7 @@ func (s *Suite) TestWorkspaceTeamsListCmd() { Header: make(http.Header), } }) - houstonClient = houston.NewClient(client) + houstonClient = houston.NewClient(client, nil) buf := new(bytes.Buffer) cmd := newWorkspaceTeamsListCmd(buf) s.NotNil(cmd) diff --git a/cmd/software/workspace_user_test.go b/cmd/software/workspace_user_test.go index 353baee18..914d00d6c 100644 --- a/cmd/software/workspace_user_test.go +++ b/cmd/software/workspace_user_test.go @@ -59,7 +59,7 @@ func (s *Suite) TestNewWorkspaceUserListCmd() { Header: make(http.Header), } }) - houstonClient = houston.NewClient(client) + houstonClient = houston.NewClient(client, nil) buf := new(bytes.Buffer) cmd := newWorkspaceUserListCmd(buf) s.NotNil(cmd) diff --git a/config/context.go b/config/context.go index 80dc7b8a5..8f5400016 100644 --- a/config/context.go +++ b/config/context.go @@ -5,7 +5,8 @@ import ( "fmt" "os" "strings" - "time" + + "github.com/astronomer/astro-cli/pkg/keychain" ) var ( @@ -31,9 +32,6 @@ type Context struct { OrganizationProduct string `mapstructure:"organization_product"` Workspace string `mapstructure:"workspace"` LastUsedWorkspace string `mapstructure:"last_used_workspace"` - Token string `mapstructure:"token"` - RefreshToken string `mapstructure:"refreshtoken"` - UserEmail string `mapstructure:"user_email"` } // GetCurrentContext looks up current context and gets corresponding Context struct @@ -123,14 +121,11 @@ func (c *Context) SetContext() error { } context := map[string]string{ - "token": c.Token, "domain": c.Domain, "organization": c.Organization, "organization_product": c.OrganizationProduct, "workspace": c.Workspace, "last_used_workspace": c.Workspace, - "refreshtoken": c.RefreshToken, - "user_email": c.UserEmail, } viperHome.Set(contextsKey+"."+key, context) @@ -207,30 +202,56 @@ func (c *Context) DeleteContext() error { return nil } -func (c *Context) SetExpiresIn(value int64) error { - cKey, err := c.GetContextKey() +// MigrateLegacyCredentials reads credential fields (token, refreshtoken, +// user_email, ExpiresIn) directly from viper across all contexts and moves +// them to the provided SecureStore. Each context is migrated atomically: +// if the keychain write fails the context is left in config.yaml and retried +// on the next invocation. Returns the number of contexts migrated. +// +// These viper key names are the legacy field names and must not appear anywhere +// else in the codebase — this function is the only sanctioned reader of them. +func MigrateLegacyCredentials(store keychain.SecureStore) (int, error) { + contexts, err := GetContexts() if err != nil { - return err + return 0, err } - expiretime := time.Now().Add(time.Duration(value) * time.Second) + migrated := 0 + for contextKey, ctx := range contexts.Contexts { + token := viperHome.GetString(fmt.Sprintf("%s.%s.token", contextsKey, contextKey)) + refreshToken := viperHome.GetString(fmt.Sprintf("%s.%s.refreshtoken", contextsKey, contextKey)) + userEmail := viperHome.GetString(fmt.Sprintf("%s.%s.user_email", contextsKey, contextKey)) + expiresAt := viperHome.GetTime(fmt.Sprintf("%s.%s.ExpiresIn", contextsKey, contextKey)) - cfgPath := fmt.Sprintf("%s.%s.%s", contextsKey, cKey, "ExpiresIn") - viperHome.Set(cfgPath, expiretime) - err = saveConfig(viperHome, HomeConfigFile) - if err != nil { - return err - } + if token == "" && refreshToken == "" { + continue + } - return nil -} + creds := keychain.Credentials{ + Token: token, + RefreshToken: refreshToken, + UserEmail: userEmail, + ExpiresAt: expiresAt, + } -func (c *Context) GetExpiresIn() (time.Time, error) { - cKey, err := c.GetContextKey() - if err != nil { - return time.Time{}, err + if err := store.SetCredentials(ctx.Domain, creds); err != nil { + // Don't scrub on failure — will retry next invocation. + continue + } + + // Scrub credential fields from the config map entirely so they don't + // linger as empty strings in the YAML file. + ctxPath := fmt.Sprintf("%s.%s", contextsKey, contextKey) + ctxMap := viperHome.GetStringMap(ctxPath) + for _, field := range []string{"token", "refreshtoken", "user_email", "expiresin"} { + delete(ctxMap, field) + } + viperHome.Set(ctxPath, ctxMap) + if err := saveConfig(viperHome, HomeConfigFile); err != nil { + return migrated, err + } + migrated++ } - cfgPath := fmt.Sprintf("%s.%s.%s", contextsKey, cKey, "ExpiresIn") - return viperHome.GetTime(cfgPath), nil + return migrated, nil } diff --git a/config/context_test.go b/config/context_test.go index 42fefaa55..fba725e23 100644 --- a/config/context_test.go +++ b/config/context_test.go @@ -2,7 +2,6 @@ package config import ( "bytes" - "time" "github.com/spf13/afero" ) @@ -40,7 +39,6 @@ context: example_com contexts: example_com: domain: example.com - token: token last_used_workspace: ck05r3bor07h40d02y2hw4n4v workspace: ck05r3bor07h40d02y2hw4n4v `) @@ -48,7 +46,6 @@ contexts: InitConfig(fs) ctx := Context{ - Token: "token", LastUsedWorkspace: "ck05r3bor07h40d02y2hw4n4v", Workspace: "ck05r3bor07h40d02y2hw4n4v", Domain: "example.com", @@ -74,7 +71,6 @@ context: example_com contexts: example_com: domain: example.com - token: token last_used_workspace: ck05r3bor07h40d02y2hw4n4v workspace: `) @@ -82,7 +78,6 @@ contexts: InitConfig(fs) ctx := Context{ - Token: "token", LastUsedWorkspace: "ck05r3bor07h40d02y2hw4n4v", Workspace: "", Domain: "example.com", @@ -108,7 +103,6 @@ context: example_com contexts: example_com: domain: example.com - token: token last_used_workspace: ck05r3bor07h40d02y2hw4n4v workspace: ck05r3bor07h40d02y2hw4n4v `) @@ -117,7 +111,6 @@ contexts: ctx, err := GetCurrentContext() s.NoError(err) s.Equal("example.com", ctx.Domain) - s.Equal("token", ctx.Token) s.Equal("ck05r3bor07h40d02y2hw4n4v", ctx.Workspace) } @@ -135,12 +128,10 @@ context: example_com contexts: example_com: domain: example.com - token: token last_used_workspace: ck05r3bor07h40d02y2hw4n4v workspace: ck05r3bor07h40d02y2hw4n4v stage_example_com: domain: stage.example.com - token: token last_used_workspace: ck05r3bor07h40d02y2hw4n4w workspace: ck05r3bor07h40d02y2hw4n4w `) @@ -150,7 +141,6 @@ contexts: ctx, err := GetCurrentContext() s.NoError(err) s.Equal("stage.example.com", ctx.Domain) - s.Equal("token", ctx.Token) s.Equal("ck05r3bor07h40d02y2hw4n4w", ctx.Workspace) } @@ -161,13 +151,11 @@ context: test_com contexts: example_com: domain: example.com - token: token last_used_workspace: ck05r3bor07h40d02y2hw4n4v workspace: ck05r3bor07h40d02y2hw4n4v organization: test-org-id test_com: domain: test.com - token: token last_used_workspace: ck05r3bor07h40d02y2hw4n4v workspace: ck05r3bor07h40d02y2hw4n4v organization: test-org-id @@ -196,16 +184,21 @@ func (s *Suite) TestGetContexts() { initTestConfig() ctxs, err := GetContexts() s.NoError(err) - s.Equal(Contexts{Contexts: map[string]Context{"test_com": {"test.com", "test-org-id", "", "ck05r3bor07h40d02y2hw4n4v", "ck05r3bor07h40d02y2hw4n4v", "token", "", ""}, "example_com": {"example.com", "test-org-id", "", "ck05r3bor07h40d02y2hw4n4v", "ck05r3bor07h40d02y2hw4n4v", "token", "", ""}}}, ctxs) + s.Equal(Contexts{Contexts: map[string]Context{ + "test_com": {Domain: "test.com", Organization: "test-org-id", Workspace: "ck05r3bor07h40d02y2hw4n4v", LastUsedWorkspace: "ck05r3bor07h40d02y2hw4n4v"}, + "example_com": {Domain: "example.com", Organization: "test-org-id", Workspace: "ck05r3bor07h40d02y2hw4n4v", LastUsedWorkspace: "ck05r3bor07h40d02y2hw4n4v"}, + }}, ctxs) } func (s *Suite) TestSetContextKey() { initTestConfig() ctx := Context{Domain: "localhost"} - ctx.SetContextKey("token", "test") + err := ctx.SetContextKey("workspace", "ws-123") + s.NoError(err) outCtx, err := ctx.GetContext() s.NoError(err) - s.Equal("test", outCtx.Token) + s.Equal("localhost", outCtx.Domain) + s.Equal("ws-123", outCtx.Workspace) } func (s *Suite) TestSetOrganizationContext() { @@ -227,29 +220,3 @@ func (s *Suite) TestSetOrganizationContext() { s.Contains(err.Error(), "context config invalid, no domain specified") }) } - -func (s *Suite) TestExpiresIn() { - initTestConfig() - ctx := Context{Domain: "localhost"} - err := ctx.SetExpiresIn(12) - s.NoError(err) - - outCtx, err := ctx.GetContext() - s.NoError(err) - - val, err := outCtx.GetExpiresIn() - s.NoError(err) - s.Equal("localhost", outCtx.Domain) - s.True(time.Now().Add(time.Duration(12) * time.Second).After(val)) // now + 12 seconds will always be after expire time, since that is set before -} - -func (s *Suite) TestExpiresInFailure() { - initTestConfig() - ctx := Context{} - err := ctx.SetExpiresIn(1) - s.ErrorIs(err, ErrCtxConfigErr) - - val, err := ctx.GetExpiresIn() - s.ErrorIs(err, ErrCtxConfigErr) - s.Equal(time.Time{}, val) -} diff --git a/config/migrate_test.go b/config/migrate_test.go new file mode 100644 index 000000000..9613001c4 --- /dev/null +++ b/config/migrate_test.go @@ -0,0 +1,131 @@ +package config + +import ( + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/astronomer/astro-cli/pkg/keychain" +) + +func TestMigrateLegacyCredentials_NothingToMigrate(t *testing.T) { + fs := afero.NewMemMapFs() + configRaw := []byte(` +context: astronomer_io +contexts: + astronomer_io: + domain: astronomer.io + workspace: ws-1 +`) + err := afero.WriteFile(fs, HomeConfigFile, configRaw, 0o777) + require.NoError(t, err) + InitConfig(fs) + + store := keychain.NewTestStore() + migrated, err := MigrateLegacyCredentials(store) + require.NoError(t, err) + assert.Equal(t, 0, migrated) + + _, err = store.GetCredentials("astronomer.io") + assert.ErrorIs(t, err, keychain.ErrNotFound) +} + +func TestMigrateLegacyCredentials_SingleContext(t *testing.T) { + fs := afero.NewMemMapFs() + configRaw := []byte(` +context: astronomer_io +contexts: + astronomer_io: + domain: astronomer.io + token: "Bearer old-token" + refreshtoken: "old-refresh" + user_email: "user@example.com" + workspace: ws-1 +`) + err := afero.WriteFile(fs, HomeConfigFile, configRaw, 0o777) + require.NoError(t, err) + InitConfig(fs) + + store := keychain.NewTestStore() + migrated, err := MigrateLegacyCredentials(store) + require.NoError(t, err) + assert.Equal(t, 1, migrated) + + creds, err := store.GetCredentials("astronomer.io") + require.NoError(t, err) + assert.Equal(t, "Bearer old-token", creds.Token) + assert.Equal(t, "old-refresh", creds.RefreshToken) + assert.Equal(t, "user@example.com", creds.UserEmail) + + // Confirm credential fields are fully removed (not just empty strings) + ctxMap := viperHome.GetStringMap("contexts.astronomer_io") + assert.NotContains(t, ctxMap, "token") + assert.NotContains(t, ctxMap, "refreshtoken") + assert.NotContains(t, ctxMap, "user_email") + assert.NotContains(t, ctxMap, "expiresin") + // Non-credential fields survive + assert.Contains(t, ctxMap, "domain") + assert.Contains(t, ctxMap, "workspace") +} + +func TestMigrateLegacyCredentials_MultipleContexts(t *testing.T) { + fs := afero.NewMemMapFs() + configRaw := []byte(` +context: astronomer_io +contexts: + astronomer_io: + domain: astronomer.io + token: "Bearer token-a" + refreshtoken: "refresh-a" + user_email: "a@example.com" + astronomer_stage_io: + domain: astronomer-stage.io + token: "Bearer token-b" + refreshtoken: "refresh-b" + user_email: "b@example.com" +`) + err := afero.WriteFile(fs, HomeConfigFile, configRaw, 0o777) + require.NoError(t, err) + InitConfig(fs) + + store := keychain.NewTestStore() + migrated, err := MigrateLegacyCredentials(store) + require.NoError(t, err) + assert.Equal(t, 2, migrated) + + credsA, err := store.GetCredentials("astronomer.io") + require.NoError(t, err) + assert.Equal(t, "Bearer token-a", credsA.Token) + + credsB, err := store.GetCredentials("astronomer-stage.io") + require.NoError(t, err) + assert.Equal(t, "Bearer token-b", credsB.Token) +} + +func TestMigrateLegacyCredentials_Idempotent(t *testing.T) { + fs := afero.NewMemMapFs() + configRaw := []byte(` +context: astronomer_io +contexts: + astronomer_io: + domain: astronomer.io + token: "Bearer old-token" + refreshtoken: "old-refresh" + user_email: "user@example.com" +`) + err := afero.WriteFile(fs, HomeConfigFile, configRaw, 0o777) + require.NoError(t, err) + InitConfig(fs) + + store := keychain.NewTestStore() + migrated, err := MigrateLegacyCredentials(store) + require.NoError(t, err) + assert.Equal(t, 1, migrated) + + // Second call: nothing in config.yaml to migrate + migrated, err = MigrateLegacyCredentials(store) + require.NoError(t, err) + assert.Equal(t, 0, migrated) +} diff --git a/context/context_test.go b/context/context_test.go index 0989be30c..5f47d79a3 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -36,10 +36,9 @@ func (s *Suite) TestGetCurrentContext() { testUtil.InitTestConfig(testUtil.LocalPlatform) cluster, err := GetCurrentContext() s.NoError(err) - s.Equal(cluster.Domain, testUtil.GetEnv("HOST", "localhost")) + s.Equal("localhost", cluster.Domain) s.Equal(cluster.Workspace, "ck05r3bor07h40d02y2hw4n4v") s.Equal(cluster.LastUsedWorkspace, "ck05r3bor07h40d02y2hw4n4v") - s.Equal(cluster.Token, "token") } func (s *Suite) TestGetContextKeyValidContextConfig() { diff --git a/go.mod b/go.mod index 087f3bb7c..1b0baea47 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,8 @@ require ( 4d63.com/gochecknoglobals v0.2.1 // indirect dario.cat/mergo v1.0.1 // indirect github.com/4meepo/tagalign v1.3.4 // indirect + github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect + github.com/99designs/keyring v1.2.2 // indirect github.com/Abirdcfly/dupword v0.1.3 // indirect github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 // indirect github.com/Antonboom/errname v1.0.0 // indirect @@ -128,10 +130,12 @@ require ( github.com/curioswitch/go-reassign v0.3.0 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/daixiang0/gci v0.13.5 // indirect + github.com/danieljoos/wincred v1.2.1 // indirect github.com/denis-tingaikin/go-header v0.5.0 // indirect github.com/docker/cli-docs-tool v0.8.0 // indirect github.com/docker/distribution v2.8.3+incompatible // indirect github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936 // indirect + github.com/dvsekhvalnov/jose2go v1.5.0 // indirect github.com/eiannone/keyboard v0.0.0-20220611211555-0d226195f203 // indirect github.com/emicklei/go-restful/v3 v3.12.1 // indirect github.com/ettle/strcase v0.2.0 // indirect @@ -160,6 +164,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect github.com/gobwas/glob v0.2.3 // indirect + github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect github.com/gofrs/uuid v4.4.0+incompatible // indirect github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect github.com/golangci/go-printf-func-name v0.1.0 // indirect @@ -178,6 +183,7 @@ require ( github.com/gostaticanalysis/nilerr v0.1.1 // indirect github.com/gostaticanalysis/testutil v0.5.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect + github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-version v1.7.0 // indirect github.com/hexops/gotextdiff v1.0.3 // indirect @@ -221,6 +227,7 @@ require ( github.com/moby/sys/userns v0.1.0 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/moricho/tparallel v0.3.2 // indirect + github.com/mtibben/percent v0.2.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/nakabonne/nestif v0.3.1 // indirect diff --git a/go.sum b/go.sum index d2d94b0d1..e0fecd2cb 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,10 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/4meepo/tagalign v1.3.4 h1:P51VcvBnf04YkHzjfclN6BbsopfJR5rxs1n+5zHt+w8= github.com/4meepo/tagalign v1.3.4/go.mod h1:M+pnkHH2vG8+qhE5bVc/zeP7HS/j910Fwa9TUSyZVI0= +github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= +github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= +github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0= +github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTBqhFkHUrPk= github.com/Abirdcfly/dupword v0.1.3 h1:9Pa1NuAsZvpFPi9Pqkd93I7LIYRURj+A//dFd5tgBeE= github.com/Abirdcfly/dupword v0.1.3/go.mod h1:8VbB2t7e10KRNdwTVoxdBaxla6avbhGzb8sCTygUMhw= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= @@ -243,6 +247,8 @@ github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxG github.com/daaku/go.zipexe v1.0.0/go.mod h1:z8IiR6TsVLEYKwXAoE/I+8ys/sDkgTzSL0CLnGVd57E= github.com/daixiang0/gci v0.13.5 h1:kThgmH1yBmZSBCh1EJVxQ7JsHpm5Oms0AMed/0LaH4c= github.com/daixiang0/gci v0.13.5/go.mod h1:12etP2OniiIdP4q+kjUGrC/rUagga7ODbqsom5Eo5Yk= +github.com/danieljoos/wincred v1.2.1 h1:dl9cBrupW8+r5250DYkYxocLeZ1Y4vB1kxgtjxw8GQs= +github.com/danieljoos/wincred v1.2.1/go.mod h1:uGaFL9fDn3OLTvzCGulzE+SzjEe5NGlh5FdCcyfPwps= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -286,6 +292,8 @@ github.com/dprotaso/go-yit v0.0.0-20191028211022-135eb7262960/go.mod h1:9HQzr9D/ github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936 h1:PRxIJD8XjimM5aTknUK9w6DHLDox2r2M3DI4i2pnd3w= github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936/go.mod h1:ttYvX5qlB+mlV1okblJqcSMtR4c52UKxDiX9GRBS8+Q= github.com/dvsekhvalnov/jose2go v0.0.0-20170216131308-f21a8cedbbae/go.mod h1:7BvyPhdbLxMXIYTFPLsyJRFMsKmOZnQmzh6Gb+uquuM= +github.com/dvsekhvalnov/jose2go v1.5.0 h1:3j8ya4Z4kMCwT5nXIKFSV84YS+HdqSSO0VsTQxaLAeM= +github.com/dvsekhvalnov/jose2go v1.5.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= @@ -390,6 +398,8 @@ github.com/go-xmlfmt/xmlfmt v1.1.3 h1:t8Ey3Uy7jDSEisW2K3somuMKIpzktkWptA0iFCnRUW github.com/go-xmlfmt/xmlfmt v1.1.3/go.mod h1:aUCEOzzezBEjDBbFBoSiya/gduyIiWYRP6CnSFIV8AM= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0= +github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= @@ -500,6 +510,8 @@ github.com/gostaticanalysis/testutil v0.5.0 h1:Dq4wT1DdTwTGCQQv3rl3IvD5Ld0E6HiY+ github.com/gostaticanalysis/testutil v0.5.0/go.mod h1:OLQSbuM6zw2EvCcXTz1lVq5unyoNft372msDY0nY5Hs= github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI= +github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= +github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -716,6 +728,8 @@ github.com/moricho/tparallel v0.3.2/go.mod h1:OQ+K3b4Ln3l2TZveGCywybl68glfLEwFGq github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/mreiferson/go-httpclient v0.0.0-20160630210159-31f0106b4474/go.mod h1:OQA4XLvDbMgS8P0CevmM4m9Q3Jq4phKUzcocxuGJ5m8= +github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs= +github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= @@ -1334,6 +1348,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= diff --git a/houston/app_test.go b/houston/app_test.go index a5dd687a4..daa294d71 100644 --- a/houston/app_test.go +++ b/houston/app_test.go @@ -45,7 +45,7 @@ func (s *Suite) TestGetAppConfig() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) config, err := api.GetAppConfig("") s.NoError(err) @@ -68,7 +68,7 @@ func (s *Suite) TestGetAppConfig() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) // reset the local variables appConfig = nil @@ -98,7 +98,7 @@ func (s *Suite) TestGetAppConfig() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetAppConfig("") s.EqualError(err, ErrFieldsNotAvailable{}.Error()) @@ -127,7 +127,7 @@ func (s *Suite) TestGetAvailableNamespaces() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) namespaces, err := api.GetAvailableNamespaces(nil) s.NoError(err) @@ -142,7 +142,7 @@ func (s *Suite) TestGetAvailableNamespaces() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetAvailableNamespaces(nil) s.Contains(err.Error(), "Internal Server Error") @@ -168,7 +168,7 @@ func (s *Suite) TestGetPlatformVersion() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) version = "0.30.0" versionErr = nil resp, err := api.GetPlatformVersion(nil) @@ -184,7 +184,7 @@ func (s *Suite) TestGetPlatformVersion() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) version = "" versionErr = errMockHouston resp, err := api.GetPlatformVersion(nil) @@ -200,7 +200,7 @@ func (s *Suite) TestGetPlatformVersion() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) version = "" versionErr = nil platformVersion, err := api.GetPlatformVersion(nil) @@ -216,7 +216,7 @@ func (s *Suite) TestGetPlatformVersion() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) version = "" versionErr = nil _, err := api.GetPlatformVersion(nil) diff --git a/houston/auth_test.go b/houston/auth_test.go index 3202913c5..3fa8c121d 100644 --- a/houston/auth_test.go +++ b/houston/auth_test.go @@ -36,7 +36,7 @@ func (s *Suite) TestAuthenticateWithBasicAuth() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) token, err := api.AuthenticateWithBasicAuth(BasicAuthRequest{"username", "password", &ctx}) s.NoError(err) @@ -51,7 +51,7 @@ func (s *Suite) TestAuthenticateWithBasicAuth() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.AuthenticateWithBasicAuth(BasicAuthRequest{"username", "password", &ctx}) s.Contains(err.Error(), "Internal Server Error") @@ -83,7 +83,7 @@ func (s *Suite) TestGetAuthConfig() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) authConfig, err := api.GetAuthConfig(&ctx) s.NoError(err) @@ -98,7 +98,7 @@ func (s *Suite) TestGetAuthConfig() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetAuthConfig(&ctx) s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/decorator.go b/houston/decorator.go index 15d08bea4..07b26880c 100644 --- a/houston/decorator.go +++ b/houston/decorator.go @@ -152,7 +152,7 @@ func getVersion() string { // fallback case in which somehow we reach here without getting houston version httpClient := NewHTTPClient() - client := NewClient(httpClient) + client := NewClient(httpClient, nil) version, versionErr = client.GetPlatformVersion(nil) return version diff --git a/houston/deployment_teams_test.go b/houston/deployment_teams_test.go index 57d7cca49..1dbc7147f 100644 --- a/houston/deployment_teams_test.go +++ b/houston/deployment_teams_test.go @@ -37,7 +37,7 @@ func (s *Suite) TestAddDeploymentTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.AddDeploymentTeam(AddDeploymentTeamRequest{"deployment-id", "team-id", "role"}) s.NoError(err) @@ -52,7 +52,7 @@ func (s *Suite) TestAddDeploymentTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.AddDeploymentTeam(AddDeploymentTeamRequest{"deployment-id", "team-id", "role"}) s.Contains(err.Error(), "Internal Server Error") @@ -87,7 +87,7 @@ func (s *Suite) TestDeleteDeploymentTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.RemoveDeploymentTeam(RemoveDeploymentTeamRequest{"deployment-id", "team-id"}) s.NoError(err) @@ -102,7 +102,7 @@ func (s *Suite) TestDeleteDeploymentTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.RemoveDeploymentTeam(RemoveDeploymentTeamRequest{"deployment-id", "team-id"}) s.Contains(err.Error(), "Internal Server Error") @@ -134,7 +134,7 @@ func (s *Suite) TestListDeploymentTeamsAndRoles() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListDeploymentTeamsAndRoles("deployment-id") s.NoError(err) @@ -149,7 +149,7 @@ func (s *Suite) TestListDeploymentTeamsAndRoles() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListDeploymentTeamsAndRoles("deploymeny-id") s.Contains(err.Error(), "Internal Server Error") @@ -184,7 +184,7 @@ func (s *Suite) TestUpdateDeploymentTeamAndRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.UpdateDeploymentTeamRole(UpdateDeploymentTeamRequest{"deployment-id", "team-id", DeploymentAdminRole}) s.NoError(err) @@ -199,7 +199,7 @@ func (s *Suite) TestUpdateDeploymentTeamAndRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateDeploymentTeamRole(UpdateDeploymentTeamRequest{"deployment-id", "team-id", "role"}) s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/deployment_test.go b/houston/deployment_test.go index bc79d356f..b0ee41f69 100644 --- a/houston/deployment_test.go +++ b/houston/deployment_test.go @@ -47,7 +47,7 @@ func (s *Suite) TestCreateDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.CreateDeployment(map[string]interface{}{}) s.NoError(err) @@ -88,7 +88,7 @@ func (s *Suite) TestCreateDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.CreateDeployment(map[string]interface{}{}) s.NoError(err) @@ -103,7 +103,7 @@ func (s *Suite) TestCreateDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.CreateDeployment(map[string]interface{}{}) s.Contains(err.Error(), "Internal Server Error") @@ -147,7 +147,7 @@ func (s *Suite) TestDeleteDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.DeleteDeployment(DeleteDeploymentRequest{"deployment-id", false}) s.NoError(err) @@ -162,7 +162,7 @@ func (s *Suite) TestDeleteDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.DeleteDeployment(DeleteDeploymentRequest{"deployment-id", false}) s.Contains(err.Error(), "Internal Server Error") @@ -208,7 +208,7 @@ func (s *Suite) TestListDeployments() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deploymentList, err := api.ListDeployments(ListDeploymentsRequest{}) s.NoError(err) @@ -223,7 +223,7 @@ func (s *Suite) TestListDeployments() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListDeployments(ListDeploymentsRequest{}) s.Contains(err.Error(), "Internal Server Error") @@ -267,7 +267,7 @@ func (s *Suite) TestUpdateDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.UpdateDeployment(map[string]interface{}{}) s.NoError(err) @@ -308,7 +308,7 @@ func (s *Suite) TestUpdateDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.UpdateDeployment(map[string]interface{}{}) s.NoError(err) @@ -323,7 +323,7 @@ func (s *Suite) TestUpdateDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateDeployment(map[string]interface{}{}) s.Contains(err.Error(), "Internal Server Error") @@ -367,7 +367,7 @@ func (s *Suite) TestGetDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.GetDeployment("deployment-id") s.NoError(err) @@ -382,7 +382,7 @@ func (s *Suite) TestGetDeployment() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetDeployment("deployment-id") s.Contains(err.Error(), "Internal Server Error") @@ -426,7 +426,7 @@ func (s *Suite) TestUpdateDeploymentAirflow() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.UpdateDeploymentAirflow(map[string]interface{}{}) s.NoError(err) @@ -441,7 +441,7 @@ func (s *Suite) TestUpdateDeploymentAirflow() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateDeploymentAirflow(map[string]interface{}{}) s.Contains(err.Error(), "Internal Server Error") @@ -479,7 +479,7 @@ func (s *Suite) TestGetDeploymentConfig() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deploymentConfig, err := api.GetDeploymentConfig(nil) s.NoError(err) @@ -494,7 +494,7 @@ func (s *Suite) TestGetDeploymentConfig() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetDeploymentConfig(nil) s.Contains(err.Error(), "Internal Server Error") @@ -524,7 +524,7 @@ func (s *Suite) TestListDeploymentLogs() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) logs, err := api.ListDeploymentLogs(ListDeploymentLogsRequest{}) s.NoError(err) @@ -539,7 +539,7 @@ func (s *Suite) TestListDeploymentLogs() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListDeploymentLogs(ListDeploymentLogsRequest{}) s.Contains(err.Error(), "Internal Server Error") @@ -586,7 +586,7 @@ func (s *Suite) TestUpdateDeploymentRuntime() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.UpdateDeploymentRuntime(map[string]interface{}{}) s.NoError(err) @@ -601,7 +601,7 @@ func (s *Suite) TestUpdateDeploymentRuntime() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateDeploymentRuntime(map[string]interface{}{}) s.Contains(err.Error(), "Internal Server Error") @@ -635,7 +635,7 @@ func (s *Suite) TestCancelUpdateDeploymentRuntime() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) deployment, err := api.CancelUpdateDeploymentRuntime(map[string]interface{}{}) s.NoError(err) @@ -650,7 +650,7 @@ func (s *Suite) TestCancelUpdateDeploymentRuntime() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.CancelUpdateDeploymentRuntime(map[string]interface{}{}) s.Contains(err.Error(), "Internal Server Error") @@ -680,7 +680,7 @@ func (s *Suite) TestUpdateDeploymentImage() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateDeploymentImage(UpdateDeploymentImageRequest{ReleaseName: mockDeployment.Data.UpdateDeploymentImage.ReleaseName, AirflowVersion: mockDeployment.Data.UpdateDeploymentImage.AirflowVersion}) s.NoError(err) @@ -694,7 +694,7 @@ func (s *Suite) TestUpdateDeploymentImage() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateDeploymentImage(UpdateDeploymentImageRequest{ReleaseName: mockDeployment.Data.UpdateDeploymentImage.ReleaseName, AirflowVersion: mockDeployment.Data.UpdateDeploymentImage.AirflowVersion}) s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/deployment_user_test.go b/houston/deployment_user_test.go index 2eb62d729..f6ff72788 100644 --- a/houston/deployment_user_test.go +++ b/houston/deployment_user_test.go @@ -40,7 +40,7 @@ func (s *Suite) TestListDeploymentUsers() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListDeploymentUsers(ListDeploymentUsersRequest{}) s.NoError(err) @@ -55,7 +55,7 @@ func (s *Suite) TestListDeploymentUsers() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListDeploymentUsers(ListDeploymentUsersRequest{}) s.Contains(err.Error(), "Internal Server Error") @@ -94,7 +94,7 @@ func (s *Suite) TestAddDeploymentUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.AddDeploymentUser(UpdateDeploymentUserRequest{}) s.NoError(err) @@ -110,7 +110,7 @@ func (s *Suite) TestAddDeploymentUser() { } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.AddDeploymentUser(UpdateDeploymentUserRequest{}) s.Contains(err.Error(), "Internal Server Error") @@ -149,7 +149,7 @@ func (s *Suite) TestUpdateDeploymentUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.UpdateDeploymentUser(UpdateDeploymentUserRequest{}) s.NoError(err) @@ -164,7 +164,7 @@ func (s *Suite) TestUpdateDeploymentUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateDeploymentUser(UpdateDeploymentUserRequest{}) s.Contains(err.Error(), "Internal Server Error") @@ -203,7 +203,7 @@ func (s *Suite) TestDeleteDeploymentUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.DeleteDeploymentUser(DeleteDeploymentUserRequest{"deployment-id", "email"}) s.NoError(err) @@ -218,7 +218,7 @@ func (s *Suite) TestDeleteDeploymentUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.DeleteDeploymentUser(DeleteDeploymentUserRequest{"deployment-id", "email"}) s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/houston.go b/houston/houston.go index 58cfc57f1..5903dc771 100644 --- a/houston/houston.go +++ b/houston/houston.go @@ -107,8 +107,8 @@ type ClientImplementation struct { // NewClient - initialized the Houston Client object with proper HTTP Client configuration // set as a variable so we can change it to return mock houston clients in tests -var NewClient = func(c *httputil.HTTPClient) ClientInterface { - client := newInternalClient(c) +var NewClient = func(c *httputil.HTTPClient, tokenHolder *httputil.TokenHolder) ClientInterface { + client := newInternalClient(c, tokenHolder) return &ClientImplementation{ client: client, } @@ -116,7 +116,8 @@ var NewClient = func(c *httputil.HTTPClient) ClientInterface { // Client containers the logger and HTTPClient used to communicate with the HoustonAPI type Client struct { - HTTPClient *httputil.HTTPClient + HTTPClient *httputil.HTTPClient + tokenHolder *httputil.TokenHolder } func NewHTTPClient() *httputil.HTTPClient { @@ -135,9 +136,10 @@ func NewHTTPClient() *httputil.HTTPClient { } // newInternalClient returns a new Client with the logger and HTTP Client setup. -func newInternalClient(c *httputil.HTTPClient) *Client { +func newInternalClient(c *httputil.HTTPClient, tokenHolder *httputil.TokenHolder) *Client { return &Client{ - HTTPClient: c, + HTTPClient: c, + tokenHolder: tokenHolder, } } @@ -164,7 +166,7 @@ func (r *Request) DoWithClient(api *Client) (*Response, error) { // Do (request) is a wrapper to more easily pass variables to a Client.Do request func (r *Request) Do() (*Response, error) { - return r.DoWithClient(newInternalClient(httputil.NewHTTPClient())) + return r.DoWithClient(newInternalClient(httputil.NewHTTPClient(), nil)) } // Do fetches the current context, and returns Houston API response, error @@ -180,8 +182,10 @@ func (c *Client) Do(doOpts *httputil.DoOptions) (*Response, error) { // DoWithContext executes a query against the Houston API, logging out any errors contained in the response object func (c *Client) DoWithContext(doOpts *httputil.DoOptions, ctx *config.Context) (*Response, error) { // set headers - if ctx.Token != "" { - doOpts.Headers["authorization"] = ctx.Token + if c.tokenHolder != nil { + if tok := c.tokenHolder.Get(); tok != "" { + doOpts.Headers["authorization"] = tok + } } newLogger.Debugf("Request Data: %v\n", string(doOpts.Data)) doOpts.Method = http.MethodPost diff --git a/houston/houston_test.go b/houston/houston_test.go index d6edd2dba..36a17b13a 100644 --- a/houston/houston_test.go +++ b/houston/houston_test.go @@ -11,7 +11,7 @@ import ( ) func (s *Suite) TestNewHoustonClient() { - client := newInternalClient(httputil.NewHTTPClient()) + client := newInternalClient(httputil.NewHTTPClient(), nil) s.NotNil(client, "Can't create new houston Client") } diff --git a/houston/runtime_test.go b/houston/runtime_test.go index 4f45d8eb2..795d8af7e 100644 --- a/houston/runtime_test.go +++ b/houston/runtime_test.go @@ -30,7 +30,7 @@ func (s *Suite) TestGetRuntimeReleases() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) vars := make(map[string]interface{}) resp, err := api.GetRuntimeReleases(vars) @@ -46,7 +46,7 @@ func (s *Suite) TestGetRuntimeReleases() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) vars := make(map[string]interface{}) vars["clusterId"] = "test-cluster-id" @@ -63,7 +63,7 @@ func (s *Suite) TestGetRuntimeReleases() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) vars := make(map[string]interface{}) vars["airflowVersion"] = "2.2.4" @@ -80,7 +80,7 @@ func (s *Suite) TestGetRuntimeReleases() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) vars := make(map[string]interface{}) _, err := api.GetRuntimeReleases(vars) diff --git a/houston/service_account_test.go b/houston/service_account_test.go index 528daaa75..c1b590458 100644 --- a/houston/service_account_test.go +++ b/houston/service_account_test.go @@ -39,7 +39,7 @@ func (s *Suite) TestCreateDeploymentServiceAccount() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.CreateDeploymentServiceAccount(&CreateServiceAccountRequest{}) s.NoError(err) @@ -54,7 +54,7 @@ func (s *Suite) TestCreateDeploymentServiceAccount() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.CreateDeploymentServiceAccount(&CreateServiceAccountRequest{}) s.Contains(err.Error(), "Internal Server Error") @@ -91,7 +91,7 @@ func (s *Suite) TestCreateWorkspaceServiceAccount() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.CreateWorkspaceServiceAccount(&CreateServiceAccountRequest{}) s.NoError(err) @@ -106,7 +106,7 @@ func (s *Suite) TestCreateWorkspaceServiceAccount() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.CreateWorkspaceServiceAccount(&CreateServiceAccountRequest{}) s.Contains(err.Error(), "Internal Server Error") @@ -141,7 +141,7 @@ func (s *Suite) TestDeleteDeploymentServiceAccount() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.DeleteDeploymentServiceAccount(DeleteServiceAccountRequest{"", "deployment-id", "sa-id"}) s.NoError(err) @@ -156,7 +156,7 @@ func (s *Suite) TestDeleteDeploymentServiceAccount() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.DeleteDeploymentServiceAccount(DeleteServiceAccountRequest{"", "deployment-id", "sa-id"}) s.Contains(err.Error(), "Internal Server Error") @@ -191,7 +191,7 @@ func (s *Suite) TestDeleteWorkspaceServiceAccount() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.DeleteWorkspaceServiceAccount(DeleteServiceAccountRequest{"workspace-id", "", "sa-id"}) s.NoError(err) @@ -206,7 +206,7 @@ func (s *Suite) TestDeleteWorkspaceServiceAccount() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.DeleteWorkspaceServiceAccount(DeleteServiceAccountRequest{"workspace-id", "", "sa-id"}) s.Contains(err.Error(), "Internal Server Error") @@ -253,7 +253,7 @@ func (s *Suite) TestListDeploymentServiceAccounts() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListDeploymentServiceAccounts("deployment-id") s.NoError(err) @@ -268,7 +268,7 @@ func (s *Suite) TestListDeploymentServiceAccounts() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListDeploymentServiceAccounts("deployment-id") s.Contains(err.Error(), "Internal Server Error") @@ -315,7 +315,7 @@ func (s *Suite) TestListWorkspaceServiceAccounts() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListWorkspaceServiceAccounts("workspace-id") s.NoError(err) @@ -330,7 +330,7 @@ func (s *Suite) TestListWorkspaceServiceAccounts() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListWorkspaceServiceAccounts("workspace-id") s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/teams_test.go b/houston/teams_test.go index 37c2b8e8b..127270d47 100644 --- a/houston/teams_test.go +++ b/houston/teams_test.go @@ -30,7 +30,7 @@ func (s *Suite) TestGetTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.GetTeam("team-id") s.NoError(err) @@ -45,7 +45,7 @@ func (s *Suite) TestGetTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetTeam("team-id") s.Contains(err.Error(), "Internal Server Error") @@ -75,7 +75,7 @@ func (s *Suite) TestGetTeamUsers() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.GetTeamUsers("team-id") s.NoError(err) @@ -90,7 +90,7 @@ func (s *Suite) TestGetTeamUsers() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetTeamUsers("team-id") s.Contains(err.Error(), "Internal Server Error") @@ -123,7 +123,7 @@ func (s *Suite) TestListTeams() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListTeams(ListTeamsRequest{"", 1}) s.NoError(err) @@ -138,7 +138,7 @@ func (s *Suite) TestListTeams() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListTeams(ListTeamsRequest{"", 1}) s.Contains(err.Error(), "Internal Server Error") @@ -165,7 +165,7 @@ func (s *Suite) TestCreateTeamSystemRoleBinding() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.CreateTeamSystemRoleBinding(SystemRoleBindingRequest{"test-id", SystemAdminRole}) s.NoError(err) @@ -180,7 +180,7 @@ func (s *Suite) TestCreateTeamSystemRoleBinding() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.CreateTeamSystemRoleBinding(SystemRoleBindingRequest{"test-id", SystemAdminRole}) s.Contains(err.Error(), "Internal Server Error") @@ -207,7 +207,7 @@ func (s *Suite) TestDeleteTeamSystemRoleBinding() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.DeleteTeamSystemRoleBinding(SystemRoleBindingRequest{"test-id", SystemAdminRole}) s.NoError(err) @@ -222,7 +222,7 @@ func (s *Suite) TestDeleteTeamSystemRoleBinding() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.DeleteTeamSystemRoleBinding(SystemRoleBindingRequest{"test-id", SystemAdminRole}) s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/user_test.go b/houston/user_test.go index 31ba421cf..11dcd5000 100644 --- a/houston/user_test.go +++ b/houston/user_test.go @@ -40,7 +40,7 @@ func (s *Suite) TestCreateUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.CreateUser(CreateUserRequest{"email", "password"}) s.NoError(err) @@ -55,7 +55,7 @@ func (s *Suite) TestCreateUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.CreateUser(CreateUserRequest{"email", "password"}) s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/workspace_teams_test.go b/houston/workspace_teams_test.go index aa2d4ca96..ac1470c0b 100644 --- a/houston/workspace_teams_test.go +++ b/houston/workspace_teams_test.go @@ -34,7 +34,7 @@ func (s *Suite) TestAddWorkspaceTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.AddWorkspaceTeam(AddWorkspaceTeamRequest{"workspace-id", "team-id", "role"}) s.NoError(err) @@ -49,7 +49,7 @@ func (s *Suite) TestAddWorkspaceTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.AddWorkspaceTeam(AddWorkspaceTeamRequest{"workspace-id", "team-id", "role"}) s.Contains(err.Error(), "Internal Server Error") @@ -81,7 +81,7 @@ func (s *Suite) TestDeleteWorkspaceTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.DeleteWorkspaceTeam(DeleteWorkspaceTeamRequest{"workspace-id", "user-id"}) s.NoError(err) @@ -96,7 +96,7 @@ func (s *Suite) TestDeleteWorkspaceTeam() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.DeleteWorkspaceTeam(DeleteWorkspaceTeamRequest{"workspace-id", "user-id"}) s.Contains(err.Error(), "Internal Server Error") @@ -128,7 +128,7 @@ func (s *Suite) TestListWorkspaceTeamsAndRoles() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListWorkspaceTeamsAndRoles("workspace-id") s.NoError(err) @@ -143,7 +143,7 @@ func (s *Suite) TestListWorkspaceTeamsAndRoles() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListWorkspaceTeamsAndRoles("workspace-id") s.Contains(err.Error(), "Internal Server Error") @@ -169,7 +169,7 @@ func (s *Suite) TestUpdateWorkspaceTeamAndRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.UpdateWorkspaceTeamRole(UpdateWorkspaceTeamRoleRequest{"workspace-id", "team-id", WorkspaceAdminRole}) s.NoError(err) @@ -184,7 +184,7 @@ func (s *Suite) TestUpdateWorkspaceTeamAndRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateWorkspaceTeamRole(UpdateWorkspaceTeamRoleRequest{"workspace-id", "team-id", "role"}) s.Contains(err.Error(), "Internal Server Error") @@ -206,7 +206,7 @@ func (s *Suite) TestGetWorkspaceTeamRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.GetWorkspaceTeamRole(GetWorkspaceTeamRoleRequest{"workspace-id", "team-id"}) s.NoError(err) @@ -221,7 +221,7 @@ func (s *Suite) TestGetWorkspaceTeamRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetWorkspaceTeamRole(GetWorkspaceTeamRoleRequest{"workspace-id", "team-id"}) s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/workspace_test.go b/houston/workspace_test.go index bfc8edf74..d77eaad14 100644 --- a/houston/workspace_test.go +++ b/houston/workspace_test.go @@ -44,7 +44,7 @@ func (s *Suite) TestCreateWorkspace() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.CreateWorkspace(CreateWorkspaceRequest{"label", "description"}) s.NoError(err) @@ -59,7 +59,7 @@ func (s *Suite) TestCreateWorkspace() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.CreateWorkspace(CreateWorkspaceRequest{"label", "description"}) s.Contains(err.Error(), "Internal Server Error") @@ -120,7 +120,7 @@ func (s *Suite) TestListWorkspaces() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListWorkspaces(nil) s.NoError(err) @@ -135,7 +135,7 @@ func (s *Suite) TestListWorkspaces() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListWorkspaces(nil) s.Contains(err.Error(), "Internal Server Error") @@ -196,7 +196,7 @@ func (s *Suite) TestPaginatedListWorkspaces() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.PaginatedListWorkspaces(PaginatedListWorkspaceRequest{10, 0}) s.NoError(err) @@ -211,7 +211,7 @@ func (s *Suite) TestPaginatedListWorkspaces() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.PaginatedListWorkspaces(PaginatedListWorkspaceRequest{10, 0}) s.Contains(err.Error(), "Internal Server Error") @@ -253,7 +253,7 @@ func (s *Suite) TestDeleteWorkspace() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.DeleteWorkspace("workspace-id") s.NoError(err) @@ -268,7 +268,7 @@ func (s *Suite) TestDeleteWorkspace() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.DeleteWorkspace("workspace-id") s.Contains(err.Error(), "Internal Server Error") @@ -310,7 +310,7 @@ func (s *Suite) TestGetWorkspace() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.GetWorkspace("workspace-id") s.NoError(err) @@ -325,7 +325,7 @@ func (s *Suite) TestGetWorkspace() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetWorkspace("workspace-id") s.Contains(err.Error(), "Internal Server Error") @@ -357,7 +357,7 @@ func (s *Suite) TestValidateWorkspaceID() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ValidateWorkspaceID("workspace-id") s.NoError(err) @@ -372,7 +372,7 @@ func (s *Suite) TestValidateWorkspaceID() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ValidateWorkspaceID("workspace-id") s.Contains(err.Error(), "Internal Server Error") @@ -414,7 +414,7 @@ func (s *Suite) TestUpdateWorkspace() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.UpdateWorkspace(UpdateWorkspaceRequest{"workspace-id", map[string]string{}}) s.NoError(err) @@ -429,7 +429,7 @@ func (s *Suite) TestUpdateWorkspace() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateWorkspace(UpdateWorkspaceRequest{"workspace-id", map[string]string{}}) s.Contains(err.Error(), "Internal Server Error") diff --git a/houston/workspace_users_test.go b/houston/workspace_users_test.go index 336f35e82..81f18c3f5 100644 --- a/houston/workspace_users_test.go +++ b/houston/workspace_users_test.go @@ -43,7 +43,7 @@ func (s *Suite) TestAddWorkspaceUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.AddWorkspaceUser(AddWorkspaceUserRequest{"workspace-id", "email", "role"}) s.NoError(err) @@ -58,7 +58,7 @@ func (s *Suite) TestAddWorkspaceUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.AddWorkspaceUser(AddWorkspaceUserRequest{"workspace-id", "email", "role"}) s.Contains(err.Error(), "Internal Server Error") @@ -99,7 +99,7 @@ func (s *Suite) TestDeleteWorkspaceUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.DeleteWorkspaceUser(DeleteWorkspaceUserRequest{"workspace-id", "user-id"}) s.NoError(err) @@ -114,7 +114,7 @@ func (s *Suite) TestDeleteWorkspaceUser() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.DeleteWorkspaceUser(DeleteWorkspaceUserRequest{"workspace-id", "user-id"}) s.Contains(err.Error(), "Internal Server Error") @@ -152,7 +152,7 @@ func (s *Suite) TestListWorkspaceUserAndRoles() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListWorkspaceUserAndRoles("workspace-id") s.NoError(err) @@ -167,7 +167,7 @@ func (s *Suite) TestListWorkspaceUserAndRoles() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListWorkspaceUserAndRoles("workspace-id") s.Contains(err.Error(), "Internal Server Error") @@ -205,7 +205,7 @@ func (s *Suite) TestListWorkspacePaginatedUserAndRoles() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.ListWorkspacePaginatedUserAndRoles(PaginatedWorkspaceUserRolesRequest{"workspace-id", "cursor-id", 100}) s.NoError(err) @@ -220,7 +220,7 @@ func (s *Suite) TestListWorkspacePaginatedUserAndRoles() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.ListWorkspacePaginatedUserAndRoles(PaginatedWorkspaceUserRolesRequest{"workspace-id", "cursor-id", 100}) s.Contains(err.Error(), "Internal Server Error") @@ -246,7 +246,7 @@ func (s *Suite) TestUpdateWorkspaceUserAndRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.UpdateWorkspaceUserRole(UpdateWorkspaceUserRoleRequest{"workspace-id", "test@test.com", WorkspaceAdminRole}) s.NoError(err) @@ -261,7 +261,7 @@ func (s *Suite) TestUpdateWorkspaceUserAndRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.UpdateWorkspaceUserRole(UpdateWorkspaceUserRoleRequest{"workspace-id", "test@test.com", WorkspaceAdminRole}) s.Contains(err.Error(), "Internal Server Error") @@ -296,7 +296,7 @@ func (s *Suite) TestGetWorkspaceUserRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) response, err := api.GetWorkspaceUserRole(GetWorkspaceUserRoleRequest{"workspace-id", "email"}) s.NoError(err) @@ -311,7 +311,7 @@ func (s *Suite) TestGetWorkspaceUserRole() { Header: make(http.Header), } }) - api := NewClient(client) + api := NewClient(client, nil) _, err := api.GetWorkspaceUserRole(GetWorkspaceUserRoleRequest{"workspace-id", "email"}) s.Contains(err.Error(), "Internal Server Error") diff --git a/pkg/httputil/token_holder.go b/pkg/httputil/token_holder.go new file mode 100644 index 000000000..7554f517d --- /dev/null +++ b/pkg/httputil/token_holder.go @@ -0,0 +1,34 @@ +package httputil + +import "sync" + +// TokenHolder holds the current auth token in memory for the duration of a +// command invocation. It is populated by PersistentPreRunE after credentials +// are resolved from the secure store, and read by API client request editors +// on every outbound request. +// +// It is constructed once in NewRootCmd and passed by pointer to both the API +// clients and CreateRootPersistentPreRunE. There is no global state. +type TokenHolder struct { + mu sync.RWMutex + token string +} + +// NewTokenHolder creates a TokenHolder with an initial token value. +func NewTokenHolder(token string) *TokenHolder { + return &TokenHolder{token: token} +} + +// Set stores the token. +func (h *TokenHolder) Set(token string) { + h.mu.Lock() + h.token = token + h.mu.Unlock() +} + +// Get returns the current token. +func (h *TokenHolder) Get() string { + h.mu.RLock() + defer h.mu.RUnlock() + return h.token +} diff --git a/pkg/httputil/token_holder_test.go b/pkg/httputil/token_holder_test.go new file mode 100644 index 000000000..c21a08376 --- /dev/null +++ b/pkg/httputil/token_holder_test.go @@ -0,0 +1,20 @@ +package httputil_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astronomer/astro-cli/pkg/httputil" +) + +func TestTokenHolder(t *testing.T) { + h := &httputil.TokenHolder{} + assert.Equal(t, "", h.Get()) + + h.Set("Bearer abc") + assert.Equal(t, "Bearer abc", h.Get()) + + h.Set("") + assert.Equal(t, "", h.Get()) +} diff --git a/pkg/keychain/keychain.go b/pkg/keychain/keychain.go new file mode 100644 index 000000000..a4f90f145 --- /dev/null +++ b/pkg/keychain/keychain.go @@ -0,0 +1,79 @@ +package keychain + +import ( + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/99designs/keyring" +) + +const serviceName = "astro-cli" + +// ErrNotFound is returned when no credentials exist for the given domain. +var ErrNotFound = errors.New("credentials not found") + +// SecureStore persists and retrieves authentication credentials +// using the OS-native secure store. +type SecureStore interface { + GetCredentials(domain string) (Credentials, error) + SetCredentials(domain string, creds Credentials) error + DeleteCredentials(domain string) error +} + +// Credentials holds all authentication credentials for a single context. +type Credentials struct { + Token string `json:"token"` + RefreshToken string `json:"refreshtoken"` + UserEmail string `json:"user_email"` + ExpiresAt time.Time `json:"expires_at"` +} + +// keyringStore is the shared SecureStore implementation for macOS and Linux +// Secret Service, backed by a 99designs/keyring.Keyring. +// +// On Windows, see keychain_windows.go for the per-field implementation. +type keyringStore struct { + ring keyring.Keyring +} + +func (s *keyringStore) GetCredentials(domain string) (Credentials, error) { + item, err := s.ring.Get(domain) + if errors.Is(err, keyring.ErrKeyNotFound) { + return Credentials{}, ErrNotFound + } + if err != nil { + return Credentials{}, fmt.Errorf("reading credentials: %w", err) + } + var creds Credentials + if err := json.Unmarshal(item.Data, &creds); err != nil { + return Credentials{}, fmt.Errorf("decoding credentials: %w", err) + } + return creds, nil +} + +func (s *keyringStore) SetCredentials(domain string, creds Credentials) error { + data, err := json.Marshal(creds) + if err != nil { + return fmt.Errorf("encoding credentials: %w", err) + } + if err := s.ring.Set(keyring.Item{Key: domain, Label: "Astro CLI (" + domain + ")", Data: data}); err != nil { + return fmt.Errorf("writing credentials: %w", err) + } + return nil +} + +func (s *keyringStore) DeleteCredentials(domain string) error { + err := s.ring.Remove(domain) + if err == nil || errors.Is(err, keyring.ErrKeyNotFound) { + return nil + } + return fmt.Errorf("deleting credentials: %w", err) +} + +// NewTestStore returns an in-memory SecureStore for use in unit tests. +// It is backed by keyring.NewArrayKeyring which ships with 99designs/keyring. +func NewTestStore() SecureStore { + return &keyringStore{ring: keyring.NewArrayKeyring(nil)} +} diff --git a/pkg/keychain/keychain_darwin.go b/pkg/keychain/keychain_darwin.go new file mode 100644 index 000000000..badbb43d0 --- /dev/null +++ b/pkg/keychain/keychain_darwin.go @@ -0,0 +1,31 @@ +//go:build darwin + +package keychain + +import ( + "fmt" + + "github.com/99designs/keyring" +) + +// New returns a macOS Keychain-backed SecureStore. +// +// Items are stored with per-app ACL (KeychainTrustApplication) so that other +// processes — including the `security` CLI tool — must show a user prompt before +// reading them. After each CLI binary update (binary hash changes), macOS +// re-prompts once on first access. This is expected behavior. +// +// The cachedStore wrapper (see keychain_cached.go) ensures only one keychain +// access per domain per process, so the user sees at most one prompt per +// command invocation. +func New() (SecureStore, error) { + ring, err := keyring.Open(keyring.Config{ + ServiceName: serviceName, + KeychainTrustApplication: true, + KeychainAccessibleWhenUnlocked: true, + }) + if err != nil { + return nil, fmt.Errorf("system keychain unavailable: %w", err) + } + return newCachedStore(&keyringStore{ring: ring}), nil +} diff --git a/pkg/keychain/keychain_file.go b/pkg/keychain/keychain_file.go new file mode 100644 index 000000000..ec8d1d471 --- /dev/null +++ b/pkg/keychain/keychain_file.go @@ -0,0 +1,108 @@ +//go:build !darwin + +package keychain + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// fileStore is a plaintext JSON credential store for environments where no +// OS-native secure store is available (Linux without Secret Service, Windows +// before the Credential Manager backend lands). Credentials are written to +// ~/.astro/credentials.json with mode 0600. +// +// Writes go via a temp-file + rename so a crash mid-write cannot corrupt an +// existing credentials file. +type fileStore struct { + path string +} + +func newFileStore() (*fileStore, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("cannot determine home directory: %w", err) + } + dir := filepath.Join(home, ".astro") + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, fmt.Errorf("cannot create credentials directory: %w", err) + } + return &fileStore{path: filepath.Join(dir, "credentials.json")}, nil +} + +func writeAtomic(path string, data []byte) error { + tmp, err := os.CreateTemp(filepath.Dir(path), ".credentials-*.json") + if err != nil { + return err + } + tmpPath := tmp.Name() + if _, err := tmp.Write(data); err != nil { + tmp.Close() + os.Remove(tmpPath) + return err + } + if err := tmp.Close(); err != nil { + os.Remove(tmpPath) + return err + } + if err := os.Chmod(tmpPath, 0o600); err != nil { + os.Remove(tmpPath) + return err + } + return os.Rename(tmpPath, path) +} + +func (s *fileStore) read() (map[string]Credentials, error) { + data, err := os.ReadFile(s.path) + if os.IsNotExist(err) { + return map[string]Credentials{}, nil + } + if err != nil { + return nil, fmt.Errorf("reading credentials file: %w", err) + } + var store map[string]Credentials + if err := json.Unmarshal(data, &store); err != nil { + return nil, fmt.Errorf("decoding credentials file: %w", err) + } + return store, nil +} + +func (s *fileStore) write(store map[string]Credentials) error { + data, err := json.Marshal(store) + if err != nil { + return fmt.Errorf("encoding credentials: %w", err) + } + return writeAtomic(s.path, data) +} + +func (s *fileStore) GetCredentials(domain string) (Credentials, error) { + store, err := s.read() + if err != nil { + return Credentials{}, err + } + creds, ok := store[domain] + if !ok { + return Credentials{}, ErrNotFound + } + return creds, nil +} + +func (s *fileStore) SetCredentials(domain string, creds Credentials) error { + store, err := s.read() + if err != nil { + return err + } + store[domain] = creds + return s.write(store) +} + +func (s *fileStore) DeleteCredentials(domain string) error { + store, err := s.read() + if err != nil { + return err + } + delete(store, domain) + return s.write(store) +} diff --git a/pkg/keychain/keychain_linux.go b/pkg/keychain/keychain_linux.go new file mode 100644 index 000000000..53b8c90d9 --- /dev/null +++ b/pkg/keychain/keychain_linux.go @@ -0,0 +1,39 @@ +//go:build linux + +package keychain + +import ( + "fmt" + + "github.com/99designs/keyring" +) + +// New returns a Secret Service-backed SecureStore on Linux. +// +// If no Secret Service daemon is available (e.g. headless CI environments), +// falls back to a plaintext JSON file at ~/.astro/credentials.json with +// 0600 permissions. This matches the current plaintext config.yaml behaviour +// and is intentional — encrypted file fallback is not worth the complexity +// given that CI environments use ASTRO_API_TOKEN anyway. +// +// NOTE: if 99designs/keyring fails to connect to Secret Service in environments +// that DO have it running, replace with godbus/dbus directly: +// https://github.com/godbus/dbus — the SecureStore interface is the only +// change boundary. +func New() (SecureStore, error) { + ring, err := keyring.Open(keyring.Config{ + ServiceName: serviceName, + LibSecretCollectionName: "astro-cli", + KWalletAppID: "astro-cli", + KWalletFolder: "astro-cli", + }) + if err == nil { + return newCachedStore(&keyringStore{ring: ring}), nil + } + // Secret Service unavailable — fall back to plaintext file. + fs, err := newFileStore() + if err != nil { + return nil, fmt.Errorf("credential store unavailable: %w", err) + } + return fs, nil +} diff --git a/pkg/keychain/keychain_test.go b/pkg/keychain/keychain_test.go new file mode 100644 index 000000000..8047d5433 --- /dev/null +++ b/pkg/keychain/keychain_test.go @@ -0,0 +1,68 @@ +package keychain_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/astronomer/astro-cli/pkg/keychain" +) + +func TestSetAndGetCredentials(t *testing.T) { + store := keychain.NewTestStore() + creds := keychain.Credentials{ + Token: "Bearer access-token", + RefreshToken: "refresh-token", + UserEmail: "user@example.com", + ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), + } + + err := store.SetCredentials("astronomer.io", creds) + require.NoError(t, err) + + got, err := store.GetCredentials("astronomer.io") + require.NoError(t, err) + // Round-trip strips monotonic clock; normalise before comparing. + creds.ExpiresAt = creds.ExpiresAt.UTC() + got.ExpiresAt = got.ExpiresAt.UTC() + assert.Equal(t, creds, got) +} + +func TestGetCredentials_NotFound(t *testing.T) { + store := keychain.NewTestStore() + _, err := store.GetCredentials("notexist.io") + assert.ErrorIs(t, err, keychain.ErrNotFound) +} + +func TestDeleteCredentials(t *testing.T) { + store := keychain.NewTestStore() + creds := keychain.Credentials{Token: "Bearer tok"} + + require.NoError(t, store.SetCredentials("astronomer.io", creds)) + require.NoError(t, store.DeleteCredentials("astronomer.io")) + + _, err := store.GetCredentials("astronomer.io") + assert.ErrorIs(t, err, keychain.ErrNotFound) +} + +func TestDeleteCredentials_NotFound_NoError(t *testing.T) { + store := keychain.NewTestStore() + err := store.DeleteCredentials("notexist.io") + assert.NoError(t, err) +} + +func TestIsolation(t *testing.T) { + store := keychain.NewTestStore() + require.NoError(t, store.SetCredentials("a.io", keychain.Credentials{Token: "token-a"})) + require.NoError(t, store.SetCredentials("b.io", keychain.Credentials{Token: "token-b"})) + + a, err := store.GetCredentials("a.io") + require.NoError(t, err) + assert.Equal(t, "token-a", a.Token) + + b, err := store.GetCredentials("b.io") + require.NoError(t, err) + assert.Equal(t, "token-b", b.Token) +} diff --git a/pkg/keychain/keychain_windows.go b/pkg/keychain/keychain_windows.go new file mode 100644 index 000000000..f0a32cd07 --- /dev/null +++ b/pkg/keychain/keychain_windows.go @@ -0,0 +1,18 @@ +//go:build windows + +package keychain + +import "fmt" + +// New returns a file-backed SecureStore on Windows. +// +// This is a temporary fallback until the Windows Credential Manager backend +// lands — credentials are stored in ~/.astro/credentials.json (mode 0600) +// rather than in Credential Manager. See the follow-up PR for the upgrade. +func New() (SecureStore, error) { + fs, err := newFileStore() + if err != nil { + return nil, fmt.Errorf("credential store unavailable: %w", err) + } + return newCachedStore(fs), nil +} diff --git a/pkg/testing/testing.go b/pkg/testing/testing.go index 09c3bc8cc..2762bbdfd 100644 --- a/pkg/testing/testing.go +++ b/pkg/testing/testing.go @@ -64,7 +64,6 @@ context: %s contexts: %s: domain: %s - token: token last_used_workspace: ck05r3bor07h40d02y2hw4n4v workspace: ck05r3bor07h40d02y2hw4n4v organization: test-org-id diff --git a/software/auth/auth.go b/software/auth/auth.go index 571c571e7..d546d91d4 100644 --- a/software/auth/auth.go +++ b/software/auth/auth.go @@ -13,6 +13,7 @@ import ( "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/houston" "github.com/astronomer/astro-cli/pkg/input" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" "github.com/astronomer/astro-cli/software/workspace" ) @@ -84,7 +85,7 @@ func oAuth(oAuthURL string) string { } // RegistryAuth authenticates with the private registry -func RegistryAuth(client houston.ClientInterface, out io.Writer, registryDomain string) error { +func RegistryAuth(client houston.ClientInterface, out io.Writer, registryDomain, token string) error { c, err := context.GetCurrentContext() if err != nil { return err @@ -126,7 +127,7 @@ func RegistryAuth(client houston.ClientInterface, out io.Writer, registryDomain } if !appConfig.Flags.BYORegistryEnabled { - err = registryHandler.Login("user", c.Token) + err = registryHandler.Login("user", token) } else { err = registryHandler.Login("", "") } @@ -161,7 +162,7 @@ func getWorkspaces(client houston.ClientInterface, interactive bool) ([]houston. } // Login handles authentication to houston and registry -func Login(domain string, oAuthOnly bool, username, password, houstonVersion string, client houston.ClientInterface, out io.Writer) error { +func Login(domain string, oAuthOnly bool, username, password, houstonVersion string, store keychain.SecureStore, client houston.ClientInterface, out io.Writer) error { var token string var err error var pageSize int @@ -206,9 +207,12 @@ func Login(domain string, oAuthOnly bool, username, password, houstonVersion str return err } - err = c.SetContextKey("token", token) - if err != nil { - return err + if store == nil { + return fmt.Errorf("credential store not available; cannot save login credentials") + } + // Houston tokens do not have refresh tokens or expiry — only Token is stored. + if err := store.SetCredentials(c.Domain, keychain.Credentials{Token: token}); err != nil { + return fmt.Errorf("storing credentials: %w", err) } workspaces, err := getWorkspaces(client, interactive) @@ -248,7 +252,7 @@ func Login(domain string, oAuthOnly bool, username, password, houstonVersion str } } - err = RegistryAuth(client, out, "") + err = RegistryAuth(client, out, "", token) if err != nil { logger.Debugf("There was an error logging into registry: %s", err.Error()) } @@ -257,20 +261,13 @@ func Login(domain string, oAuthOnly bool, username, password, houstonVersion str } // Logout removes the locally stored token and reset current context -func Logout(domain string) { - c, err := context.GetContext(domain) - if err != nil { - return - } - - err = c.SetContextKey("token", "") - if err != nil { - return +func Logout(domain string, store keychain.SecureStore) { + if err := store.DeleteCredentials(domain); err != nil { + fmt.Printf("Failed to remove credentials: %s\n", err.Error()) } // remove the current context - err = config.ResetCurrentContext() - if err != nil { + if err := config.ResetCurrentContext(); err != nil { fmt.Println("Failed to reset current context: ", err.Error()) return } diff --git a/software/auth/auth_test.go b/software/auth/auth_test.go index d2e7efd14..c26c2917a 100644 --- a/software/auth/auth_test.go +++ b/software/auth/auth_test.go @@ -16,6 +16,7 @@ import ( "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/houston" houstonMocks "github.com/astronomer/astro-cli/houston/mocks" + "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -198,7 +199,7 @@ func (s *Suite) TestRegistryAuthSuccess() { err = ctx.SwitchContext() s.NoError(err) - tt.errAssertion(s.T(), RegistryAuth(houstonMock, out, "")) + tt.errAssertion(s.T(), RegistryAuth(houstonMock, out, "", "")) }) } mockRegistryHandler.AssertExpectations(s.T()) @@ -270,7 +271,7 @@ func (s *Suite) TestRegistryAuthRegistryDomain() { }, nil) out := new(bytes.Buffer) - RegistryAuth(houstonMock, out, tt.registryDomain) + RegistryAuth(houstonMock, out, tt.registryDomain, "") mockRegistryHandler.AssertExpectations(s.T()) }) @@ -292,7 +293,7 @@ func (s *Suite) TestRegistryAuthFailure() { houstonMock := new(houstonMocks.ClientInterface) houstonMock.On("GetAppConfig", "").Return(&houston.AppConfig{Flags: houston.FeatureFlags{BYORegistryEnabled: true}}, nil).Twice() - err := RegistryAuth(houstonMock, out, "") + err := RegistryAuth(houstonMock, out, "", "") s.ErrorIs(err, errMockRegistry) mockRegistryHandler := new(mocks.RegistryHandler) @@ -301,12 +302,12 @@ func (s *Suite) TestRegistryAuthFailure() { return mockRegistryHandler, nil } - err = RegistryAuth(houstonMock, out, "") + err = RegistryAuth(houstonMock, out, "", "") s.NoError(err) houstonMock.On("GetAppConfig", "").Return(&houston.AppConfig{Flags: houston.FeatureFlags{BYORegistryEnabled: false}}, nil).Once() - err = RegistryAuth(houstonMock, out, "") + err = RegistryAuth(houstonMock, out, "", "") s.ErrorIs(err, errMockRegistry) mockRegistryHandler.AssertExpectations(s.T()) @@ -318,7 +319,7 @@ func (s *Suite) TestRegistryAuthFailure() { houstonMock := new(houstonMocks.ClientInterface) houstonMock.On("GetAppConfig", "").Return(nil, errMockHouston).Once() - err := RegistryAuth(houstonMock, out, "") + err := RegistryAuth(houstonMock, out, "", "") s.ErrorIs(err, errMockHouston) houstonMock.AssertExpectations(s.T()) }) @@ -338,7 +339,7 @@ func (s *Suite) TestLoginSuccess() { houstonMock.On("ValidateWorkspaceID", "test-workspace-id").Return(&houston.Workspace{ID: "test-workspace-id"}, nil).Once() out := &bytes.Buffer{} - if !s.NoError(Login("localhost", false, "test", "test", "0.29.0", houstonMock, out)) { + if !s.NoError(Login("localhost", false, "test", "test", "0.29.0", keychain.NewTestStore(), houstonMock, out)) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"localhost"}, gotOut) { @@ -347,7 +348,7 @@ func (s *Suite) TestLoginSuccess() { houstonMock.On("ListWorkspaces", nil).Return([]houston.Workspace{{ID: "ck05r3bor07h40d02y2hw4n4v"}, {ID: "test-workspace-id"}}, nil).Once() out = &bytes.Buffer{} - if s.NoError(Login("localhost", false, "test", "test", "0.30.0", houstonMock, out)) { + if s.NoError(Login("localhost", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out)) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"localhost", "test-workspace-id"}, gotOut) { @@ -377,7 +378,7 @@ func (s *Suite) TestLoginSuccess() { houstonMock.On("ValidateWorkspaceID", "ck05r3bor07h40d02y2hw4n4v").Return(&houston.Workspace{}, nil).Once() out := &bytes.Buffer{} - if s.NoError(Login("localhost", false, "test", "test", "0.30.0", houstonMock, out)) { + if s.NoError(Login("localhost", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out)) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"localhost", "test-workspace-id"}, gotOut) { @@ -406,7 +407,7 @@ func (s *Suite) TestLoginSuccess() { } out := &bytes.Buffer{} - if s.NoError(Login("localhost", false, "test", "test", "0.30.0", houstonMock, out)) { + if s.NoError(Login("localhost", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out)) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"localhost", "ck05r3bor07h40d02y2hw4n4v"}, gotOut) { @@ -438,7 +439,7 @@ func (s *Suite) TestLoginSuccess() { houstonMock.On("ValidateWorkspaceID", "test-workspace-1").Return(&houston.Workspace{}, nil).Once() out := &bytes.Buffer{} - if s.NoError(Login("localhost", false, "test", "test", "0.30.0", houstonMock, out)) { + if s.NoError(Login("localhost", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out)) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"localhost", "test-workspace-1"}, gotOut) { @@ -460,7 +461,7 @@ func (s *Suite) TestLoginFailure() { houstonMock.On("GetAuthConfig", mock.Anything).Return(nil, errMockRegistry) out := &bytes.Buffer{} - if !s.ErrorIs(Login("localhost", false, "test", "test", "0.30.0", houstonMock, out), errMockRegistry) { + if !s.ErrorIs(Login("localhost", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out), errMockRegistry) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"localhost"}, gotOut) { @@ -475,7 +476,7 @@ func (s *Suite) TestLoginFailure() { houstonMock.On("AuthenticateWithBasicAuth", mock.Anything).Return("", errMockRegistry) out := &bytes.Buffer{} - if s.ErrorIs(Login("localhost", false, "test", "test", "0.30.0", houstonMock, out), errMockRegistry) { + if s.ErrorIs(Login("localhost", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out), errMockRegistry) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"localhost"}, gotOut) { @@ -491,7 +492,7 @@ func (s *Suite) TestLoginFailure() { houstonMock.On("ListWorkspaces", nil).Return([]houston.Workspace{}, errMockRegistry).Once() out := &bytes.Buffer{} - if s.ErrorIs(Login("localhost", false, "test", "test", "0.30.0", houstonMock, out), errMockRegistry) { + if s.ErrorIs(Login("localhost", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out), errMockRegistry) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"localhost"}, gotOut) { @@ -508,7 +509,7 @@ func (s *Suite) TestLoginFailure() { houstonMock.On("GetAppConfig", "").Return(&houston.AppConfig{Flags: houston.FeatureFlags{BYORegistryEnabled: false}}, nil) out := &bytes.Buffer{} - if s.NoError(Login("dev.astro.io", false, "test", "test", "0.30.0", houstonMock, out)) { + if s.NoError(Login("dev.astro.io", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out)) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"dev.astro.io", "No default workspace detected"}, gotOut) { @@ -531,7 +532,7 @@ func (s *Suite) TestLoginFailure() { } out := &bytes.Buffer{} - if s.NoError(Login("test.astro.io", false, "test", "test", "0.30.0", houstonMock, out)) { + if s.NoError(Login("test.astro.io", false, "test", "test", "0.30.0", keychain.NewTestStore(), houstonMock, out)) { return } if gotOut := out.String(); !testUtil.StringContains([]string{"test.astro.io", "Failed to authenticate to the registry"}, gotOut) { @@ -565,7 +566,7 @@ func (s *Suite) TestLogout() { } for _, tt := range tests { s.Run(tt.name, func() { - Logout(tt.args.domain) + Logout(tt.args.domain, keychain.NewTestStore()) }) } } diff --git a/software/deploy/deploy.go b/software/deploy/deploy.go index 793a07480..08d513d6a 100644 --- a/software/deploy/deploy.go +++ b/software/deploy/deploy.go @@ -15,10 +15,12 @@ import ( "github.com/astronomer/astro-cli/airflow" "github.com/astronomer/astro-cli/airflow/types" "github.com/astronomer/astro-cli/config" + "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/docker" "github.com/astronomer/astro-cli/houston" "github.com/astronomer/astro-cli/pkg/fileutil" "github.com/astronomer/astro-cli/pkg/input" + "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" "github.com/astronomer/astro-cli/pkg/printutil" "github.com/astronomer/astro-cli/software/auth" @@ -76,7 +78,7 @@ var tab = printutil.Table{ Header: []string{"#", "LABEL", "DEPLOYMENT NAME", "WORKSPACE", "DEPLOYMENT ID"}, } -func Airflow(houstonClient houston.ClientInterface, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { +func Airflow(houstonClient houston.ClientInterface, store keychain.SecureStore, path, deploymentID, wsID string, ignoreCacheDeploy, prompt bool, description string, isImageOnlyDeploy bool, imageName string) (string, error) { deploymentID, deployments, err := getDeploymentIDForCurrentCommand(houstonClient, wsID, deploymentID, prompt) if err != nil { return deploymentID, err @@ -126,7 +128,7 @@ func Airflow(houstonClient houston.ClientInterface, path, deploymentID, wsID str fmt.Printf(houstonDeploymentPrompt, releaseName) // Build the image to deploy - err = buildPushDockerImage(houstonClient, &c, deploymentInfo, releaseName, path, nextTag, cloudDomain, byoRegistryDomain, ignoreCacheDeploy, byoRegistryEnabled, description, imageName) + err = buildPushDockerImage(houstonClient, store, deploymentInfo, releaseName, path, nextTag, cloudDomain, byoRegistryDomain, ignoreCacheDeploy, byoRegistryEnabled, description, imageName) if err != nil { return deploymentID, err } @@ -204,13 +206,17 @@ func UpdateDeploymentImage(houstonClient houston.ClientInterface, deploymentID, return deploymentID, err } -func pushDockerImage(byoRegistryEnabled bool, deploymentInfo *houston.Deployment, byoRegistryDomain, name, nextTag, cloudDomain string, imageHandler airflow.ImageHandler, houstonClient houston.ClientInterface, c *config.Context, customImageName string) error { +func pushDockerImage(byoRegistryEnabled bool, deploymentInfo *houston.Deployment, byoRegistryDomain, name, nextTag, cloudDomain string, imageHandler airflow.ImageHandler, houstonClient houston.ClientInterface, store keychain.SecureStore, customImageName string) error { var registry, remoteImage, token string if byoRegistryEnabled { registry = byoRegistryDomain remoteImage = fmt.Sprintf("%s:%s", registry, fmt.Sprintf("%s-%s", name, nextTag)) } else { - token = c.Token + if ctx, err := context.GetCurrentContext(); err == nil { + if creds, err := store.GetCredentials(ctx.Domain); err == nil { + token = creds.Token + } + } platformVersion, _ := houstonClient.GetPlatformVersion(nil) if versions.GreaterThanOrEqualTo(platformVersion, "1.0.0") { var err error @@ -219,7 +225,7 @@ func pushDockerImage(byoRegistryEnabled bool, deploymentInfo *houston.Deployment return err } // Switch to per deployment registry login - err = auth.RegistryAuth(houstonClient, os.Stdout, registry) + err = auth.RegistryAuth(houstonClient, os.Stdout, registry, token) if err != nil { logger.Debugf("There was an error logging into registry: %s", err.Error()) return err @@ -228,7 +234,6 @@ func pushDockerImage(byoRegistryEnabled bool, deploymentInfo *houston.Deployment } else { registry = registryDomainPrefix + cloudDomain remoteImage = fmt.Sprintf("%s/%s", registry, airflow.ImageName(name, nextTag)) - token = c.Token } } if customImageName != "" { @@ -320,14 +325,14 @@ func getGetTagFromImageName(imageName string) string { return "" } -func buildPushDockerImage(houstonClient houston.ClientInterface, c *config.Context, deploymentInfo *houston.Deployment, name, path, nextTag, cloudDomain, byoRegistryDomain string, ignoreCacheDeploy, byoRegistryEnabled bool, description, customImageName string) error { +func buildPushDockerImage(houstonClient houston.ClientInterface, store keychain.SecureStore, deploymentInfo *houston.Deployment, name, path, nextTag, cloudDomain, byoRegistryDomain string, ignoreCacheDeploy, byoRegistryEnabled bool, description, customImageName string) error { imageName := airflow.ImageName(name, "latest") imageHandler := imageHandlerInit(imageName) err := buildDockerImage(ignoreCacheDeploy, deploymentInfo, customImageName, path, imageHandler, houstonClient, description) if err != nil { return err } - return pushDockerImage(byoRegistryEnabled, deploymentInfo, byoRegistryDomain, name, nextTag, cloudDomain, imageHandler, houstonClient, c, customImageName) + return pushDockerImage(byoRegistryEnabled, deploymentInfo, byoRegistryDomain, name, nextTag, cloudDomain, imageHandler, houstonClient, store, customImageName) } func getAirflowUILink(deploymentID string, deploymentURLs []houston.DeploymentURL) string { @@ -468,7 +473,7 @@ func getDagDeployURL(deploymentInfo *houston.Deployment) string { return "" } -func DagsOnlyDeploy(houstonClient houston.ClientInterface, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { +func DagsOnlyDeploy(houstonClient houston.ClientInterface, store keychain.SecureStore, wsID, deploymentID, dagsParentPath string, dagDeployURL *string, cleanUpFiles bool, description string) error { deploymentID, _, err := getDeploymentIDForCurrentCommandVar(houstonClient, wsID, deploymentID, deploymentID == "") if err != nil { return err @@ -540,8 +545,13 @@ func DagsOnlyDeploy(houstonClient houston.ClientInterface, wsID, deploymentID, d c, _ := config.GetCurrentContext() + var token string + if creds, err := store.GetCredentials(c.Domain); err == nil { + token = creds.Token + } + headers := map[string]string{ - "authorization": c.Token, + "authorization": token, } uploadFileArgs := fileutil.UploadFileArguments{ diff --git a/software/deploy/deploy_test.go b/software/deploy/deploy_test.go index a9a53a719..f58550866 100644 --- a/software/deploy/deploy_test.go +++ b/software/deploy/deploy_test.go @@ -22,6 +22,7 @@ import ( "github.com/astronomer/astro-cli/houston" houston_mocks "github.com/astronomer/astro-cli/houston/mocks" "github.com/astronomer/astro-cli/pkg/fileutil" + "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -162,7 +163,7 @@ func (s *Suite) TestBuildPushDockerImageSuccessWithTagWarning() { s.houstonMock.On("GetDeploymentConfig", nil).Return(mockedDeploymentConfig, nil) s.houstonMock.On("GetPlatformVersion", mock.Anything).Return("1.0.0", nil).Once() - err := buildPushDockerImage(s.houstonMock, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") + err := buildPushDockerImage(s.houstonMock, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") s.NoError(err) } @@ -189,7 +190,7 @@ func (s *Suite) TestBuildPushDockerImageSuccessWithImageRepoWarning() { s.houstonMock.On("GetRuntimeReleases", vars).Return(houston.RuntimeReleases{}, nil) s.houstonMock.On("GetPlatformVersion", mock.Anything).Return("1.0.0", nil).Once() - err := buildPushDockerImage(s.houstonMock, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") + err := buildPushDockerImage(s.houstonMock, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") s.NoError(err) } @@ -230,7 +231,7 @@ func (s *Suite) TestBuildPushDockerImageSuccessWithBYORegistry() { s.houstonMock.On("GetRuntimeReleases", vars).Return(houston.RuntimeReleases{}, nil) s.houstonMock.On("UpdateDeploymentImage", houston.UpdateDeploymentImageRequest{ReleaseName: "test", Image: "test.registry.io:test-test", AirflowVersion: "1.10.12", RuntimeVersion: ""}).Return(nil, nil) - err := buildPushDockerImage(s.houstonMock, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "test.registry.io", false, true, description, "") + err := buildPushDockerImage(s.houstonMock, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "test.registry.io", false, true, description, "") s.NoError(err) expectedLabel := deployRevisionDescriptionLabel + "=" + description @@ -262,7 +263,7 @@ func (s *Suite) TestBuildPushDockerImageSuccessWithBYORegistry() { } config.CFG.ShaAsTag.SetHomeString("true") defer config.CFG.ShaAsTag.SetHomeString("false") - err = buildPushDockerImage(s.houstonMock, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "test.registry.io", false, true, description, "") + err = buildPushDockerImage(s.houstonMock, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "test.registry.io", false, true, description, "") s.NoError(err) expectedLabel = deployRevisionDescriptionLabel + "=" + description assert.Contains(s.T(), capturedBuildConfig.Labels, expectedLabel) @@ -292,14 +293,14 @@ func (s *Suite) TestBuildPushDockerImageSuccessWithBYORegistryAndCustomImageName s.houstonMock.On("GetRuntimeReleases", vars).Return(houston.RuntimeReleases{}, nil) s.houstonMock.On("UpdateDeploymentImage", houston.UpdateDeploymentImageRequest{ReleaseName: "test", Image: "test.registry.io:latest", AirflowVersion: "1.10.12", RuntimeVersion: "12.2.0"}).Return(nil, nil) - err := buildPushDockerImage(s.houstonMock, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "test.registry.io", false, true, description, customImageName) + err := buildPushDockerImage(s.houstonMock, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "test.registry.io", false, true, description, customImageName) s.NoError(err) } func (s *Suite) TestBuildPushDockerImageFailure() { // invalid dockerfile test dockerfile = "Dockerfile.invalid" - err := buildPushDockerImage(nil, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") + err := buildPushDockerImage(nil, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") s.EqualError(err, "failed to parse dockerfile: testfiles/Dockerfile.invalid: when using JSON array syntax, arrays must be comprised of strings only") dockerfile = "Dockerfile" @@ -313,7 +314,7 @@ func (s *Suite) TestBuildPushDockerImageFailure() { vars["clusterId"] = "" s.houstonMock.On("GetRuntimeReleases", vars).Return(houston.RuntimeReleases{}, nil) // houston GetDeploymentConfig call failure - err = buildPushDockerImage(s.houstonMock, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") + err = buildPushDockerImage(s.houstonMock, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") s.Error(err, errMockHouston) s.houstonMock.On("GetDeploymentConfig", nil).Return(mockedDeploymentConfig, nil).Twice() @@ -324,7 +325,7 @@ func (s *Suite) TestBuildPushDockerImageFailure() { } // build error test case - err = buildPushDockerImage(s.houstonMock, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") + err = buildPushDockerImage(s.houstonMock, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") s.Error(err, errSomeContainerIssue.Error()) s.mockImageHandler.AssertExpectations(s.T()) @@ -337,7 +338,7 @@ func (s *Suite) TestBuildPushDockerImageFailure() { s.houstonMock.On("GetPlatformVersion", mock.Anything).Return("1.0.0", nil).Once() // push error test case - err = buildPushDockerImage(s.houstonMock, &config.Context{}, mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") + err = buildPushDockerImage(s.houstonMock, keychain.NewTestStore(), mockDeployment, "test", "./testfiles/", "test", "test", "", false, false, description, "") s.Error(err, errSomeContainerIssue.Error()) } @@ -430,13 +431,13 @@ func (s *Suite) TestGetDagDeployURL() { func (s *Suite) TestAirflowFailure() { // No workspace ID test case - _, err := Airflow(nil, "", "", "", false, false, description, false, "") + _, err := Airflow(nil, keychain.NewTestStore(), "", "", "", false, false, description, false, "") s.ErrorIs(err, ErrNoWorkspaceID) // houston GetWorkspace failure case s.houstonMock.On("GetWorkspace", mock.Anything).Return(nil, errMockHouston).Once() - _, err = Airflow(s.houstonMock, "", "", "test-workspace-id", false, false, description, false, "") + _, err = Airflow(s.houstonMock, keychain.NewTestStore(), "", "", "test-workspace-id", false, false, description, false, "") s.ErrorIs(err, errMockHouston) s.houstonMock.AssertExpectations(s.T()) @@ -444,7 +445,7 @@ func (s *Suite) TestAirflowFailure() { s.houstonMock.On("GetWorkspace", mock.Anything).Return(&houston.Workspace{}, nil) s.houstonMock.On("ListDeployments", mock.Anything).Return(nil, errMockHouston).Once() - _, err = Airflow(s.houstonMock, "", "", "test-workspace-id", false, false, description, false, "") + _, err = Airflow(s.houstonMock, keychain.NewTestStore(), "", "", "test-workspace-id", false, false, description, false, "") s.ErrorIs(err, errMockHouston) s.houstonMock.AssertExpectations(s.T()) @@ -455,36 +456,36 @@ func (s *Suite) TestAirflowFailure() { // config GetCurrentContext failure case config.ResetCurrentContext() - _, err = Airflow(s.houstonMock, "", "", "test-workspace-id", false, false, description, false, "") + _, err = Airflow(s.houstonMock, keychain.NewTestStore(), "", "", "test-workspace-id", false, false, description, false, "") s.EqualError(err, "no context set, have you authenticated to Astro or Astro Private Cloud? Run astro login and try again") context.Switch("localhost") // Invalid deployment name case - _, err = Airflow(s.houstonMock, "", "test-deployment-id", "test-workspace-id", false, false, description, false, "") + _, err = Airflow(s.houstonMock, keychain.NewTestStore(), "", "test-deployment-id", "test-workspace-id", false, false, description, false, "") s.ErrorIs(err, errInvalidDeploymentID) // No deployment in the current workspace case - _, err = Airflow(s.houstonMock, "", "", "test-workspace-id", false, false, description, false, "") + _, err = Airflow(s.houstonMock, keychain.NewTestStore(), "", "", "test-workspace-id", false, false, description, false, "") s.ErrorIs(err, errDeploymentNotFound) s.houstonMock.AssertExpectations(s.T()) // Invalid deployment selection case s.houstonMock.On("ListDeployments", mock.Anything).Return([]houston.Deployment{{ID: "test-deployment-id"}}, nil) - _, err = Airflow(s.houstonMock, "", "", "test-workspace-id", false, false, description, false, "") + _, err = Airflow(s.houstonMock, keychain.NewTestStore(), "", "", "test-workspace-id", false, false, description, false, "") s.ErrorIs(err, errInvalidDeploymentSelected) // return error When houston get deployment throws an error s.houstonMock.On("ListDeployments", mock.Anything).Return([]houston.Deployment{{ID: "test-deployment-id"}}, nil) s.houstonMock.On("GetDeployment", mock.Anything).Return(nil, errMockHouston).Once() - _, err = Airflow(s.houstonMock, "", "test-deployment-id", "test-workspace-id", false, false, description, false, "") + _, err = Airflow(s.houstonMock, keychain.NewTestStore(), "", "test-deployment-id", "test-workspace-id", false, false, description, false, "") s.Equal(err.Error(), "failed to get deployment info: "+errMockHouston.Error()) // buildPushDockerImage failure case s.houstonMock.On("GetDeployment", "test-deployment-id").Return(&houston.Deployment{ClusterID: "test-cluster-id"}, nil) s.houstonMock.On("GetAppConfig", "test-cluster-id").Return(&houston.AppConfig{}, nil) dockerfile = "Dockerfile.invalid" - _, err = Airflow(s.houstonMock, "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, false, "") + _, err = Airflow(s.houstonMock, keychain.NewTestStore(), "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, false, "") dockerfile = "Dockerfile" s.Error(err) s.Contains(err.Error(), "failed to parse dockerfile") @@ -523,7 +524,7 @@ func (s *Suite) TestAirflowSuccess() { vars["clusterId"] = "test-cluster-id" s.houstonMock.On("GetRuntimeReleases", vars).Return(mockRuntimeReleases, nil) - _, err := Airflow(s.houstonMock, "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, false, "") + _, err := Airflow(s.houstonMock, keychain.NewTestStore(), "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, false, "") s.NoError(err) } @@ -567,7 +568,7 @@ func (s *Suite) TestAirflowSuccessForBYORegistry() { s.houstonMock.On("GetRuntimeReleases", vars).Return(mockRuntimeReleases, nil) s.houstonMock.On("UpdateDeploymentImage", mock.Anything).Return(&houston.UpdateDeploymentImageResp{}, nil).Once() - _, err := Airflow(s.houstonMock, "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, false, "") + _, err := Airflow(s.houstonMock, keychain.NewTestStore(), "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, false, "") s.NoError(err) } @@ -591,7 +592,7 @@ func (s *Suite) TestAirflowFailureForNoBYORegistryDomain() { }, }, nil).Once() - _, err := Airflow(s.houstonMock, "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, false, "") + _, err := Airflow(s.houstonMock, keychain.NewTestStore(), "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, false, "") s.ErrorIs(err, ErrBYORegistryDomainNotSet) } @@ -636,7 +637,7 @@ func (s *Suite) TestAirflowSuccessForImageOnly() { vars["clusterId"] = "test-cluster-id" s.houstonMock.On("GetRuntimeReleases", vars).Return(mockRuntimeReleases, nil) - _, err := Airflow(s.houstonMock, "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, true, "") + _, err := Airflow(s.houstonMock, keychain.NewTestStore(), "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, true, "") s.NoError(err) } @@ -681,7 +682,7 @@ func (s *Suite) TestAirflowSuccessForImageName() { vars["clusterId"] = "test-cluster-id" s.houstonMock.On("GetRuntimeReleases", vars).Return(mockRuntimeReleases, nil) - _, err := Airflow(s.houstonMock, "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, true, customImageName) + _, err := Airflow(s.houstonMock, keychain.NewTestStore(), "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, true, customImageName) s.NoError(err) } @@ -707,7 +708,7 @@ func (s *Suite) TestAirflowFailForImageNameWhenImageHasNoRuntimeLabel() { s.houstonMock.On("GetDeployment", "test-deployment-id").Return(deployment, nil).Once() s.houstonMock.On("GetAppConfig", "test-cluster-id").Return(&houston.AppConfig{}, nil).Once() - _, err := Airflow(s.houstonMock, "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, true, customImageName) + _, err := Airflow(s.houstonMock, keychain.NewTestStore(), "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, true, customImageName) s.Error(err, ErrNoRuntimeLabelOnCustomImage) } @@ -732,7 +733,7 @@ func (s *Suite) TestAirflowFailureForImageOnly() { s.houstonMock.On("GetDeployment", "test-deployment-id").Return(deployment, nil).Once() s.houstonMock.On("GetAppConfig", "test-cluster-id").Return(&houston.AppConfig{}, nil).Once() - _, err := Airflow(s.houstonMock, "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, true, "") + _, err := Airflow(s.houstonMock, keychain.NewTestStore(), "./testfiles/", "test-deployment-id", "test-workspace-id", false, false, description, true, "") s.Error(err, ErrDeploymentTypeIncorrectForImageOnly) } @@ -758,7 +759,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { } s.houstonMock.On("GetDeployment", deploymentID).Return(deployment, nil).Once() s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err := DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, config.WorkingPath, nil, false, description) + err := DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, config.WorkingPath, nil, false, description) s.ErrorIs(err, ErrDagOnlyDeployDisabledInConfig) }) @@ -766,7 +767,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { getDeploymentIDForCurrentCommandVar = func(houstonClient houston.ClientInterface, wsID, deploymentID string, prompt bool) (string, []houston.Deployment, error) { return deploymentID, nil, errDeploymentNotFound } - err := DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, config.WorkingPath, nil, false, description) + err := DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, config.WorkingPath, nil, false, description) s.ErrorIs(err, errDeploymentNotFound) }) @@ -775,7 +776,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { return deploymentID, nil, nil } s.houstonMock.On("GetDeployment", deploymentID).Return(nil, errMockHouston).Once() - err := DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, config.WorkingPath, nil, false, description) + err := DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, config.WorkingPath, nil, false, description) s.ErrorContains(err, "failed to get deployment info: some houston error") }) @@ -799,7 +800,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { } s.houstonMock.On("GetDeployment", deploymentID).Return(deployment, nil).Once() s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err := DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, config.WorkingPath, nil, false, description) + err := DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, config.WorkingPath, nil, false, description) s.ErrorIs(err, ErrDagOnlyDeployNotEnabledForDeployment) }) @@ -824,7 +825,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { config.ResetCurrentContext() s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err := DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, config.WorkingPath, nil, false, description) + err := DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, config.WorkingPath, nil, false, description) s.EqualError(err, "could not get current context! Error: no context set, have you authenticated to Astro or Astro Private Cloud? Run astro login and try again") context.Switch("localhost") }) @@ -849,7 +850,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { } s.houstonMock.On("GetDeployment", deploymentID).Return(deployment, nil).Once() s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err := DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, config.WorkingPath, nil, false, description) + err := DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, config.WorkingPath, nil, false, description) s.ErrorIs(err, errInvalidDeploymentID) }) @@ -892,7 +893,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { defer os.RemoveAll("dags") s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err = DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, ".", nil, false, description) + err = DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, ".", nil, false, description) s.EqualError(err, ErrEmptyDagFolderUserCancelledOperation.Error()) // assert that no tar or gz file exists @@ -954,7 +955,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { defer server.Close() s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err = DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, ".", &server.URL, false, description) + err = DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, ".", &server.URL, false, description) s.NoError(err) // Validate that dags.tar file was created @@ -1004,7 +1005,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { defer testUtil.MockUserInput(s.T(), "y")() s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err = DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, "./dags", nil, false, description) + err = DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, "./dags", nil, false, description) s.EqualError(err, "open dags/dags.tar: no such file or directory") // assert that no tar or gz file exists @@ -1049,7 +1050,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { } s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err = DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, ".", nil, false, description) + err = DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, ".", nil, false, description) s.ErrorIs(err, gzipMockError) // Validate that dags.tar file was created @@ -1110,7 +1111,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { defer server.Close() s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err = DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, ".", &server.URL, false, description) + err = DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, ".", &server.URL, false, description) s.NoError(err) // Validate that dags.tar file was created @@ -1170,7 +1171,7 @@ func (s *Suite) TestDeployDagsOnlyFailure() { defer server.Close() s.houstonMock.On("GetAppConfig", deployment.ClusterID).Return(appConfig, nil).Once() - err = DagsOnlyDeploy(s.houstonMock, wsID, deploymentID, ".", &server.URL, true, description) + err = DagsOnlyDeploy(s.houstonMock, keychain.NewTestStore(), wsID, deploymentID, ".", &server.URL, true, description) s.NoError(err) // assert that no tar or gz file exists diff --git a/software/deployment/logs.go b/software/deployment/logs.go index b9936b7a3..63c63b503 100644 --- a/software/deployment/logs.go +++ b/software/deployment/logs.go @@ -33,7 +33,7 @@ func Log(deploymentID, component, search string, since time.Duration, client hou return nil } -func SubscribeDeploymentLog(deploymentID, component, search string, since time.Duration) error { +func SubscribeDeploymentLog(deploymentID, component, search, token string, since time.Duration) error { // Calculate timestamp as now - since e.g: // (2019-04-02 17:51:03.780819 +0000 UTC - 2 mins) = 2019-04-02 17:49:03.780819 +0000 UTC timestamp := time.Now().UTC().Add(-since) @@ -43,7 +43,7 @@ func SubscribeDeploymentLog(deploymentID, component, search string, since time.D return err } - err = subscribe(cl.Token, cl.GetSoftwareWebsocketURL(), request) + err = subscribe(token, cl.GetSoftwareWebsocketURL(), request) if err != nil { return err } diff --git a/software/deployment/logs_test.go b/software/deployment/logs_test.go index 88a67066a..e2262bb72 100644 --- a/software/deployment/logs_test.go +++ b/software/deployment/logs_test.go @@ -41,7 +41,7 @@ func (s *Suite) TestSubscribeDeploymentLog() { return nil } - err := SubscribeDeploymentLog("test-id", "test-component", "test", 0) + err := SubscribeDeploymentLog("test-id", "test-component", "test", "test-token", 0) s.NoError(err) }) @@ -50,7 +50,7 @@ func (s *Suite) TestSubscribeDeploymentLog() { return errMock } - err := SubscribeDeploymentLog("test-id", "test-component", "test", 0) + err := SubscribeDeploymentLog("test-id", "test-component", "test", "test-token", 0) s.ErrorIs(err, errMock) }) } From 83bcfaea99479703c870df2bf09b625a724c4000 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 9 Apr 2026 13:34:55 +0100 Subject: [PATCH 2/5] fixup! Store auth credentials in the OS keychain (macOS + Linux) --- pkg/keychain/keychain_cached.go | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 pkg/keychain/keychain_cached.go diff --git a/pkg/keychain/keychain_cached.go b/pkg/keychain/keychain_cached.go new file mode 100644 index 000000000..ca6a7e8ae --- /dev/null +++ b/pkg/keychain/keychain_cached.go @@ -0,0 +1,55 @@ +package keychain + +import "sync" + +// cachedStore wraps a SecureStore and caches credentials in memory so that +// repeated reads of the same domain within a single process only hit the +// underlying store (and trigger an OS keychain prompt) once. +type cachedStore struct { + inner SecureStore + mu sync.RWMutex + cache map[string]Credentials +} + +func newCachedStore(inner SecureStore) SecureStore { + return &cachedStore{inner: inner, cache: make(map[string]Credentials)} +} + +func (c *cachedStore) GetCredentials(domain string) (Credentials, error) { + c.mu.RLock() + creds, ok := c.cache[domain] + c.mu.RUnlock() + if ok { + return creds, nil + } + + creds, err := c.inner.GetCredentials(domain) + if err != nil { + return Credentials{}, err + } + + c.mu.Lock() + c.cache[domain] = creds + c.mu.Unlock() + return creds, nil +} + +func (c *cachedStore) SetCredentials(domain string, creds Credentials) error { + if err := c.inner.SetCredentials(domain, creds); err != nil { + return err + } + c.mu.Lock() + c.cache[domain] = creds + c.mu.Unlock() + return nil +} + +func (c *cachedStore) DeleteCredentials(domain string) error { + if err := c.inner.DeleteCredentials(domain); err != nil { + return err + } + c.mu.Lock() + delete(c.cache, domain) + c.mu.Unlock() + return nil +} From f1c29a5e44bcabbf0f44eec797127189c2079044 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 9 Apr 2026 15:06:50 +0100 Subject: [PATCH 3/5] fixup! Store auth credentials in the OS keychain (macOS + Linux) --- .golangci.yml | 2 + cmd/root_hooks_test.go | 22 ++++ pkg/keychain/keychain_cached_test.go | 167 +++++++++++++++++++++++++++ pkg/keychain/keychain_file_test.go | 126 ++++++++++++++++++++ pkg/keychain/keychain_linux.go | 2 +- pkg/keychain/keychain_linux_test.go | 19 +++ 6 files changed, 337 insertions(+), 1 deletion(-) create mode 100644 cmd/root_hooks_test.go create mode 100644 pkg/keychain/keychain_cached_test.go create mode 100644 pkg/keychain/keychain_file_test.go create mode 100644 pkg/keychain/keychain_linux_test.go diff --git a/.golangci.yml b/.golangci.yml index aba7ee4d2..44b75355d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -68,6 +68,8 @@ linters: - "2" - "10" - "64" + - "0o600" + - "0o700" nolintlint: require-explanation: false require-specific: false diff --git a/cmd/root_hooks_test.go b/cmd/root_hooks_test.go new file mode 100644 index 000000000..ed1fb661d --- /dev/null +++ b/cmd/root_hooks_test.go @@ -0,0 +1,22 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/keychain" + testUtil "github.com/astronomer/astro-cli/pkg/testing" +) + +func TestLoadSoftwareToken_LoadsToken(t *testing.T) { + testUtil.InitTestConfig(testUtil.SoftwarePlatform) + store := keychain.NewTestStore() + err := store.SetCredentials("astronomer_dev.com", keychain.Credentials{Token: "test-token"}) + assert.NoError(t, err) + + holder := &httputil.TokenHolder{} + loadSoftwareToken(store, holder) + assert.Equal(t, "test-token", holder.Get()) +} diff --git a/pkg/keychain/keychain_cached_test.go b/pkg/keychain/keychain_cached_test.go new file mode 100644 index 000000000..986d0a99e --- /dev/null +++ b/pkg/keychain/keychain_cached_test.go @@ -0,0 +1,167 @@ +package keychain + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockStore struct { + mock.Mock +} + +func (m *mockStore) GetCredentials(domain string) (Credentials, error) { + args := m.Called(domain) + return args.Get(0).(Credentials), args.Error(1) +} + +func (m *mockStore) SetCredentials(domain string, creds Credentials) error { + return m.Called(domain, creds).Error(0) +} + +func (m *mockStore) DeleteCredentials(domain string) error { + return m.Called(domain).Error(0) +} + +func TestCachedStore_Get(t *testing.T) { + t.Run("cache populated on first get, second get uses cache", func(t *testing.T) { + inner := new(mockStore) + inner.On("GetCredentials", "example.com").Return(Credentials{Token: "tok"}, nil) + store := newCachedStore(inner) + + got, err := store.GetCredentials("example.com") + require.NoError(t, err) + assert.Equal(t, "tok", got.Token) + + got, err = store.GetCredentials("example.com") + require.NoError(t, err) + assert.Equal(t, "tok", got.Token) + + inner.AssertNumberOfCalls(t, "GetCredentials", 1) + }) + + t.Run("different domains cached independently", func(t *testing.T) { + inner := new(mockStore) + inner.On("GetCredentials", "a.io").Return(Credentials{Token: "a"}, nil) + inner.On("GetCredentials", "b.io").Return(Credentials{Token: "b"}, nil) + store := newCachedStore(inner) + + _, err := store.GetCredentials("a.io") + require.NoError(t, err) + + b, err := store.GetCredentials("b.io") + require.NoError(t, err) + assert.Equal(t, "b", b.Token) + + inner.AssertNumberOfCalls(t, "GetCredentials", 2) + }) + + t.Run("inner error propagated and not cached", func(t *testing.T) { + inner := new(mockStore) + inner.On("GetCredentials", "example.com").Return(Credentials{}, errors.New("keyring locked")).Once() + inner.On("GetCredentials", "example.com").Return(Credentials{Token: "recovered"}, nil).Once() + store := newCachedStore(inner) + + _, err := store.GetCredentials("example.com") + require.ErrorContains(t, err, "keyring locked") + + got, err := store.GetCredentials("example.com") + require.NoError(t, err) + assert.Equal(t, "recovered", got.Token) + + inner.AssertNumberOfCalls(t, "GetCredentials", 2) + }) +} + +func TestCachedStore_Set(t *testing.T) { + t.Run("write-through then served from cache", func(t *testing.T) { + inner := new(mockStore) + creds := Credentials{Token: "new-tok"} + inner.On("SetCredentials", "example.com", creds).Return(nil) + store := newCachedStore(inner) + + require.NoError(t, store.SetCredentials("example.com", creds)) + + got, err := store.GetCredentials("example.com") + require.NoError(t, err) + assert.Equal(t, "new-tok", got.Token) + + inner.AssertNotCalled(t, "GetCredentials") + }) + + t.Run("inner set error propagated and cache not updated", func(t *testing.T) { + inner := new(mockStore) + inner.On("SetCredentials", "example.com", mock.Anything).Return(errors.New("disk full")) + inner.On("GetCredentials", "example.com").Return(Credentials{}, ErrNotFound) + store := newCachedStore(inner) + + err := store.SetCredentials("example.com", Credentials{Token: "tok"}) + require.ErrorContains(t, err, "disk full") + + _, err = store.GetCredentials("example.com") + assert.ErrorIs(t, err, ErrNotFound) + }) +} + +func TestCachedStore_Delete(t *testing.T) { + t.Run("invalidates cache so next get hits inner", func(t *testing.T) { + inner := new(mockStore) + inner.On("GetCredentials", "example.com").Return(Credentials{Token: "tok"}, nil).Once() + inner.On("GetCredentials", "example.com").Return(Credentials{}, ErrNotFound).Once() + inner.On("DeleteCredentials", "example.com").Return(nil) + store := newCachedStore(inner) + + _, err := store.GetCredentials("example.com") + require.NoError(t, err) + + require.NoError(t, store.DeleteCredentials("example.com")) + + _, err = store.GetCredentials("example.com") + assert.ErrorIs(t, err, ErrNotFound) + + inner.AssertNumberOfCalls(t, "GetCredentials", 2) + }) + + t.Run("does not affect other domains", func(t *testing.T) { + inner := new(mockStore) + inner.On("GetCredentials", "a.io").Return(Credentials{Token: "a"}, nil) + inner.On("GetCredentials", "b.io").Return(Credentials{Token: "b"}, nil) + inner.On("DeleteCredentials", "a.io").Return(nil) + store := newCachedStore(inner) + + _, err := store.GetCredentials("a.io") + require.NoError(t, err) + _, err = store.GetCredentials("b.io") + require.NoError(t, err) + + require.NoError(t, store.DeleteCredentials("a.io")) + + got, err := store.GetCredentials("b.io") + require.NoError(t, err) + assert.Equal(t, "b", got.Token) + + inner.AssertNumberOfCalls(t, "GetCredentials", 2) + }) + + t.Run("inner delete error propagated and cache preserved", func(t *testing.T) { + inner := new(mockStore) + inner.On("GetCredentials", "example.com").Return(Credentials{Token: "tok"}, nil) + inner.On("DeleteCredentials", "example.com").Return(errors.New("permission denied")) + store := newCachedStore(inner) + + _, err := store.GetCredentials("example.com") + require.NoError(t, err) + + err = store.DeleteCredentials("example.com") + require.ErrorContains(t, err, "permission denied") + + got, err := store.GetCredentials("example.com") + require.NoError(t, err) + assert.Equal(t, "tok", got.Token) + + inner.AssertNumberOfCalls(t, "GetCredentials", 1) + }) +} diff --git a/pkg/keychain/keychain_file_test.go b/pkg/keychain/keychain_file_test.go new file mode 100644 index 000000000..b151cbf7b --- /dev/null +++ b/pkg/keychain/keychain_file_test.go @@ -0,0 +1,126 @@ +//go:build !darwin + +package keychain + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFileStore_CRUD(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T, s *fileStore) + }{ + { + name: "get from missing file returns ErrNotFound", + run: func(t *testing.T, s *fileStore) { + _, err := s.GetCredentials("example.com") + assert.ErrorIs(t, err, ErrNotFound) + }, + }, + { + name: "set and get round-trip", + run: func(t *testing.T, s *fileStore) { + creds := Credentials{Token: "tok", UserEmail: "a@b.com"} + require.NoError(t, s.SetCredentials("example.com", creds)) + + got, err := s.GetCredentials("example.com") + require.NoError(t, err) + assert.Equal(t, creds, got) + }, + }, + { + name: "get missing domain returns ErrNotFound", + run: func(t *testing.T, s *fileStore) { + require.NoError(t, s.SetCredentials("a.io", Credentials{Token: "a"})) + + _, err := s.GetCredentials("b.io") + assert.ErrorIs(t, err, ErrNotFound) + }, + }, + { + name: "set overwrites existing", + run: func(t *testing.T, s *fileStore) { + require.NoError(t, s.SetCredentials("x.io", Credentials{Token: "old"})) + require.NoError(t, s.SetCredentials("x.io", Credentials{Token: "new"})) + + got, err := s.GetCredentials("x.io") + require.NoError(t, err) + assert.Equal(t, "new", got.Token) + }, + }, + { + name: "delete then get returns ErrNotFound", + run: func(t *testing.T, s *fileStore) { + require.NoError(t, s.SetCredentials("x.io", Credentials{Token: "tok"})) + require.NoError(t, s.DeleteCredentials("x.io")) + + _, err := s.GetCredentials("x.io") + assert.ErrorIs(t, err, ErrNotFound) + }, + }, + { + name: "delete preserves other domains", + run: func(t *testing.T, s *fileStore) { + require.NoError(t, s.SetCredentials("a.io", Credentials{Token: "a"})) + require.NoError(t, s.SetCredentials("b.io", Credentials{Token: "b"})) + require.NoError(t, s.DeleteCredentials("a.io")) + + got, err := s.GetCredentials("b.io") + require.NoError(t, err) + assert.Equal(t, "b", got.Token) + }, + }, + { + name: "file has 0600 permissions after set", + run: func(t *testing.T, s *fileStore) { + require.NoError(t, s.SetCredentials("x.io", Credentials{Token: "tok"})) + + info, err := os.Stat(s.path) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &fileStore{path: filepath.Join(t.TempDir(), "credentials.json")} + tt.run(t, s) + }) + } +} + +func TestFileStore_CorruptJSON(t *testing.T) { + s := &fileStore{path: filepath.Join(t.TempDir(), "credentials.json")} + require.NoError(t, os.WriteFile(s.path, []byte("not json{{{"), 0o600)) + + _, err := s.GetCredentials("example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "decoding credentials file") +} + +func TestWriteAtomic(t *testing.T) { + t.Run("writes file with expected content and permissions", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "test.json") + require.NoError(t, writeAtomic(path, []byte(`{"key":"value"}`))) + + data, err := os.ReadFile(path) + require.NoError(t, err) + assert.Equal(t, `{"key":"value"}`, string(data)) + + info, err := os.Stat(path) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) + }) + + t.Run("error when target dir does not exist", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "nodir", "file.json") + err := writeAtomic(path, []byte("data")) + require.Error(t, err) + }) +} diff --git a/pkg/keychain/keychain_linux.go b/pkg/keychain/keychain_linux.go index 53b8c90d9..93fb752d1 100644 --- a/pkg/keychain/keychain_linux.go +++ b/pkg/keychain/keychain_linux.go @@ -12,7 +12,7 @@ import ( // // If no Secret Service daemon is available (e.g. headless CI environments), // falls back to a plaintext JSON file at ~/.astro/credentials.json with -// 0600 permissions. This matches the current plaintext config.yaml behaviour +// 0600 permissions. This matches the current plaintext config.yaml behavior // and is intentional — encrypted file fallback is not worth the complexity // given that CI environments use ASTRO_API_TOKEN anyway. // diff --git a/pkg/keychain/keychain_linux_test.go b/pkg/keychain/keychain_linux_test.go new file mode 100644 index 000000000..b75d163c7 --- /dev/null +++ b/pkg/keychain/keychain_linux_test.go @@ -0,0 +1,19 @@ +//go:build linux + +package keychain + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// CI runs in a Docker container with no D-Bus / Secret Service, so New() +// always takes the fileStore fallback path. We can't test the keyring +// success path without a running Secret Service daemon. +func TestNew_FallsBackToFileStore(t *testing.T) { + store, err := New() + require.NoError(t, err) + assert.IsType(t, &fileStore{}, store) +} From 90bbf462a3b25884fd4f1aa85f4622b57bcc0924 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 10 Apr 2026 14:09:30 +0100 Subject: [PATCH 4/5] fixup! Store auth credentials in the OS keychain (macOS + Linux) - Add tests for cachedStore, fileStore, loadSoftwareToken, and linux New() - Restrict Linux keyring backends to SecretService/KWallet only (keyctl doesn't persist across reboots, pass/file prompt for passphrases) - Remove stale NOTE comment from keychain_linux.go - Add store==nil guard to software Logout (reachable via login/logout which skip the storeErr check in pre-run hook) - Ignore 0o600/0o700 file permission literals in golangci-lint mnd Co-Authored-By: Claude Sonnet 4.6 --- pkg/keychain/keychain_linux.go | 14 +++++++++----- pkg/keychain/keychain_linux_test.go | 4 ++-- software/auth/auth.go | 4 +++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pkg/keychain/keychain_linux.go b/pkg/keychain/keychain_linux.go index 93fb752d1..978f96f89 100644 --- a/pkg/keychain/keychain_linux.go +++ b/pkg/keychain/keychain_linux.go @@ -15,17 +15,21 @@ import ( // 0600 permissions. This matches the current plaintext config.yaml behavior // and is intentional — encrypted file fallback is not worth the complexity // given that CI environments use ASTRO_API_TOKEN anyway. -// -// NOTE: if 99designs/keyring fails to connect to Secret Service in environments -// that DO have it running, replace with godbus/dbus directly: -// https://github.com/godbus/dbus — the SecureStore interface is the only -// change boundary. func New() (SecureStore, error) { ring, err := keyring.Open(keyring.Config{ ServiceName: serviceName, LibSecretCollectionName: "astro-cli", KWalletAppID: "astro-cli", KWalletFolder: "astro-cli", + // Only allow persistent, non-interactive backends. KeyCtl stores + // credentials in kernel memory that doesn't survive reboot. Pass + // and File prompt for passphrases, which breaks non-interactive + // CLI usage. When neither desktop backend is available we fall + // through to our own fileStore below. + AllowedBackends: []keyring.BackendType{ + keyring.SecretServiceBackend, + keyring.KWalletBackend, + }, }) if err == nil { return newCachedStore(&keyringStore{ring: ring}), nil diff --git a/pkg/keychain/keychain_linux_test.go b/pkg/keychain/keychain_linux_test.go index b75d163c7..1ff206c40 100644 --- a/pkg/keychain/keychain_linux_test.go +++ b/pkg/keychain/keychain_linux_test.go @@ -10,8 +10,8 @@ import ( ) // CI runs in a Docker container with no D-Bus / Secret Service, so New() -// always takes the fileStore fallback path. We can't test the keyring -// success path without a running Secret Service daemon. +// falls back to a fileStore. We can't test the keyring success path +// without a running Secret Service or KWallet daemon. func TestNew_FallsBackToFileStore(t *testing.T) { store, err := New() require.NoError(t, err) diff --git a/software/auth/auth.go b/software/auth/auth.go index d546d91d4..e1a3995a7 100644 --- a/software/auth/auth.go +++ b/software/auth/auth.go @@ -262,7 +262,9 @@ func Login(domain string, oAuthOnly bool, username, password, houstonVersion str // Logout removes the locally stored token and reset current context func Logout(domain string, store keychain.SecureStore) { - if err := store.DeleteCredentials(domain); err != nil { + if store == nil { + fmt.Println("Warning: credential store not available; local credentials may not be cleared") + } else if err := store.DeleteCredentials(domain); err != nil { fmt.Printf("Failed to remove credentials: %s\n", err.Error()) } From a723c464feaac68bcff039cf425ce695e418d72c Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 10 Apr 2026 17:52:14 +0100 Subject: [PATCH 5/5] fixup! Store auth credentials in the OS keychain (macOS + Linux) --- airflow-client/airflow-client.go | 13 +- airflow-client/airflow-client_test.go | 117 +++++++++--------- astro-client-core/client.go | 7 +- astro-client-core/client.test.go | 3 +- astro-client-iam-core/client.go | 3 +- astro-client-iam-core/client.test.go | 3 +- astro-client-platform-core/client.go | 3 +- cloud/auth/auth.go | 11 +- cloud/deploy/bundle.go | 4 +- cloud/deploy/bundle_test.go | 12 +- cloud/deploy/deploy.go | 9 +- cloud/deploy/deploy_test.go | 23 +++- cloud/deployment/deployment.go | 28 +++-- cloud/deployment/deployment_test.go | 97 ++++++++------- cloud/deployment/deployment_variable.go | 2 +- cloud/deployment/fromfile/fromfile.go | 11 +- cloud/deployment/fromfile/fromfile_test.go | 117 +++++++++--------- cloud/deployment/workerqueue/workerqueue.go | 6 +- cloud/platformclient/client.go | 3 +- cloud/platformclient/client_test.go | 3 +- cmd/api/airflow.go | 13 +- cmd/api/airflow_test.go | 14 +-- cmd/api/api.go | 12 +- cmd/api/cloud.go | 14 +-- cmd/auth.go | 22 ++-- cmd/auth_test.go | 4 +- cmd/cloud/dbt.go | 1 + cmd/cloud/deploy.go | 1 + cmd/cloud/deployment.go | 6 +- cmd/cloud/remote.go | 1 + cmd/cloud/root.go | 5 +- cmd/cloud/root_test.go | 2 +- cmd/cloud/setup.go | 39 +++--- cmd/cloud/setup_test.go | 68 +++++----- cmd/root.go | 23 ++-- cmd/root_hooks.go | 18 +-- cmd/root_hooks_test.go | 4 +- houston/houston.go | 19 +-- .../credentials.go} | 16 +-- .../credentials_test.go} | 8 +- 40 files changed, 414 insertions(+), 351 deletions(-) rename pkg/{httputil/token_holder.go => credentials/credentials.go} (59%) rename pkg/{httputil/token_holder_test.go => credentials/credentials_test.go} (55%) diff --git a/airflow-client/airflow-client.go b/airflow-client/airflow-client.go index 65e77cb7e..a762cb521 100644 --- a/airflow-client/airflow-client.go +++ b/airflow-client/airflow-client.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) @@ -40,14 +41,14 @@ type Client interface { // Client containers the logger and HTTPClient used to communicate with the Astronomer API type HTTPClient struct { *httputil.HTTPClient - tokenHolder *httputil.TokenHolder + creds *credentials.CurrentCredentials } // NewAirflowClient returns a new Client with the logger and HTTP client setup. -func NewAirflowClient(c *httputil.HTTPClient, tokenHolder *httputil.TokenHolder) *HTTPClient { +func NewAirflowClient(c *httputil.HTTPClient, creds *credentials.CurrentCredentials) *HTTPClient { return &HTTPClient{ - HTTPClient: c, - tokenHolder: tokenHolder, + HTTPClient: c, + creds: creds, } } @@ -243,8 +244,8 @@ func checkRetryPolicy(method string) retryablehttp.CheckRetry { } func (c *HTTPClient) DoAirflowClient(doOpts *httputil.DoOptions) (*Response, error) { - if c.tokenHolder != nil { - if tok := c.tokenHolder.Get(); tok != "" { + if c.creds != nil { + if tok := c.creds.Get(); tok != "" { if doOpts.Headers == nil { doOpts.Headers = map[string]string{} } diff --git a/airflow-client/airflow-client_test.go b/airflow-client/airflow-client_test.go index c764a3ffb..27abd40ad 100644 --- a/airflow-client/airflow-client_test.go +++ b/airflow-client/airflow-client_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/suite" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -46,8 +47,8 @@ func (s *Suite) TestDoAirflowClient() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) doOpts := &httputil.DoOptions{ Path: "/test", Headers: map[string]string{ @@ -111,8 +112,8 @@ func (s *Suite) TestGetConnections() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) response, err := airflowClient.GetConnections("test-airflow-url") s.NoError(err) @@ -143,8 +144,8 @@ func (s *Suite) TestGetConnections() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) _, err := airflowClient.GetConnections("test-airflow-url") s.Error(err) @@ -177,8 +178,8 @@ func (s *Suite) TestUpdateConnection() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdateConnection("test-airflow-url", mockConn) s.NoError(err) @@ -192,8 +193,8 @@ func (s *Suite) TestUpdateConnection() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdateConnection("test-airflow-url", mockConn) s.Error(err) @@ -208,8 +209,8 @@ func (s *Suite) TestUpdateConnection() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) // Pass a nil connection to force JSON marshal error err := airflowClient.UpdateConnection("test-airflow-url", mockConn) @@ -225,8 +226,8 @@ func (s *Suite) TestUpdateConnection() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdateConnection("test-airflow-url", mockConn) s.Error(err) @@ -259,8 +260,8 @@ func (s *Suite) TestCreateConnection() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreateConnection("test-airflow-url", mockConn) s.NoError(err) @@ -274,8 +275,8 @@ func (s *Suite) TestCreateConnection() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreateConnection("test-airflow-url", mockConn) s.Error(err) @@ -290,8 +291,8 @@ func (s *Suite) TestCreateConnection() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) // Pass a nil connection to force JSON marshal error err := airflowClient.CreateConnection("test-airflow-url", nil) @@ -307,8 +308,8 @@ func (s *Suite) TestCreateConnection() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreateConnection("test-airflow-url", mockConn) s.Error(err) @@ -341,8 +342,8 @@ func (s *Suite) TestCreateVariable() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreateVariable("test-airflow-url", *mockVar) s.NoError(err) @@ -356,8 +357,8 @@ func (s *Suite) TestCreateVariable() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreateVariable("test-airflow-url", *mockVar) s.Error(err) @@ -372,8 +373,8 @@ func (s *Suite) TestCreateVariable() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreateVariable("test-airflow-url", Variable{Key: "", Value: "test-value"}) s.Error(err) @@ -399,8 +400,8 @@ func (s *Suite) TestGetVariables() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) response, err := airflowClient.GetVariables("test-airflow-url") s.NoError(err) @@ -450,8 +451,8 @@ func (s *Suite) TestUpdateVariable() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdateVariable("test-airflow-url", *mockVar) s.NoError(err) @@ -465,8 +466,8 @@ func (s *Suite) TestUpdateVariable() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdateVariable("test-airflow-url", *mockVar) s.Error(err) @@ -481,8 +482,8 @@ func (s *Suite) TestUpdateVariable() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdateVariable("test-airflow-url", Variable{Key: "", Value: "test-value"}) s.Error(err) @@ -532,8 +533,8 @@ func (s *Suite) TestCreatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreatePool("test-airflow-url", *mockPool) s.NoError(err) @@ -547,8 +548,8 @@ func (s *Suite) TestCreatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreatePool("test-airflow-url", *mockPool) s.Error(err) @@ -563,8 +564,8 @@ func (s *Suite) TestCreatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) // Pass a nil pool to force JSON marshal error err := airflowClient.CreatePool("test-airflow-url", *mockPool) @@ -580,8 +581,8 @@ func (s *Suite) TestCreatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.CreatePool("test-airflow-url", *mockPool) s.Error(err) @@ -614,8 +615,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdatePool("test-airflow-url", *mockPool) s.NoError(err) @@ -646,8 +647,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err = airflowClient.UpdatePool("test-airflow-url", defaultPool) s.NoError(err) @@ -661,8 +662,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdatePool("test-airflow-url", *mockPool) s.Error(err) @@ -677,8 +678,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) // Pass a nil pool to force JSON marshal error err := airflowClient.UpdatePool("test-airflow-url", Pool{}) @@ -694,8 +695,8 @@ func (s *Suite) TestUpdatePool() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) err := airflowClient.UpdatePool("test-airflow-url", *mockPool) s.Error(err) @@ -720,8 +721,8 @@ func (s *Suite) TestGetPools() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) response, err := airflowClient.GetPools("test-airflow-url") s.NoError(err) @@ -754,8 +755,8 @@ func (s *Suite) TestGetPools() { Header: make(http.Header), } }) - th := httputil.NewTokenHolder("token") - airflowClient := NewAirflowClient(client, th) + creds := credentials.New("token") + airflowClient := NewAirflowClient(client, creds) response, err := airflowClient.GetPools("test-airflow-url") s.Error(err) diff --git a/astro-client-core/client.go b/astro-client-core/client.go index d9ee55873..368edf60a 100644 --- a/astro-client-core/client.go +++ b/astro-client-core/client.go @@ -2,6 +2,7 @@ package astrocore import ( "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) @@ -13,9 +14,9 @@ var NormalizeAPIError = httputil.NormalizeAPIError type CoreClient = ClientWithResponsesInterface // NewCoreClient creates an API client for astro core services. -// The provided TokenHolder is read on every request — set it via -// TokenHolder.Set after credentials are resolved in PersistentPreRunE. -func NewCoreClient(c *httputil.HTTPClient, holder *httputil.TokenHolder) *ClientWithResponses { +// The provided CurrentCredentials is read on every request — set it via +// CurrentCredentials.Set after credentials are resolved in PersistentPreRunE. +func NewCoreClient(c *httputil.HTTPClient, holder *credentials.CurrentCredentials) *ClientWithResponses { cl, _ := NewClientWithResponses("", WithHTTPClient(c.HTTPClient), WithRequestEditorFn(httputil.NewRequestEditorFn(func() (string, string, error) { ctx, err := context.GetCurrentContext() if err != nil { diff --git a/astro-client-core/client.test.go b/astro-client-core/client.test.go index ec950ffe0..4a57ff1e3 100644 --- a/astro-client-core/client.test.go +++ b/astro-client-core/client.test.go @@ -5,10 +5,11 @@ import ( "github.com/stretchr/testify/assert" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) func TestNewCoreClient(t *testing.T) { - client := NewCoreClient(httputil.NewHTTPClient(), &httputil.TokenHolder{}) + client := NewCoreClient(httputil.NewHTTPClient(), &credentials.CurrentCredentials{}) assert.NotNil(t, client, "Can't create new Astro Core client") } diff --git a/astro-client-iam-core/client.go b/astro-client-iam-core/client.go index 1a820de41..5ad06d57e 100644 --- a/astro-client-iam-core/client.go +++ b/astro-client-iam-core/client.go @@ -2,6 +2,7 @@ package astroiamcore import ( "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) @@ -12,7 +13,7 @@ var NormalizeAPIError = httputil.NormalizeAPIError // a shorter alias type CoreClient = ClientWithResponsesInterface -func NewIamCoreClient(c *httputil.HTTPClient, holder *httputil.TokenHolder) *ClientWithResponses { +func NewIamCoreClient(c *httputil.HTTPClient, holder *credentials.CurrentCredentials) *ClientWithResponses { cl, _ := NewClientWithResponses("", WithHTTPClient(c.HTTPClient), WithRequestEditorFn(httputil.NewRequestEditorFn(func() (string, string, error) { ctx, err := context.GetCurrentContext() if err != nil { diff --git a/astro-client-iam-core/client.test.go b/astro-client-iam-core/client.test.go index a41af28a6..4c1a88bc6 100644 --- a/astro-client-iam-core/client.test.go +++ b/astro-client-iam-core/client.test.go @@ -5,10 +5,11 @@ import ( "github.com/stretchr/testify/assert" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) func TestNewIamCoreClient(t *testing.T) { - client := NewIamCoreClient(httputil.NewHTTPClient(), &httputil.TokenHolder{}) + client := NewIamCoreClient(httputil.NewHTTPClient(), &credentials.CurrentCredentials{}) assert.NotNil(t, client, "Can't create new Astro IAM Core client") } diff --git a/astro-client-platform-core/client.go b/astro-client-platform-core/client.go index 610d9bfcc..98df7097b 100644 --- a/astro-client-platform-core/client.go +++ b/astro-client-platform-core/client.go @@ -2,6 +2,7 @@ package astroplatformcore import ( "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) @@ -12,7 +13,7 @@ var NormalizeAPIError = httputil.NormalizeAPIError // a shorter alias type CoreClient = ClientWithResponsesInterface -func NewPlatformCoreClient(c *httputil.HTTPClient, holder *httputil.TokenHolder) *ClientWithResponses { +func NewPlatformCoreClient(c *httputil.HTTPClient, holder *credentials.CurrentCredentials) *ClientWithResponses { cl, _ := NewClientWithResponses("", WithHTTPClient(c.HTTPClient), WithRequestEditorFn(httputil.NewRequestEditorFn(func() (string, string, error) { ctx, err := context.GetCurrentContext() if err != nil { diff --git a/cloud/auth/auth.go b/cloud/auth/auth.go index 2db1e3e2a..4ea5986ad 100644 --- a/cloud/auth/auth.go +++ b/cloud/auth/auth.go @@ -22,6 +22,7 @@ import ( "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/ansi" "github.com/astronomer/astro-cli/pkg/astroauth" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/domainutil" "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/keychain" @@ -341,7 +342,7 @@ func CheckUserSession(c *config.Context, coreClient astrocore.CoreClient, platfo } // Login handles authentication to astronomer api and registry -func Login(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { +func Login(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { var res Result domain = domainutil.FormatDomain(domain) authConfig, err := FetchDomainAuthConfig(domain) @@ -388,7 +389,7 @@ func Login(domain, token string, store keychain.SecureStore, tokenHolder *httput return err } - creds := keychain.Credentials{ + keyCreds := keychain.Credentials{ Token: "Bearer " + res.AccessToken, RefreshToken: res.RefreshToken, UserEmail: res.UserEmail, @@ -397,11 +398,11 @@ func Login(domain, token string, store keychain.SecureStore, tokenHolder *httput if store == nil { return fmt.Errorf("credential store not available; cannot save login credentials") } - if err := store.SetCredentials(domain, creds); err != nil { + if err := store.SetCredentials(domain, keyCreds); err != nil { return fmt.Errorf("storing credentials: %w", err) } - if tokenHolder != nil { - tokenHolder.Set(creds.Token) + if creds != nil { + creds.Set(keyCreds.Token) } fmt.Printf("Logging in as %s\n", ansi.Green(res.UserEmail)) diff --git a/cloud/deploy/bundle.go b/cloud/deploy/bundle.go index accfeeb9b..d06f441e9 100644 --- a/cloud/deploy/bundle.go +++ b/cloud/deploy/bundle.go @@ -14,6 +14,7 @@ import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" "github.com/astronomer/astro-cli/cloud/deployment" "github.com/astronomer/astro-cli/config" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/fileutil" "github.com/astronomer/astro-cli/pkg/git" "github.com/astronomer/astro-cli/pkg/logger" @@ -29,6 +30,7 @@ type DeployBundleInput struct { WaitTime time.Duration PlatformCoreClient astroplatformcore.CoreClient CoreClient astrocore.CoreClient + Creds *credentials.CurrentCredentials } func DeployBundle(input *DeployBundleInput) error { @@ -44,7 +46,7 @@ func DeployBundle(input *DeployBundleInput) error { } // if CI/CD is enforced, check the subject can deploy - if currentDeployment.IsCicdEnforced && !canCiCdDeploy("Bearer "+os.Getenv("ASTRO_API_TOKEN")) { + if currentDeployment.IsCicdEnforced && !canCiCdDeploy(input.Creds) { return fmt.Errorf(errCiCdEnforcementUpdate, currentDeployment.Name) } diff --git a/cloud/deploy/bundle_test.go b/cloud/deploy/bundle_test.go index be0a9f66c..418c17f1a 100644 --- a/cloud/deploy/bundle_test.go +++ b/cloud/deploy/bundle_test.go @@ -19,6 +19,7 @@ import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" astroplatformcore_mocks "github.com/astronomer/astro-cli/astro-client-platform-core/mocks" "github.com/astronomer/astro-cli/config" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/git" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -40,7 +41,7 @@ func TestBundles(t *testing.T) { } func (s *BundleSuite) TestBundleDeploy_Success() { - canCiCdDeploy = func(token string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return true } @@ -57,6 +58,7 @@ func (s *BundleSuite) TestBundleDeploy_Success() { Description: "test-description", PlatformCoreClient: s.mockPlatformCoreClient, CoreClient: s.mockCoreClient, + Creds: &credentials.CurrentCredentials{}, } mockGetDeployment(s.mockPlatformCoreClient, true, true) @@ -82,7 +84,7 @@ func (s *BundleSuite) TestBundleDeploy_Success() { } func (s *BundleSuite) TestBundleDeploy_CiCdIncompatible() { - canCiCdDeploy = func(token string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return false } @@ -90,6 +92,7 @@ func (s *BundleSuite) TestBundleDeploy_CiCdIncompatible() { DeploymentID: "test-deployment-id", PlatformCoreClient: s.mockPlatformCoreClient, CoreClient: s.mockCoreClient, + Creds: &credentials.CurrentCredentials{}, } mockGetDeployment(s.mockPlatformCoreClient, true, true) @@ -105,6 +108,7 @@ func (s *BundleSuite) TestBundleDeploy_DagDeployDisabled() { input := &DeployBundleInput{ PlatformCoreClient: s.mockPlatformCoreClient, CoreClient: s.mockCoreClient, + Creds: &credentials.CurrentCredentials{}, } mockGetDeployment(s.mockPlatformCoreClient, false, false) @@ -124,6 +128,7 @@ func (s *BundleSuite) TestBundleDeploy_GitMetadataRetrieved() { BundlePath: gitPath, PlatformCoreClient: s.mockPlatformCoreClient, CoreClient: s.mockCoreClient, + Creds: &credentials.CurrentCredentials{}, } mockGetDeployment(s.mockPlatformCoreClient, true, false) @@ -167,6 +172,7 @@ func (s *BundleSuite) TestBundleDeploy_GitHasUncommittedChanges() { BundlePath: gitPath, PlatformCoreClient: s.mockPlatformCoreClient, CoreClient: s.mockCoreClient, + Creds: &credentials.CurrentCredentials{}, } mockGetDeployment(s.mockPlatformCoreClient, true, false) @@ -206,6 +212,7 @@ func (s *BundleSuite) TestBundleDeploy_GitMetadataDisabledViaConfig() { BundlePath: gitPath, PlatformCoreClient: s.mockPlatformCoreClient, CoreClient: s.mockCoreClient, + Creds: &credentials.CurrentCredentials{}, } mockGetDeployment(s.mockPlatformCoreClient, true, false) @@ -236,6 +243,7 @@ func (s *BundleSuite) TestBundleDeploy_BundleUploadUrlMissing() { input := &DeployBundleInput{ PlatformCoreClient: s.mockPlatformCoreClient, CoreClient: s.mockCoreClient, + Creds: &credentials.CurrentCredentials{}, } mockGetDeployment(s.mockPlatformCoreClient, true, false) diff --git a/cloud/deploy/deploy.go b/cloud/deploy/deploy.go index f06f65ec8..0249726a4 100644 --- a/cloud/deploy/deploy.go +++ b/cloud/deploy/deploy.go @@ -23,6 +23,7 @@ import ( "github.com/astronomer/astro-cli/docker" "github.com/astronomer/astro-cli/pkg/ansi" "github.com/astronomer/astro-cli/pkg/azure" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/fileutil" "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/input" @@ -112,6 +113,7 @@ type InputDeploy struct { Description string BuildSecretString string Force bool + Creds *credentials.CurrentCredentials } // InputClientDeploy contains inputs for client image deployments @@ -121,6 +123,7 @@ type InputClientDeploy struct { Platform string BuildSecretString string DeploymentID string + Creds *credentials.CurrentCredentials } const accessYourDeploymentFmt = ` @@ -224,7 +227,7 @@ func Deploy(deployInput InputDeploy, platformCoreClient astroplatformcore.CoreCl } if deployInfo.cicdEnforcement { - if !canCiCdDeploy("Bearer " + os.Getenv("ASTRO_API_TOKEN")) { + if !canCiCdDeploy(deployInput.Creds) { return fmt.Errorf(errCiCdEnforcementUpdate, deployInfo.name) //nolint } } @@ -413,7 +416,7 @@ func Deploy(deployInput InputDeploy, platformCoreClient astroplatformcore.CoreCl imageHandler := airflowImageHandler(deployInfo.deployImage) fmt.Println("Pushing image to Astronomer registry") - _, err = imageHandler.Push(remoteImage, registryUsername, "Bearer "+os.Getenv("ASTRO_API_TOKEN"), false) + _, err = imageHandler.Push(remoteImage, registryUsername, deployInput.Creds.Get(), false) if err != nil { return err } @@ -969,7 +972,7 @@ func DeployClientImage(deployInput InputClientDeploy, platformCoreClient astropl } baseImageRegistry := config.CFG.RemoteBaseImageRegistry.GetString() fmt.Printf("Authenticating with base image registry: %s\n", baseImageRegistry) - err := airflow.DockerLogin(baseImageRegistry, registryUsername, "Bearer "+os.Getenv("ASTRO_API_TOKEN")) + err := airflow.DockerLogin(baseImageRegistry, registryUsername, deployInput.Creds.Get()) if err != nil { fmt.Println("Failed to authenticate with Astronomer registry that contains the base agent image used in the Dockerfile.client file.") fmt.Println("This could be because either your token has expired or you don't have permission to pull the base agent image.") diff --git a/cloud/deploy/deploy_test.go b/cloud/deploy/deploy_test.go index 498971f21..a0ce7e06e 100644 --- a/cloud/deploy/deploy_test.go +++ b/cloud/deploy/deploy_test.go @@ -20,6 +20,7 @@ import ( astroplatformcore_mocks "github.com/astronomer/astro-cli/astro-client-platform-core/mocks" "github.com/astronomer/astro-cli/cloud/deployment" "github.com/astronomer/astro-cli/config" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/fileutil" "github.com/astronomer/astro-cli/pkg/httputil" testUtil "github.com/astronomer/astro-cli/pkg/testing" @@ -179,6 +180,7 @@ func TestDeployWithoutDagsDeploySuccess(t *testing.T) { Prompt: true, WaitForStatus: false, Dags: false, + Creds: &credentials.CurrentCredentials{}, } testUtil.InitTestConfig(testUtil.LocalPlatform) config.CFG.ShowWarnings.SetHomeString("false") @@ -279,6 +281,7 @@ func TestDeployOnRemoteExecutionDeployment(t *testing.T) { Prompt: true, WaitForStatus: false, Dags: false, + Creds: &credentials.CurrentCredentials{}, } testUtil.InitTestConfig(testUtil.LocalPlatform) config.CFG.ShowWarnings.SetHomeString("false") @@ -383,10 +386,11 @@ func TestDeployOnCiCdEnforcedDeployment(t *testing.T) { Prompt: true, WaitForStatus: false, Dags: false, + Creds: &credentials.CurrentCredentials{}, } testUtil.InitTestConfig(testUtil.LocalPlatform) config.CFG.ShowWarnings.SetHomeString("false") - canCiCdDeploy = func(astroAPIToken string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return false } @@ -420,6 +424,7 @@ func TestDeployWithDagsDeploySuccess(t *testing.T) { Prompt: true, WaitForStatus: false, Dags: false, + Creds: &credentials.CurrentCredentials{}, } testUtil.InitTestConfig(testUtil.LocalPlatform) config.CFG.ShowWarnings.SetHomeString("false") @@ -522,6 +527,7 @@ func TestDeployWithDagsDeploySuccess(t *testing.T) { Prompt: true, WaitForStatus: false, Dags: false, + Creds: &credentials.CurrentCredentials{}, } defer testUtil.MockUserInput(t, "1")() err = Deploy(deployInput, mockPlatformCoreClient, mockCoreClient) @@ -552,6 +558,7 @@ func TestDagsDeploySuccess(t *testing.T) { Dags: true, WaitForStatus: false, DagsPath: "./testfiles/dags", + Creds: &credentials.CurrentCredentials{}, } testUtil.InitTestConfig(testUtil.LocalPlatform) config.CFG.ShowWarnings.SetHomeString("false") @@ -637,6 +644,7 @@ func TestImageOnlyDeploySuccess(t *testing.T) { Image: true, WaitForStatus: false, DagsPath: "./testfiles/dags", + Creds: &credentials.CurrentCredentials{}, } testUtil.InitTestConfig(testUtil.LocalPlatform) config.CFG.ShowWarnings.SetHomeString("false") @@ -700,6 +708,7 @@ func TestNoDagsDeploy(t *testing.T) { Prompt: true, WaitForStatus: false, Dags: true, + Creds: &credentials.CurrentCredentials{}, } defer testUtil.MockUserInput(t, "1")() err = Deploy(deployInput, mockPlatformCoreClient, mockCoreClient) @@ -734,6 +743,7 @@ func TestNoDagsDeployForceSkipsPrompt(t *testing.T) { WsID: ws, Dags: true, Force: true, + Creds: &credentials.CurrentCredentials{}, } err = Deploy(deployInput, mockPlatformCoreClient, mockCoreClient) assert.NoError(t, err) @@ -780,6 +790,7 @@ func TestNoDagsImageDeployForceSkipsPrompt(t *testing.T) { EnvFile: "./testfiles/.env", Dags: false, Force: true, + Creds: &credentials.CurrentCredentials{}, } err = Deploy(deployInput, mockPlatformCoreClient, mockCoreClient) assert.NoError(t, err) @@ -805,6 +816,7 @@ func TestDagsDeployFailed(t *testing.T) { Prompt: true, WaitForStatus: false, Dags: true, + Creds: &credentials.CurrentCredentials{}, } mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(3) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(6) @@ -868,6 +880,7 @@ func TestDeployFailure(t *testing.T) { Prompt: true, WaitForStatus: false, Dags: false, + Creds: &credentials.CurrentCredentials{}, } defer testUtil.MockUserInput(t, "y")() @@ -945,6 +958,7 @@ func TestDeployMonitoringDAGNonHosted(t *testing.T) { Prompt: true, Dags: true, DagsPath: "./testfiles/dags", + Creds: &credentials.CurrentCredentials{}, } testUtil.InitTestConfig(testUtil.LocalPlatform) config.CFG.ShowWarnings.SetHomeString("false") @@ -1028,6 +1042,7 @@ func TestDeployNoMonitoringDAGHosted(t *testing.T) { Prompt: true, Dags: true, DagsPath: "./testfiles/dags", + Creds: &credentials.CurrentCredentials{}, } testUtil.InitTestConfig(testUtil.LocalPlatform) config.CFG.ShowWarnings.SetHomeString("false") @@ -1428,6 +1443,7 @@ func TestDeployClientImage(t *testing.T) { deployInput := InputClientDeploy{ Path: tempDir, BuildSecretString: "", + Creds: credentials.New("Bearer test-token"), } err = DeployClientImage(deployInput, nil) @@ -1458,6 +1474,7 @@ func TestDeployClientImage(t *testing.T) { deployInput := InputClientDeploy{ Path: "/test/path", BuildSecretString: "", + Creds: &credentials.CurrentCredentials{}, } err = DeployClientImage(deployInput, nil) @@ -1486,6 +1503,7 @@ func TestDeployClientImage(t *testing.T) { deployInput := InputClientDeploy{ Path: "/test/path", BuildSecretString: "", + Creds: &credentials.CurrentCredentials{}, } err = DeployClientImage(deployInput, nil) @@ -1542,6 +1560,7 @@ func TestDeployClientImage(t *testing.T) { deployInput := InputClientDeploy{ Path: tempDir, BuildSecretString: "", + Creds: &credentials.CurrentCredentials{}, } err = DeployClientImage(deployInput, nil) @@ -1599,6 +1618,7 @@ func TestDeployClientImage(t *testing.T) { deployInput := InputClientDeploy{ Path: tempDir, BuildSecretString: "", + Creds: &credentials.CurrentCredentials{}, } err = DeployClientImage(deployInput, nil) @@ -1646,6 +1666,7 @@ func TestDeployClientImage(t *testing.T) { Path: tempDir, ImageName: "custom-image:tag", BuildSecretString: "", + Creds: &credentials.CurrentCredentials{}, } err = DeployClientImage(deployInput, nil) diff --git a/cloud/deployment/deployment.go b/cloud/deployment/deployment.go index 3e7229da9..02ded5ac7 100644 --- a/cloud/deployment/deployment.go +++ b/cloud/deployment/deployment.go @@ -21,6 +21,7 @@ import ( "github.com/astronomer/astro-cli/cloud/workspace" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/pkg/ansi" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/domainutil" "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/input" @@ -125,21 +126,26 @@ func deploymentTableConfig(fromAllWorkspaces bool, ws string) *output.TableConfi ) } -func CanCiCdDeploy(bearerToken string) bool { - token := strings.Split(bearerToken, " ")[1] // Stripping Bearer - // Parse the token to peek at the custom claims - claims, err := parseToken(token) +func CanCiCdDeploy(creds *credentials.CurrentCredentials) bool { + if creds == nil { + return false + } + bearerToken := creds.Get() + if bearerToken == "" { + return false + } + parts := strings.SplitN(bearerToken, " ", 2) + if len(parts) < 2 { + return false + } + claims, err := parseToken(parts[1]) if err != nil { fmt.Println("Unable to Parse Token") return false } // Only API Tokens and API Keys have permissions - if len(claims.Permissions) > 0 { - return true - } - - return false + return len(claims.Permissions) > 0 } // deploymentToInfo converts a deployment to DeploymentInfo for structured output @@ -870,7 +876,7 @@ func HealthPoll(deploymentID, ws string, sleepTime, tickNum, timeoutNum int, pla } // TODO (https://github.com/astronomer/astro-cli/issues/1709): move these input arguments to a struct, and drop the nolint -func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, executor, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, schedulerAU, schedulerReplicas int, wQueueList []astroplatformcore.WorkerQueueRequest, hybridQueueList []astroplatformcore.HybridWorkerQueueRequest, newEnvironmentVariables []astroplatformcore.DeploymentEnvironmentVariableRequest, allowedIpAddressRanges *[]string, taskLogBucket *string, taskLogUrlPattern *string, force bool, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) error { //nolint +func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, executor, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, schedulerAU, schedulerReplicas int, wQueueList []astroplatformcore.WorkerQueueRequest, hybridQueueList []astroplatformcore.HybridWorkerQueueRequest, newEnvironmentVariables []astroplatformcore.DeploymentEnvironmentVariableRequest, allowedIpAddressRanges *[]string, taskLogBucket *string, taskLogUrlPattern *string, force bool, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, creds *credentials.CurrentCredentials) error { //nolint var queueCreateUpdate, confirmWithUser bool // get deployment currentDeployment, err := GetDeployment(ws, deploymentID, deploymentName, false, nil, platformCoreClient, coreClient) @@ -893,7 +899,7 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec isCicdEnforced = true } if !force && isCicdEnforced && dagDeploy != "" { - if !canCiCdDeploy("Bearer " + os.Getenv("ASTRO_API_TOKEN")) { + if !canCiCdDeploy(creds) { fmt.Printf("\nWarning: You are trying to update the dag deploy setting with ci-cd enforcement enabled. Once the setting is updated, you will not be able to deploy your dags using the CLI. Until you deploy your dags, dags will not be visible in the UI nor will new tasks start." + "\nAfter the setting is updated, either disable cicd enforcement and then deploy your dags OR deploy your dags via CICD or using API Tokens.") y, _ := input.Confirm("\n\nAre you sure you want to continue?") diff --git a/cloud/deployment/deployment_test.go b/cloud/deployment/deployment_test.go index f6094543a..3098dfe3f 100644 --- a/cloud/deployment/deployment_test.go +++ b/cloud/deployment/deployment_test.go @@ -21,6 +21,7 @@ import ( astroplatformcore_mocks "github.com/astronomer/astro-cli/astro-client-platform-core/mocks" "github.com/astronomer/astro-cli/cloud/organization" "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/credentials" testUtil "github.com/astronomer/astro-cli/pkg/testing" "github.com/astronomer/astro-cli/pkg/util" ) @@ -1591,13 +1592,13 @@ func (s *Suite) TestCanCiCdDeploy() { return &mockClaims, nil } - canDeploy := CanCiCdDeploy("bearer token") + canDeploy := CanCiCdDeploy(credentials.New("bearer token")) s.Equal(canDeploy, false) parseToken = func(astroAPIToken string) (*util.CustomClaims, error) { return nil, errMock } - canDeploy = CanCiCdDeploy("bearer token") + canDeploy = CanCiCdDeploy(credentials.New("bearer token")) s.Equal(canDeploy, false) permissions = []string{ @@ -1612,7 +1613,7 @@ func (s *Suite) TestCanCiCdDeploy() { return &mockClaims, nil } - canDeploy = CanCiCdDeploy("bearer token") + canDeploy = CanCiCdDeploy(credentials.New("bearer token")) s.Equal(canDeploy, true) } @@ -1668,7 +1669,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "1")() // success with hybrid type in this test nothing is being change just ensuring that dag deploy stays true. Addtionally no deployment id/name is given so user input is needed to select one - err := Update("", "", ws, "", "", "", CeleryExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, true, mockCoreClient, mockPlatformCoreClient) + err := Update("", "", ws, "", "", "", CeleryExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, true, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) s.Equal(deploymentResponse.JSON200.IsDagDeployEnabled, dagDeployEnabled) @@ -1677,7 +1678,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() // success updating the kubernetes executor on hybrid type. deployment name is given - err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // change type to standard @@ -1689,7 +1690,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() // success with standard type and deployment name input and dag deploy stays the same - err = Update("test-id-1", "", ws, "", "", "", CeleryExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "", "", "", CeleryExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) s.Equal(deploymentResponse.JSON200.IsDagDeployEnabled, dagDeployEnabled) @@ -1697,7 +1698,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() // success updating to kubernetes executor on standard type - err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // change type to dedicatd @@ -1707,7 +1708,7 @@ func (s *Suite) TestUpdate() { //nolint // defer testUtil.MockUserInput(t, "1")() // success with dedicated type no changes made asserts that dag deploy stays the same - err = Update("test-id-1", "", ws, "", "", "", CeleryExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "", "", "", CeleryExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) s.Equal(deploymentResponse.JSON200.IsDagDeployEnabled, dagDeployEnabled) @@ -1715,7 +1716,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() // success with dedicated updating to kubernetes executor - err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) s.Run("successfully update schedulerSize and highAvailability and CICDEnforement", func() { @@ -1735,7 +1736,7 @@ func (s *Suite) TestUpdate() { //nolint // Mock user input for deployment name defer testUtil.MockUserInput(s.T(), "1")() // success with standard type with name - err := Update("", "test", ws, "", "", "enable", CeleryExecutor, "medium", "disable", "", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("", "test", ws, "", "", "enable", CeleryExecutor, "medium", "disable", "", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // change type to dedicatd @@ -1746,20 +1747,20 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "1")() // success with dedicated type - err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "enable", "", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "enable", "", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // success with large scheduler size - err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "large", "enable", "", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "large", "enable", "", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // success with extra large scheduler size - err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "extra_large", "enable", "", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "extra_large", "enable", "", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // success with hybrid type with id deploymentResponse.JSON200.Type = &hybridType - err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // Mock user input for deployment name @@ -1767,7 +1768,7 @@ func (s *Suite) TestUpdate() { //nolint // success with hybrid type with id deploymentResponse.JSON200.Executor = &executorKubernetes - err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -1785,7 +1786,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "1")() // success with standard type with name - err := Update("", "test", ws, "", "", "enable", CeleryExecutor, "medium", "disable", "disable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("", "test", ws, "", "", "enable", CeleryExecutor, "medium", "disable", "disable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // change type to dedicatd and set development mode to false @@ -1796,7 +1797,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "1")() // success with dedicated type - err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "disable", "enable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "disable", "enable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -1829,7 +1830,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "1")() // success with standard type with name - err := Update("", "test", ws, "", "", "enable", CeleryExecutor, "medium", "disable", "disable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("", "test", ws, "", "", "enable", CeleryExecutor, "medium", "disable", "disable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // mock os.Stdin @@ -1837,7 +1838,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "1")() // success with standard type - err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "enable", "disable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, nil, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "enable", "disable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, nil, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -1870,7 +1871,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "1")() // success with dedicated type with name - err := Update("", "test", ws, "", "", "enable", CeleryExecutor, "medium", "disable", "disable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("", "test", ws, "", "", "enable", CeleryExecutor, "medium", "disable", "disable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // mock os.Stdin @@ -1878,7 +1879,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "1")() // success with dedicated type - err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "disable", "enable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, nil, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "disable", "enable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, nil, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -1887,35 +1888,35 @@ func (s *Suite) TestUpdate() { //nolint mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(2) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(2) - err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.ErrorIs(err, errMock) mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) deploymentResponse.JSON200.Type = &hybridType - err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "10Gi", "2CPU", "10Gi", "", 100, 100, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "10Gi", "2CPU", "10Gi", "", 100, 100, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.ErrorIs(err, ErrInvalidResourceRequest) }) s.Run("list deployments failure", func() { mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, errMock).Times(1) - err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.ErrorIs(err, errMock) }) s.Run("invalid deployment id", func() { mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, errMock).Times(1) // list deployment error - err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.ErrorIs(err, errMock) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(3) // invalid id - err = Update("invalid-id", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("invalid-id", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.ErrorContains(err, "the Deployment specified was not found in this workspace.") // invalid name - err = Update("", "", ws, "update", "invalid-name", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("", "", ws, "update", "invalid-name", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.ErrorContains(err, "the Deployment specified was not found in this workspace.") // mock os.Stdin @@ -1923,7 +1924,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "0")() // invalid selection - err = Update("", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.ErrorContains(err, "invalid Deployment selected") }) @@ -1938,7 +1939,7 @@ func (s *Suite) TestUpdate() { //nolint // Mock user input for deployment name defer testUtil.MockUserInput(s.T(), "n")() - err := Update("test-id-1", "", ws, "update", "", "disable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "disable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -1949,7 +1950,7 @@ func (s *Suite) TestUpdate() { //nolint mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Times(1) - err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.ErrorIs(err, errMock) s.NotContains(err.Error(), organization.AstronomerConnectionErrMsg) }) @@ -1959,7 +1960,7 @@ func (s *Suite) TestUpdate() { //nolint deploymentResponse.JSON200.IsDagDeployEnabled = true mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) - err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -1967,12 +1968,12 @@ func (s *Suite) TestUpdate() { //nolint mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) - canCiCdDeploy = func(astroAPIToken string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return false } defer testUtil.MockUserInput(s.T(), "n")() - err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "enable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "enable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -1981,7 +1982,7 @@ func (s *Suite) TestUpdate() { //nolint mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) - err := Update("test-id-1", "", ws, "update", "", "disable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "disable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -2002,7 +2003,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() - err := Update("test-id-1", "", ws, "update", "", "", KubeExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "", KubeExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // change type to standard @@ -2010,7 +2011,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() // test update with standard type - err = Update("test-id-1", "", ws, "update", "", "", KubeExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "update", "", "", KubeExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // change type to standard @@ -2018,7 +2019,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() // test update with standard type - err = Update("test-id-1", "", ws, "update", "", "", KubeExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "update", "", "", KubeExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -2037,7 +2038,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() // test update with standard type - err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // change type to standard @@ -2045,14 +2046,14 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "y")() // test update with standard type - err = Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) defer testUtil.MockUserInput(s.T(), "y")() // test update with hybrid type deploymentResponse.JSON200.Type = &hybridType - err = Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -2068,7 +2069,7 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "n")() - err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -2096,7 +2097,7 @@ func (s *Suite) TestUpdate() { //nolint mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Once() mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponseWithNoNodePools, nil).Once() - err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) @@ -2119,7 +2120,7 @@ func (s *Suite) TestUpdate() { //nolint mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Once() // Call the Update function with a non-empty workload ID - err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "small", "enable", "", "disable", "", "", "", "", mockWorkloadIdentity, 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, true, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "small", "enable", "", "disable", "", "", "", "", mockWorkloadIdentity, 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, true, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) s.Run("update deployment to change executor to/from ASTRO executor", func() { @@ -2134,26 +2135,26 @@ func (s *Suite) TestUpdate() { //nolint // CELERY -> ASTRO deploymentResponse.JSON200.Executor = &executorCelery defer testUtil.MockUserInput(s.T(), "y")() - err := Update("test-id-1", "", ws, "", "", "", AstroExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "", "", "", AstroExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // ASTRO -> CELERY astroExecutor := astroplatformcore.DeploymentExecutorASTRO deploymentResponse.JSON200.Executor = &astroExecutor defer testUtil.MockUserInput(s.T(), "y")() - err = Update("test-id-1", "", ws, "", "", "", CeleryExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "", "", "", CeleryExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // ASTRO -> KUBERNETES deploymentResponse.JSON200.Executor = &astroExecutor defer testUtil.MockUserInput(s.T(), "y")() - err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) // KUBERNETES -> ASTRO deploymentResponse.JSON200.Executor = &executorKubernetes defer testUtil.MockUserInput(s.T(), "y")() - err = Update("test-id-1", "", ws, "", "", "", AstroExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient) + err = Update("test-id-1", "", ws, "", "", "", AstroExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, nil, nil, nil, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) s.Run("update deployment to change remote execution config", func() { @@ -2176,7 +2177,7 @@ func (s *Suite) TestUpdate() { //nolint newTaskLogURLPattern := "new-task-log-url-pattern" newAllowedIPAddressRanges := []string{"1.2.3.5/32"} defer testUtil.MockUserInput(s.T(), "y")() - err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, &newAllowedIPAddressRanges, &newTaskLogBucket, &newTaskLogURLPattern, false, mockCoreClient, mockPlatformCoreClient) + err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, &newAllowedIPAddressRanges, &newTaskLogBucket, &newTaskLogURLPattern, false, mockCoreClient, mockPlatformCoreClient, nil) s.NoError(err) }) } diff --git a/cloud/deployment/deployment_variable.go b/cloud/deployment/deployment_variable.go index 824491504..92a5f9b79 100644 --- a/cloud/deployment/deployment_variable.go +++ b/cloud/deployment/deployment_variable.go @@ -139,7 +139,7 @@ func VariableModify( } // update deployment - err = Update(currentDeployment.Id, "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", 0, 0, []astroplatformcore.WorkerQueueRequest{}, []astroplatformcore.HybridWorkerQueueRequest{}, newEnvironmentVariables, nil, nil, nil, false, coreClient, platformCoreClient) + err = Update(currentDeployment.Id, "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", 0, 0, []astroplatformcore.WorkerQueueRequest{}, []astroplatformcore.HybridWorkerQueueRequest{}, newEnvironmentVariables, nil, nil, nil, false, coreClient, platformCoreClient, nil) if err != nil { return err } diff --git a/cloud/deployment/fromfile/fromfile.go b/cloud/deployment/fromfile/fromfile.go index 4ddb2e225..78346943c 100644 --- a/cloud/deployment/fromfile/fromfile.go +++ b/cloud/deployment/fromfile/fromfile.go @@ -21,6 +21,7 @@ import ( "github.com/astronomer/astro-cli/cloud/organization" "github.com/astronomer/astro-cli/cloud/workspace" "github.com/astronomer/astro-cli/config" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/input" ) @@ -53,7 +54,7 @@ const ( // CreateOrUpdate takes a file and creates a deployment with the confiuration specified in the file. // inputFile can be in yaml or json format // It returns an error if any required information is missing or incorrectly specified. -func CreateOrUpdate(inputFile, action string, astroPlatformCore astroplatformcore.CoreClient, coreClient astrocore.CoreClient, out io.Writer, waitForStatus bool, waitTime time.Duration, force bool) error { //nolint +func CreateOrUpdate(inputFile, action string, astroPlatformCore astroplatformcore.CoreClient, coreClient astrocore.CoreClient, out io.Writer, waitForStatus bool, waitTime time.Duration, force bool, creds *credentials.CurrentCredentials) error { //nolint var ( err error errHelp, clusterID, workspaceID, outputFormat string @@ -128,7 +129,7 @@ func CreateOrUpdate(inputFile, action string, astroPlatformCore astroplatformcor } // this deployment does not exist so create it // transform formattedDeployment to DeploymentCreateInput - err = createOrUpdateDeployment(&formattedDeployment, clusterID, workspaceID, createAction, &astroplatformcore.Deployment{}, nodePools, dagDeploy, envVars, coreClient, astroPlatformCore, waitForStatus, waitTime, force) + err = createOrUpdateDeployment(&formattedDeployment, clusterID, workspaceID, createAction, &astroplatformcore.Deployment{}, nodePools, dagDeploy, envVars, coreClient, astroPlatformCore, waitForStatus, waitTime, force, creds) if err != nil { return err } @@ -171,7 +172,7 @@ func CreateOrUpdate(inputFile, action string, astroPlatformCore astroplatformcor return fmt.Errorf("%w \n failed to %s alert emails", err, action) } // transform formattedDeployment to DeploymentUpdateInput - err = createOrUpdateDeployment(&formattedDeployment, clusterID, workspaceID, updateAction, &existingDeployment, nodePools, dagDeploy, envVars, coreClient, astroPlatformCore, waitForStatus, waitTime, force) + err = createOrUpdateDeployment(&formattedDeployment, clusterID, workspaceID, updateAction, &existingDeployment, nodePools, dagDeploy, envVars, coreClient, astroPlatformCore, waitForStatus, waitTime, force, creds) if err != nil { return err } @@ -196,7 +197,7 @@ func CreateOrUpdate(inputFile, action string, astroPlatformCore astroplatformcor // It returns an error if node pool id could not be found for the worker type. // //nolint:dupl -func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, clusterID, workspaceID, action string, existingDeployment *astroplatformcore.Deployment, nodePools []astroplatformcore.NodePool, dagDeploy bool, envVars []astroplatformcore.DeploymentEnvironmentVariableRequest, coreClient astrocore.CoreClient, astroPlatformCore astroplatformcore.CoreClient, waitForStatus bool, waitTime time.Duration, force bool) error { //nolint +func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, clusterID, workspaceID, action string, existingDeployment *astroplatformcore.Deployment, nodePools []astroplatformcore.NodePool, dagDeploy bool, envVars []astroplatformcore.DeploymentEnvironmentVariableRequest, coreClient astrocore.CoreClient, astroPlatformCore astroplatformcore.CoreClient, waitForStatus bool, waitTime time.Duration, force bool, creds *credentials.CurrentCredentials) error { //nolint var ( defaultOptions astroplatformcore.WorkerQueueOptions configOptions astroplatformcore.DeploymentOptions @@ -694,7 +695,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c } // update deployment if !force && deploymentFromFile.Deployment.Configuration.APIKeyOnlyDeployments && dagDeploy { - if !canCiCdDeploy("Bearer " + os.Getenv("ASTRO_API_TOKEN")) { + if !canCiCdDeploy(creds) { fmt.Printf("\nWarning: You are trying to update dag deploy setting on a deployment with ci-cd enforcement enabled. You will not be able to deploy your dags using the CLI and that dags will not be visible in the UI and new tasks will not start." + "\nEither disable ci-cd enforcement or please cancel this operation and use API Tokens instead.") y, _ := input.Confirm("\n\nAre you sure you want to continue?") diff --git a/cloud/deployment/fromfile/fromfile_test.go b/cloud/deployment/fromfile/fromfile_test.go index c15f03b3c..458ac7a39 100644 --- a/cloud/deployment/fromfile/fromfile_test.go +++ b/cloud/deployment/fromfile/fromfile_test.go @@ -18,6 +18,7 @@ import ( astroplatformcore_mocks "github.com/astronomer/astro-cli/astro-client-platform-core/mocks" "github.com/astronomer/astro-cli/cloud/deployment" "github.com/astronomer/astro-cli/cloud/deployment/inspect" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/fileutil" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -430,7 +431,7 @@ func (s *Suite) TestCreateOrUpdate() { ) s.Run("returns an error if file does not exist", func() { - err = CreateOrUpdate("deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false, nil) s.ErrorContains(err, "open deployment.yaml: no such file or directory") }) s.Run("returns an error if file exists but user provides incorrect path", func() { @@ -439,7 +440,7 @@ func (s *Suite) TestCreateOrUpdate() { err = fileutil.WriteStringToFile(filePath, data) s.NoError(err) defer afero.NewOsFs().RemoveAll("./2") - err = CreateOrUpdate("1/deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false) + err = CreateOrUpdate("1/deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false, nil) s.ErrorContains(err, "open 1/deployment.yaml: no such file or directory") }) s.Run("returns an error if file is empty", func() { @@ -447,7 +448,7 @@ func (s *Suite) TestCreateOrUpdate() { data = "" fileutil.WriteStringToFile(filePath, data) defer afero.NewOsFs().Remove(filePath) - err = CreateOrUpdate("deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false, nil) s.ErrorIs(err, errEmptyFile) s.ErrorContains(err, "deployment.yaml has no content") }) @@ -456,7 +457,7 @@ func (s *Suite) TestCreateOrUpdate() { data = "test" fileutil.WriteStringToFile(filePath, data) defer afero.NewOsFs().Remove(filePath) - err = CreateOrUpdate("deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false, nil) s.ErrorContains(err, "error unmarshaling JSON:") }) s.Run("returns an error if required fields are missing", func() { @@ -511,7 +512,7 @@ deployment: ` fileutil.WriteStringToFile(filePath, data) defer afero.NewOsFs().Remove(filePath) - err = CreateOrUpdate("deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", nil, nil, nil, false, 0*time.Second, false, nil) s.ErrorContains(err, "missing required field: deployment.configuration.name") }) s.Run("returns an error if getting context fails", func() { @@ -572,7 +573,7 @@ deployment: fileutil.WriteStringToFile(filePath, data) defer afero.NewOsFs().Remove(filePath) - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorContains(err, "no context set") }) s.Run("returns an error if cluster does not exist", func() { @@ -633,7 +634,7 @@ deployment: fileutil.WriteStringToFile(filePath, data) defer afero.NewOsFs().Remove(filePath) mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorIs(err, errNotFound) mockCoreClient.AssertExpectations(s.T()) }) @@ -695,7 +696,7 @@ deployment: fileutil.WriteStringToFile(filePath, data) defer afero.NewOsFs().Remove(filePath) mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, errTest).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorIs(err, errTest) }) s.Run("returns an error if listing deployment fails", func() { @@ -758,7 +759,7 @@ deployment: mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, errTest).Times(1) - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorIs(err, errTest) mockCoreClient.AssertExpectations(s.T()) }) @@ -831,7 +832,7 @@ deployment: mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.NotNil(out) mockCoreClient.AssertExpectations(s.T()) @@ -901,7 +902,7 @@ deployment: mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.NotNil(out) mockCoreClient.AssertExpectations(s.T()) @@ -985,7 +986,7 @@ deployment: mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, errTest).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorIs(err, errTest) mockCoreClient.AssertExpectations(s.T()) }) @@ -1054,7 +1055,7 @@ deployment: mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "metadata:\n deployment_id: test-deployment-id") @@ -1111,7 +1112,7 @@ deployment: mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "metadata:\n deployment_id: test-deployment-id") @@ -1184,7 +1185,7 @@ deployment: description: hibernation schedule 1 enabled: true ` - canCiCdDeploy = func(astroAPIToken string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return true } fileutil.WriteStringToFile(filePath, data) @@ -1213,7 +1214,7 @@ deployment: )).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, "test-deployment-id").Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "metadata:\n deployment_id: test-deployment-id") @@ -1277,7 +1278,7 @@ deployment: - test1@test.com - test2@test.com ` - canCiCdDeploy = func(astroAPIToken string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return true } fileutil.WriteStringToFile(filePath, data) @@ -1302,7 +1303,7 @@ deployment: )).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, "test-deployment-id").Return(&deploymentResponseRemoteExecution, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "metadata:\n deployment_id: test-deployment-id") @@ -1380,7 +1381,7 @@ deployment: }` fileutil.WriteStringToFile(filePath, data) defer afero.NewOsFs().Remove(filePath) - canCiCdDeploy = func(astroAPIToken string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return true } mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(2) @@ -1392,7 +1393,7 @@ deployment: mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "\"configuration\": {\n \"name\": \"test-deployment-label\"") s.Contains(out.String(), "\"metadata\": {\n \"deployment_id\": \"test-deployment-id\"") @@ -1492,7 +1493,7 @@ deployment: )).Return(&mockCreateDeploymentResponse, nil).Once() mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "\"configuration\": {\n \"name\": \"test-deployment-label\"") s.Contains(out.String(), "\"metadata\": {\n \"deployment_id\": \"test-deployment-id\"") @@ -1594,7 +1595,7 @@ deployment: )).Return(&mockCreateDeploymentResponse, nil).Once() mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "\"configuration\": {\n \"name\": \"test-deployment-label\"") s.Contains(out.String(), "\"metadata\": {\n \"deployment_id\": \"test-deployment-id\"") @@ -1660,7 +1661,7 @@ deployment: defer afero.NewOsFs().Remove(filePath) mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&EmptyListWorkspacesResponseOK, errTest).Times(1) - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorIs(err, errTest) mockCoreClient.AssertExpectations(s.T()) }) @@ -1723,7 +1724,7 @@ deployment: mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(1) mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorContains(err, "deployment: test-deployment-label already exists: use deployment update --deployment-file deployment.yaml instead") mockCoreClient.AssertExpectations(s.T()) }) @@ -1802,7 +1803,7 @@ deployment: mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.Error(err) s.ErrorContains(err, "worker queue option is invalid: worker concurrency") mockCoreClient.AssertExpectations(s.T()) @@ -1887,7 +1888,7 @@ deployment: mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(1) mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, errCreateFailed).Once() - err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorIs(err, errCreateFailed) mockCoreClient.AssertExpectations(s.T()) }) @@ -1954,7 +1955,7 @@ deployment: mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "\n description: description 1") @@ -2010,7 +2011,7 @@ deployment: mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "\n description: description 1") @@ -2098,7 +2099,7 @@ deployment: )).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "\n description: description 1") @@ -2183,7 +2184,7 @@ deployment: mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "\n description: description 1") @@ -2258,11 +2259,11 @@ deployment: mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - canCiCdDeploy = func(astroAPIToken string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return false } - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) defer testUtil.MockUserInput(s.T(), "n")() s.NoError(err) mockCoreClient.AssertExpectations(s.T()) @@ -2335,11 +2336,11 @@ deployment: mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - canCiCdDeploy = func(astroAPIToken string) bool { + canCiCdDeploy = func(creds *credentials.CurrentCredentials) bool { return false } - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, true) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, true, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "metadata:\n deployment_id: test-deployment-id") @@ -2425,7 +2426,7 @@ deployment: mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "test-deployment-label") s.Contains(out.String(), "description 1") @@ -2516,7 +2517,7 @@ deployment: )).Return(&mockUpdateDeploymentResponse, nil) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out, false, 0*time.Second, false, nil) s.NoError(err) s.Contains(out.String(), "configuration:\n name: test-deployment-label") s.Contains(out.String(), "\n description: description 1") @@ -2584,7 +2585,7 @@ deployment: mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(2) - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorContains(err, "deployment: test-deployment-label does not exist: use deployment create --deployment-file deployment.yaml instead") mockCoreClient.AssertExpectations(s.T()) }) @@ -2664,7 +2665,7 @@ deployment: mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(2) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.Error(err) s.ErrorContains(err, "worker queue option is invalid: worker concurrency") mockCoreClient.AssertExpectations(s.T()) @@ -2746,7 +2747,7 @@ deployment: mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, errUpdateFailed).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) - err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false) + err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, nil, false, 0*time.Second, false, nil) s.ErrorIs(err, errUpdateFailed) s.ErrorContains(err, "failed to update deployment with input") mockCoreClient.AssertExpectations(s.T()) @@ -2819,7 +2820,7 @@ deployment: s.Require().NoError(err) // Run createOrUpdateDeployment with waitForStatus=true, allow time for 2 polls - err = createOrUpdateDeployment(&fd, "", "ws-id", createAction, &astroplatformcore.Deployment{}, nil, false, nil, nil, nil, true, 3*time.Second, false) + err = createOrUpdateDeployment(&fd, "", "ws-id", createAction, &astroplatformcore.Deployment{}, nil, false, nil, nil, nil, true, 3*time.Second, false, nil) s.NoError(err) s.Equal(2, callCount, "expected two polling iterations before becoming healthy") }) @@ -2845,7 +2846,7 @@ deployment: err := yaml.Unmarshal([]byte(minimalDeploymentYAML), &fd) s.Require().NoError(err) - err = createOrUpdateDeployment(&fd, "", "ws-id", createAction, &astroplatformcore.Deployment{}, nil, false, nil, nil, nil, false, 0*time.Second, false) + err = createOrUpdateDeployment(&fd, "", "ws-id", createAction, &astroplatformcore.Deployment{}, nil, false, nil, nil, nil, false, 0*time.Second, false, nil) s.NoError(err) s.Equal(0, callCount, "expected no polling when waitForStatus is false") }) @@ -2876,7 +2877,7 @@ deployment: err := yaml.Unmarshal([]byte(minimalDeploymentYAML), &fd) s.Require().NoError(err) - err = createOrUpdateDeployment(&fd, "", "ws-id", createAction, &astroplatformcore.Deployment{}, nil, false, nil, nil, nil, true, 100*time.Millisecond, false) + err = createOrUpdateDeployment(&fd, "", "ws-id", createAction, &astroplatformcore.Deployment{}, nil, false, nil, nil, nil, true, 100*time.Millisecond, false, nil) s.ErrorIs(err, deployment.ErrTimedOut, "expected ErrTimedOut when deployment does not become healthy") }) @@ -2907,7 +2908,7 @@ deployment: err := yaml.Unmarshal([]byte(minimalDeploymentYAML), &fd) s.Require().NoError(err) - err = createOrUpdateDeployment(&fd, "", "ws-id", createAction, &astroplatformcore.Deployment{}, nil, false, nil, nil, nil, true, 3*time.Second, false) + err = createOrUpdateDeployment(&fd, "", "ws-id", createAction, &astroplatformcore.Deployment{}, nil, false, nil, nil, nil, true, 3*time.Second, false, nil) s.ErrorIs(err, apiErr, "expected error returned from CoreGetDeployment to propagate") }) } @@ -2966,7 +2967,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { }, } - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.ErrorContains(err, "worker_type: test-worker-8 does not exist in cluster: test-cluster") mockCoreClient.AssertExpectations(s.T()) }) @@ -3012,7 +3013,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { }, } mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.ErrorContains(err, "worker queue option is invalid: min worker count") mockCoreClient.AssertExpectations(s.T()) }) @@ -3058,7 +3059,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { }, } mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, errTest).Times(1) - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.ErrorIs(err, errTest) mockCoreClient.AssertExpectations(s.T()) }) @@ -3120,7 +3121,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { } mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3158,7 +3159,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { }, } mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.ErrorContains(err, "don't use 'worker_queues' to update default queue with KubernetesExecutor") mockCoreClient.AssertExpectations(s.T()) }) @@ -3175,7 +3176,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { deploymentFromFile.Deployment.Configuration.DagDeployEnabled = &dagDeploy mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, nil, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, nil, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3205,7 +3206,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { } mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3271,7 +3272,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &astroplatformcore.Deployment{}, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3307,7 +3308,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { } mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "create", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3328,7 +3329,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { Name: "test-deployment", ClusterId: &clusterID, } - err = createOrUpdateDeployment(&deploymentFromFile, "diff-cluster", workspaceID, "update", &existingDeployment, nil, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, "diff-cluster", workspaceID, "update", &existingDeployment, nil, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.ErrorIs(err, errNotPermitted) s.ErrorContains(err, "changing an existing deployment's cluster is not permitted") mockCoreClient.AssertExpectations(s.T()) @@ -3368,7 +3369,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { } mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3409,7 +3410,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { } mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3459,7 +3460,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { } mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3509,7 +3510,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { } mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) @@ -3583,7 +3584,7 @@ func (s *Suite) TestGetCreateOrUpdateInput() { mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) - err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false) + err = createOrUpdateDeployment(&deploymentFromFile, clusterID, workspaceID, "update", &existingDeployment, existingPools, dagDeploy, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, mockCoreClient, mockPlatformCoreClient, false, 0*time.Second, false, nil) s.NoError(err) mockCoreClient.AssertExpectations(s.T()) }) diff --git a/cloud/deployment/workerqueue/workerqueue.go b/cloud/deployment/workerqueue/workerqueue.go index 39519ccf9..b812cc68c 100644 --- a/cloud/deployment/workerqueue/workerqueue.go +++ b/cloud/deployment/workerqueue/workerqueue.go @@ -231,7 +231,7 @@ func CreateOrUpdate(ws, deploymentID, deploymentName, name, action, workerType s } } // update the deployment with the new list of worker queues - err = deployment.Update(requestedDeployment.Id, "", ws, "", "", "", "", "", "", "", "", "", "", "", "", "", 0, 0, listToCreate, hybridListToCreate, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, nil, nil, nil, true, coreClient, platformCoreClient) + err = deployment.Update(requestedDeployment.Id, "", ws, "", "", "", "", "", "", "", "", "", "", "", "", "", 0, 0, listToCreate, hybridListToCreate, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, nil, nil, nil, true, coreClient, platformCoreClient, nil) if err != nil { return err } @@ -558,7 +558,7 @@ func Delete(ws, deploymentID, deploymentName, name string, force bool, platformC } } // update the deployment with the new list - err = deployment.Update(requestedDeployment.Id, "", ws, "", "", "", "", "", "", "", "", "", "", "", "", "", 0, 0, workerQueuesToKeep, hybridWorkerQueuesToKeep, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, nil, nil, nil, true, coreClient, platformCoreClient) + err = deployment.Update(requestedDeployment.Id, "", ws, "", "", "", "", "", "", "", "", "", "", "", "", "", 0, 0, workerQueuesToKeep, hybridWorkerQueuesToKeep, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, nil, nil, nil, true, coreClient, platformCoreClient, nil) if err != nil { return err } @@ -580,7 +580,7 @@ func Delete(ws, deploymentID, deploymentName, name string, force bool, platformC } } // update the deployment with the new list - err = deployment.Update(requestedDeployment.Id, "", ws, "", "", "", "", "", "", "", "", "", "", "", "", "", 0, 0, workerQueuesToKeep, hybridWorkerQueuesToKeep, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, nil, nil, nil, true, coreClient, platformCoreClient) + err = deployment.Update(requestedDeployment.Id, "", ws, "", "", "", "", "", "", "", "", "", "", "", "", "", 0, 0, workerQueuesToKeep, hybridWorkerQueuesToKeep, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, nil, nil, nil, true, coreClient, platformCoreClient, nil) if err != nil { return err } diff --git a/cloud/platformclient/client.go b/cloud/platformclient/client.go index b1b10294b..e9f06ddbd 100644 --- a/cloud/platformclient/client.go +++ b/cloud/platformclient/client.go @@ -2,10 +2,11 @@ package platformclient import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) // NewPlatformCoreClient creates an API client for Astro platform core services. -func NewPlatformCoreClient(c *httputil.HTTPClient, holder *httputil.TokenHolder) *astroplatformcore.ClientWithResponses { +func NewPlatformCoreClient(c *httputil.HTTPClient, holder *credentials.CurrentCredentials) *astroplatformcore.ClientWithResponses { return astroplatformcore.NewPlatformCoreClient(c, holder) } diff --git a/cloud/platformclient/client_test.go b/cloud/platformclient/client_test.go index 414c0dd4e..db35ba211 100644 --- a/cloud/platformclient/client_test.go +++ b/cloud/platformclient/client_test.go @@ -5,10 +5,11 @@ import ( "github.com/stretchr/testify/assert" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) func TestNewPlatformCoreClient(t *testing.T) { - client := NewPlatformCoreClient(httputil.NewHTTPClient(), &httputil.TokenHolder{}) + client := NewPlatformCoreClient(httputil.NewHTTPClient(), &credentials.CurrentCredentials{}) assert.NotNil(t, client, "Can't create new Astro Platform Core client") } diff --git a/cmd/api/airflow.go b/cmd/api/airflow.go index 60a7c2163..74491c593 100644 --- a/cmd/api/airflow.go +++ b/cmd/api/airflow.go @@ -18,6 +18,7 @@ import ( "github.com/astronomer/astro-cli/cloud/platformclient" "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/ansi" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/openapi" ) @@ -43,19 +44,19 @@ type AirflowOptions struct { // Internal detectedVersion string // The Airflow version being used (detected or overridden) CredentialsExplicit bool // true when --username or --password was explicitly passed - tokenHolder *httputil.TokenHolder + creds *credentials.CurrentCredentials } // NewAirflowCmd creates the 'astro api airflow' command. // //nolint:dupl -func NewAirflowCmd(out io.Writer, tokenHolder *httputil.TokenHolder) *cobra.Command { +func NewAirflowCmd(out io.Writer, creds *credentials.CurrentCredentials) *cobra.Command { opts := &AirflowOptions{ RequestOptions: RequestOptions{ Out: out, ErrOut: os.Stderr, }, - tokenHolder: tokenHolder, + creds: creds, } cmd := &cobra.Command{ @@ -460,7 +461,7 @@ func resolveDeploymentAirflowURL(opts *AirflowOptions) (baseURL, authToken strin } // Check for token - if opts.tokenHolder == nil || opts.tokenHolder.Get() == "" { + if opts.creds == nil || opts.creds.Get() == "" { return "", "", fmt.Errorf("not authenticated. Run 'astro login' to authenticate") } @@ -474,7 +475,7 @@ func resolveDeploymentAirflowURL(opts *AirflowOptions) (baseURL, authToken strin } // Create platform client - platformCoreClient := platformclient.NewPlatformCoreClient(httputil.NewHTTPClient(), opts.tokenHolder) + platformCoreClient := platformclient.NewPlatformCoreClient(httputil.NewHTTPClient(), opts.creds) // Fetch deployment dep, err := deployment.CoreGetDeployment(orgID, opts.DeploymentID, platformCoreClient) @@ -493,7 +494,7 @@ func resolveDeploymentAirflowURL(opts *AirflowOptions) (baseURL, authToken strin airflowURL = "https://" + airflowURL } - return airflowURL, opts.tokenHolder.Get(), nil + return airflowURL, opts.creds.Get(), nil } // runAirflowInteractive runs the airflow API command in interactive mode. diff --git a/cmd/api/airflow_test.go b/cmd/api/airflow_test.go index d78b1ec29..4c083ae43 100644 --- a/cmd/api/airflow_test.go +++ b/cmd/api/airflow_test.go @@ -14,7 +14,7 @@ import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" "github.com/astronomer/astro-cli/cloud/deployment" "github.com/astronomer/astro-cli/config" - "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/openapi" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -169,10 +169,10 @@ func TestResolveAirflowAPIURL_DeploymentID_Success(t *testing.T) { }, nil } - th := httputil.NewTokenHolder("test-token") + creds := credentials.New("test-token") opts := &AirflowOptions{ DeploymentID: "test-deployment-id", - tokenHolder: th, + creds: creds, } baseURL, authToken, err := resolveAirflowAPIURL(opts) @@ -209,11 +209,11 @@ func TestResolveAirflowAPIURL_DeploymentID_WithOrgOverride(t *testing.T) { }, nil } - th := httputil.NewTokenHolder("test-token") + creds := credentials.New("test-token") opts := &AirflowOptions{ DeploymentID: "test-deployment-id", OrganizationID: "override-org", - tokenHolder: th, + creds: creds, } baseURL, authToken, err := resolveAirflowAPIURL(opts) @@ -246,10 +246,10 @@ func TestResolveAirflowAPIURL_DeploymentID_NoAirflowURL(t *testing.T) { }, nil } - th := httputil.NewTokenHolder("test-token") + creds := credentials.New("test-token") opts := &AirflowOptions{ DeploymentID: "test-deployment-id", - tokenHolder: th, + creds: creds, } _, _, err = resolveAirflowAPIURL(opts) diff --git a/cmd/api/api.go b/cmd/api/api.go index 51f49561a..a18a8b295 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -8,16 +8,16 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" - "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/credentials" ) // NewAPICmd creates the parent 'astro api' command. -func NewAPICmd(tokenHolder *httputil.TokenHolder) *cobra.Command { - return NewAPICmdWithOutput(os.Stdout, tokenHolder) +func NewAPICmd(creds *credentials.CurrentCredentials) *cobra.Command { + return NewAPICmdWithOutput(os.Stdout, creds) } // NewAPICmdWithOutput creates the parent 'astro api' command with a custom output writer. -func NewAPICmdWithOutput(out io.Writer, tokenHolder *httputil.TokenHolder) *cobra.Command { +func NewAPICmdWithOutput(out io.Writer, creds *credentials.CurrentCredentials) *cobra.Command { var noColor bool cmd := &cobra.Command{ @@ -70,8 +70,8 @@ Use "astro api [command] --help" for more information about a command.`, cmd.PersistentFlags().BoolVar(&noColor, "no-color", false, "Disable colorized output") - cmd.AddCommand(NewAirflowCmd(out, tokenHolder)) - cmd.AddCommand(NewCloudCmd(out, tokenHolder)) + cmd.AddCommand(NewAirflowCmd(out, creds)) + cmd.AddCommand(NewCloudCmd(out, creds)) cmd.AddCommand(NewRegistryCmd(out)) return cmd diff --git a/cmd/api/cloud.go b/cmd/api/cloud.go index 4f5312474..5120c18bd 100644 --- a/cmd/api/cloud.go +++ b/cmd/api/cloud.go @@ -13,8 +13,8 @@ import ( "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/ansi" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/domainutil" - "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/openapi" ) @@ -23,19 +23,19 @@ type CloudOptions struct { RequestOptions SpecURL string // hidden flag: alternative OpenAPI spec URL SpecTokenEnvVar string // hidden flag: env var name containing auth token for spec fetch - tokenHolder *httputil.TokenHolder + creds *credentials.CurrentCredentials } // NewCloudCmd creates the 'astro api cloud' command. // //nolint:dupl -func NewCloudCmd(out io.Writer, tokenHolder *httputil.TokenHolder) *cobra.Command { +func NewCloudCmd(out io.Writer, creds *credentials.CurrentCredentials) *cobra.Command { opts := &CloudOptions{ RequestOptions: RequestOptions{ Out: out, ErrOut: os.Stderr, }, - tokenHolder: tokenHolder, + creds: creds, } cmd := &cobra.Command{ @@ -161,7 +161,7 @@ func runCloud(opts *CloudOptions) error { } // Check for token - if opts.tokenHolder == nil || opts.tokenHolder.Get() == "" { + if opts.creds == nil || opts.creds.Get() == "" { return fmt.Errorf("not authenticated. Run 'astro login' to authenticate") } @@ -234,11 +234,11 @@ func runCloud(opts *CloudOptions) error { // Generate curl command if requested if opts.GenerateCurl { - return generateCurl(opts.Out, method, url, opts.tokenHolder.Get(), opts.RequestHeaders, params, opts.RequestInputFile) + return generateCurl(opts.Out, method, url, opts.creds.Get(), opts.RequestHeaders, params, opts.RequestInputFile) } // Build and execute the request - return executeRequest(&opts.RequestOptions, method, url, opts.tokenHolder.Get(), params) + return executeRequest(&opts.RequestOptions, method, url, opts.creds.Get(), params) } // isOperationID checks if the input looks like an operation ID rather than a path. diff --git a/cmd/auth.go b/cmd/auth.go index 0a31e4581..90881a85e 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -12,8 +12,8 @@ import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" cloudAuth "github.com/astronomer/astro-cli/cloud/auth" "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/domainutil" - "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/keychain" softwareAuth "github.com/astronomer/astro-cli/software/auth" ) @@ -30,8 +30,8 @@ var ( ) // newLoginCommand is a top-level alias for "astro auth login" kept for backward compatibility. -func newLoginCommand(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { - cmd := newAuthLoginCommand(store, tokenHolder, coreClient, platformCoreClient, out) +func newLoginCommand(store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { + cmd := newAuthLoginCommand(store, creds, coreClient, platformCoreClient, out) cmd.Long = "Authenticate to Astro or Astro Private Cloud. This is an alias for 'astro auth login'." return cmd } @@ -43,7 +43,7 @@ func newLogoutCommand(store keychain.SecureStore, out io.Writer) *cobra.Command return cmd } -func login(cmd *cobra.Command, args []string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) error { +func login(cmd *cobra.Command, args []string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) error { // Silence Usage as we have now validated command input cmd.SilenceUsage = true @@ -57,15 +57,15 @@ func login(cmd *cobra.Command, args []string, store keychain.SecureStore, tokenH } return softwareLogin(args[0], oAuth, "", "", houstonVersion, store, houstonClient, out) } - return cloudLogin(args[0], token, store, tokenHolder, coreClient, platformCoreClient, out, shouldDisplayLoginLink) + return cloudLogin(args[0], token, store, creds, coreClient, platformCoreClient, out, shouldDisplayLoginLink) } // Log back into the current context in case no domain is passed ctx, err := context.GetCurrentContext() if err != nil || ctx.Domain == "" { // Default case when no domain is passed, and error getting current context - return cloudLogin(domainutil.DefaultDomain, token, store, tokenHolder, coreClient, platformCoreClient, out, shouldDisplayLoginLink) + return cloudLogin(domainutil.DefaultDomain, token, store, creds, coreClient, platformCoreClient, out, shouldDisplayLoginLink) } else if context.IsCloudDomain(ctx.Domain) { - return cloudLogin(ctx.Domain, token, store, tokenHolder, coreClient, platformCoreClient, out, shouldDisplayLoginLink) + return cloudLogin(ctx.Domain, token, store, creds, coreClient, platformCoreClient, out, shouldDisplayLoginLink) } return softwareLogin(ctx.Domain, oAuth, "", "", houstonVersion, store, houstonClient, out) } @@ -93,28 +93,28 @@ func logout(cmd *cobra.Command, args []string, store keychain.SecureStore, out i return nil } -func newAuthRootCmd(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { +func newAuthRootCmd(store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { cmd := &cobra.Command{ Use: "auth", Short: "Manage authentication to Astronomer", Long: "Commands for authenticating to Astro or Astro Private Cloud", } cmd.AddCommand( - newAuthLoginCommand(store, tokenHolder, coreClient, platformCoreClient, out), + newAuthLoginCommand(store, creds, coreClient, platformCoreClient, out), newAuthLogoutCommand(store, out), newAuthTokenCommand(store, out), ) return cmd } -func newAuthLoginCommand(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { +func newAuthLoginCommand(store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) *cobra.Command { cmd := &cobra.Command{ Use: "login [BASEDOMAIN]", Short: "Log in to Astronomer", Long: "Authenticate to Astro or Astro Private Cloud", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return login(cmd, args, store, tokenHolder, coreClient, platformCoreClient, out) + return login(cmd, args, store, creds, coreClient, platformCoreClient, out) }, } diff --git a/cmd/auth_test.go b/cmd/auth_test.go index 24e26e2f9..f64bd5e24 100644 --- a/cmd/auth_test.go +++ b/cmd/auth_test.go @@ -11,7 +11,7 @@ import ( astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/houston" - "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -28,7 +28,7 @@ func (s *CmdSuite) TestLogin() { cloudDomain := "astronomer.io" softwareDomain := "astronomer_dev.com" - cloudLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + cloudLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { s.Equal(cloudDomain, domain) return nil } diff --git a/cmd/cloud/dbt.go b/cmd/cloud/dbt.go index 201187806..ced491972 100644 --- a/cmd/cloud/dbt.go +++ b/cmd/cloud/dbt.go @@ -138,6 +138,7 @@ func deployDbt(cmd *cobra.Command, args []string) error { WaitTime: waitTime, PlatformCoreClient: platformCoreClient, CoreClient: astroCoreClient, + Creds: creds, } return DeployBundle(deployBundleInput) } diff --git a/cmd/cloud/deploy.go b/cmd/cloud/deploy.go index c62de34ac..0da8355f0 100644 --- a/cmd/cloud/deploy.go +++ b/cmd/cloud/deploy.go @@ -169,6 +169,7 @@ func deploy(cmd *cobra.Command, args []string) error { Description: deployDescription, BuildSecretString: BuildSecretString, Force: forceDeploy, + Creds: creds, } return DeployImage(deployInput, platformCoreClient, astroCoreClient) diff --git a/cmd/cloud/deployment.go b/cmd/cloud/deployment.go index 559379f5f..36349bf86 100644 --- a/cmd/cloud/deployment.go +++ b/cmd/cloud/deployment.go @@ -820,7 +820,7 @@ func deploymentCreate(cmd *cobra.Command, _ []string, out io.Writer) error { //n if disallowedFlagSet { return errFlag } - return fromfile.CreateOrUpdate(inputFile, cmd.Name(), platformCoreClient, astroCoreClient, out, waitForStatus, waitTimeForDeployment, forceUpdate) + return fromfile.CreateOrUpdate(inputFile, cmd.Name(), platformCoreClient, astroCoreClient, out, waitForStatus, waitTimeForDeployment, forceUpdate, creds) } if dagDeploy != "" && !(dagDeploy == enable || dagDeploy == disable) { @@ -869,7 +869,7 @@ func deploymentUpdate(cmd *cobra.Command, args []string, out io.Writer) error { // other flags were requested return errFlag } - return fromfile.CreateOrUpdate(inputFile, cmd.Name(), platformCoreClient, astroCoreClient, out, false, 0*time.Second, forceUpdate) + return fromfile.CreateOrUpdate(inputFile, cmd.Name(), platformCoreClient, astroCoreClient, out, false, 0*time.Second, forceUpdate, creds) } if dagDeploy != "" && !(dagDeploy == enable || dagDeploy == disable) { return errors.New("Invalid --dag-deploy value") @@ -905,7 +905,7 @@ func deploymentUpdate(cmd *cobra.Command, args []string, out io.Writer) error { deploymentID = args[0] } - return deployment.Update(deploymentID, label, ws, description, deploymentName, dagDeploy, executor, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCPU, defaultTaskPodMemory, resourceQuotaCPU, resourceQuotaMemory, workloadIdentity, updateSchedulerAU, updateSchedulerReplicas, []astroplatformcore.WorkerQueueRequest{}, []astroplatformcore.HybridWorkerQueueRequest{}, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, allowedIPAddressRanges, taskLogBucket, taskLogURLPattern, forceUpdate, astroCoreClient, platformCoreClient) + return deployment.Update(deploymentID, label, ws, description, deploymentName, dagDeploy, executor, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCPU, defaultTaskPodMemory, resourceQuotaCPU, resourceQuotaMemory, workloadIdentity, updateSchedulerAU, updateSchedulerReplicas, []astroplatformcore.WorkerQueueRequest{}, []astroplatformcore.HybridWorkerQueueRequest{}, []astroplatformcore.DeploymentEnvironmentVariableRequest{}, allowedIPAddressRanges, taskLogBucket, taskLogURLPattern, forceUpdate, astroCoreClient, platformCoreClient, creds) } func validateCICD() error { diff --git a/cmd/cloud/remote.go b/cmd/cloud/remote.go index 100560535..fe74e552c 100644 --- a/cmd/cloud/remote.go +++ b/cmd/cloud/remote.go @@ -84,6 +84,7 @@ func remoteDeploy(cmd *cobra.Command, args []string) error { Platform: remotePlatform, BuildSecretString: buildSecretString, DeploymentID: remoteDeploymentID, + Creds: creds, } return cloud.DeployClientImage(deployInput, platformCoreClient) diff --git a/cmd/cloud/root.go b/cmd/cloud/root.go index e7ab48538..85ab8499a 100644 --- a/cmd/cloud/root.go +++ b/cmd/cloud/root.go @@ -9,6 +9,7 @@ import ( astrocore "github.com/astronomer/astro-cli/astro-client-core" astroiamcore "github.com/astronomer/astro-cli/astro-client-iam-core" astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" + "github.com/astronomer/astro-cli/pkg/credentials" ) var ( @@ -16,14 +17,16 @@ var ( astroCoreIamClient astroiamcore.CoreClient platformCoreClient astroplatformcore.CoreClient airflowAPIClient airflow.Client + creds *credentials.CurrentCredentials ) // AddCmds adds all the command initialized in this package for the cmd package to import -func AddCmds(astroPlatformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient, airflowClient airflow.Client, iamCoreClient astroiamcore.CoreClient, out io.Writer) []*cobra.Command { +func AddCmds(astroPlatformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient, airflowClient airflow.Client, iamCoreClient astroiamcore.CoreClient, c *credentials.CurrentCredentials, out io.Writer) []*cobra.Command { astroCoreClient = coreClient platformCoreClient = astroPlatformCoreClient astroCoreIamClient = iamCoreClient airflowAPIClient = airflowClient + creds = c return []*cobra.Command{ NewDeployCmd(), newDeploymentRootCmd(out), diff --git a/cmd/cloud/root_test.go b/cmd/cloud/root_test.go index b3312279a..fee4fa922 100644 --- a/cmd/cloud/root_test.go +++ b/cmd/cloud/root_test.go @@ -12,7 +12,7 @@ import ( func TestAddCmds(t *testing.T) { testUtil.InitTestConfig(testUtil.LocalPlatform) buf := new(bytes.Buffer) - cmds := AddCmds(nil, nil, nil, nil, buf) + cmds := AddCmds(nil, nil, nil, nil, nil, buf) for cmdIdx := range cmds { assert.Contains(t, []string{"deployment", "deploy DEPLOYMENT-ID", "workspace", "user", "organization", "dbt", "ide", "remote"}, cmds[cmdIdx].Use) } diff --git a/cmd/cloud/setup.go b/cmd/cloud/setup.go index 223a99f53..e02682d66 100644 --- a/cmd/cloud/setup.go +++ b/cmd/cloud/setup.go @@ -21,6 +21,7 @@ import ( "github.com/astronomer/astro-cli/cloud/deployment" "github.com/astronomer/astro-cli/cloud/organization" "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/pkg/logger" @@ -66,7 +67,7 @@ type CustomClaims struct { } //nolint:gocognit -func Setup(cmd *cobra.Command, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient) error { +func Setup(cmd *cobra.Command, store keychain.SecureStore, creds *credentials.CurrentCredentials, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient) error { // If the user is trying to login or logout no need to go through auth setup. if cmd.CalledAs() == "login" || cmd.CalledAs() == "logout" { return nil @@ -110,7 +111,7 @@ func Setup(cmd *cobra.Command, store keychain.SecureStore, tokenHolder *httputil } // Check for APITokens before API keys or refresh tokens - apiToken, err := checkAPIToken(isDeploymentFile, tokenHolder, platformCoreClient) + apiToken, err := checkAPIToken(isDeploymentFile, creds, platformCoreClient) if err != nil { return err } @@ -119,14 +120,14 @@ func Setup(cmd *cobra.Command, store keychain.SecureStore, tokenHolder *httputil } // run auth setup for any command that requires auth - apiKey, err := checkAPIKeys(platformCoreClient, tokenHolder, isDeploymentFile) + apiKey, err := checkAPIKeys(platformCoreClient, creds, isDeploymentFile) if err != nil { return err } if apiKey { return nil } - err = checkToken(store, tokenHolder, coreClient, platformCoreClient, os.Stdout) + err = checkToken(store, creds, coreClient, platformCoreClient, os.Stdout) if err != nil { return err } @@ -134,30 +135,30 @@ func Setup(cmd *cobra.Command, store keychain.SecureStore, tokenHolder *httputil return nil } -func checkToken(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) error { +func checkToken(store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer) error { c, err := context.GetCurrentContext() if err != nil { return err } - creds, err := store.GetCredentials(c.Domain) - if err != nil || creds.Token == "" { - return authLogin(c.Domain, "", store, tokenHolder, coreClient, platformCoreClient, out, false) + keyCreds, err := store.GetCredentials(c.Domain) + if err != nil || keyCreds.Token == "" { + return authLogin(c.Domain, "", store, creds, coreClient, platformCoreClient, out, false) } - if isExpired(creds.ExpiresAt, accessTokenExpThreshold) { + if isExpired(keyCreds.ExpiresAt, accessTokenExpThreshold) { authConfig, err := auth.FetchDomainAuthConfig(c.Domain) if err != nil { return err } - res, err := refresh(creds.RefreshToken, authConfig) + res, err := refresh(keyCreds.RefreshToken, authConfig) if err != nil { - return authLogin(c.Domain, "", store, tokenHolder, coreClient, platformCoreClient, out, false) + return authLogin(c.Domain, "", store, creds, coreClient, platformCoreClient, out, false) } newCreds := keychain.Credentials{ Token: "Bearer " + res.AccessToken, - RefreshToken: creds.RefreshToken, - UserEmail: creds.UserEmail, + RefreshToken: keyCreds.RefreshToken, + UserEmail: keyCreds.UserEmail, ExpiresAt: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second), } if res.RefreshToken != "" { @@ -166,11 +167,11 @@ func checkToken(store keychain.SecureStore, tokenHolder *httputil.TokenHolder, c if err := store.SetCredentials(c.Domain, newCreds); err != nil { return err } - tokenHolder.Set(newCreds.Token) + creds.Set(newCreds.Token) return nil } - tokenHolder.Set(creds.Token) + creds.Set(keyCreds.Token) return nil } @@ -219,7 +220,7 @@ func refresh(refreshToken string, authConfig auth.Config) (TokenResponse, error) return tokenRes, nil } -func checkAPIKeys(platformCoreClient astroplatformcore.CoreClient, tokenHolder *httputil.TokenHolder, isDeploymentFile bool) (bool, error) { +func checkAPIKeys(platformCoreClient astroplatformcore.CoreClient, creds *credentials.CurrentCredentials, isDeploymentFile bool) (bool, error) { // check os variables astronomerKeyID := os.Getenv("ASTRONOMER_KEY_ID") astronomerKeySecret := os.Getenv("ASTRONOMER_KEY_SECRET") @@ -300,7 +301,7 @@ func checkAPIKeys(platformCoreClient astroplatformcore.CoreClient, tokenHolder * return false, errors.New(tokenRes.ErrorDescription) } - tokenHolder.Set("Bearer " + tokenRes.AccessToken) + creds.Set("Bearer " + tokenRes.AccessToken) orgs, err := organization.ListOrganizations(platformCoreClient) if err != nil { @@ -330,7 +331,7 @@ func checkAPIKeys(platformCoreClient astroplatformcore.CoreClient, tokenHolder * return true, nil } -func checkAPIToken(isDeploymentFile bool, tokenHolder *httputil.TokenHolder, platformCoreClient astroplatformcore.CoreClient) (bool, error) { +func checkAPIToken(isDeploymentFile bool, creds *credentials.CurrentCredentials, platformCoreClient astroplatformcore.CoreClient) (bool, error) { // check os variables astroAPIToken := os.Getenv("ASTRO_API_TOKEN") if astroAPIToken == "" { @@ -367,7 +368,7 @@ func checkAPIToken(isDeploymentFile bool, tokenHolder *httputil.TokenHolder, pla } } - tokenHolder.Set("Bearer " + astroAPIToken) + creds.Set("Bearer " + astroAPIToken) // Parse the token to peek at the custom claims claims, err := parseAPIToken(astroAPIToken) diff --git a/cmd/cloud/setup_test.go b/cmd/cloud/setup_test.go index 1e1bb6fea..6f609cefe 100644 --- a/cmd/cloud/setup_test.go +++ b/cmd/cloud/setup_test.go @@ -20,7 +20,7 @@ import ( astroplatformcore_mocks "github.com/astronomer/astro-cli/astro-client-platform-core/mocks" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/context" - "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" "github.com/astronomer/astro-cli/pkg/util" @@ -69,11 +69,11 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, nil, nil) + err = Setup(cmd, keychain.NewTestStore(), &credentials.CurrentCredentials{}, nil, nil) assert.NoError(t, err) }) @@ -90,11 +90,11 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, nil, nil) + err = Setup(cmd, keychain.NewTestStore(), &credentials.CurrentCredentials{}, nil, nil) assert.NoError(t, err) }) @@ -170,11 +170,11 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "deployment"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, nil, nil) + err = Setup(cmd, keychain.NewTestStore(), &credentials.CurrentCredentials{}, nil, nil) assert.NoError(t, err) }) @@ -190,11 +190,11 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, nil, nil) + err = Setup(cmd, keychain.NewTestStore(), &credentials.CurrentCredentials{}, nil, nil) assert.NoError(t, err) }) @@ -243,7 +243,7 @@ func TestSetup(t *testing.T) { t.Setenv("ASTRO_API_TOKEN", "token") - err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, mockPlatformCoreClient, mockCoreClient) + err = Setup(cmd, keychain.NewTestStore(), &credentials.CurrentCredentials{}, mockPlatformCoreClient, mockCoreClient) assert.NoError(t, err) mockPlatformCoreClient.AssertExpectations(t) }) @@ -267,7 +267,7 @@ func TestSetup(t *testing.T) { t.Setenv("ASTRO_API_TOKEN", "bad token") - err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, mockPlatformCoreClient, mockCoreClient) + err = Setup(cmd, keychain.NewTestStore(), &credentials.CurrentCredentials{}, mockPlatformCoreClient, mockCoreClient) assert.Error(t, err) }) @@ -284,13 +284,13 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } t.Setenv("ASTRO_API_TOKEN", "") - err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, mockPlatformCoreClient, mockCoreClient) + err = Setup(cmd, keychain.NewTestStore(), &credentials.CurrentCredentials{}, mockPlatformCoreClient, mockCoreClient) assert.NoError(t, err) }) @@ -320,7 +320,7 @@ func TestSetup(t *testing.T) { rootCmd := &cobra.Command{Use: "astro"} rootCmd.AddCommand(cmd) - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -341,7 +341,7 @@ func TestSetup(t *testing.T) { Header: make(http.Header), } }) - err = Setup(cmd, keychain.NewTestStore(), &httputil.TokenHolder{}, mockPlatformCoreClient, mockCoreClient) + err = Setup(cmd, keychain.NewTestStore(), &credentials.CurrentCredentials{}, mockPlatformCoreClient, mockCoreClient) assert.NoError(t, err) mockPlatformCoreClient.AssertExpectations(t) mockCoreClient.AssertExpectations(t) @@ -371,7 +371,7 @@ func TestCheckAPIKeys(t *testing.T) { mockPlatformCoreClient.On("ListOrganizationsWithResponse", mock.Anything, mock.Anything).Return(&mockOrgsResponse, nil).Once() mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Once() - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -398,7 +398,7 @@ func TestCheckAPIKeys(t *testing.T) { err = context.Switch(domain) assert.NoError(t, err) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} _, err = checkAPIKeys(mockPlatformCoreClient, holder, false) assert.NoError(t, err) mockPlatformCoreClient.AssertExpectations(t) @@ -411,20 +411,20 @@ func TestCheckToken(t *testing.T) { mockPlatformCoreClient := new(astroplatformcore_mocks.ClientWithResponsesInterface) t.Run("test check token", func(t *testing.T) { mockCoreClient := new(astrocore_mocks.ClientWithResponsesInterface) - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} err := checkToken(keychain.NewTestStore(), holder, mockCoreClient, mockPlatformCoreClient, nil) assert.NoError(t, err) }) t.Run("trigger login when no token is found", func(t *testing.T) { mockCoreClient := new(astrocore_mocks.ClientWithResponsesInterface) - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return errorLogin } - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} err := checkToken(keychain.NewTestStore(), holder, mockCoreClient, mockPlatformCoreClient, nil) assert.Contains(t, err.Error(), "failed to login") }) @@ -434,7 +434,7 @@ func TestCheckToken(t *testing.T) { store := keychain.NewTestStore() _ = store.SetCredentials("astronomer.io", keychain.Credentials{Token: "Bearer tok", ExpiresAt: time.Now().Add(time.Hour)}) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} err := checkToken(store, holder, mockCoreClient, mockPlatformCoreClient, nil) assert.NoError(t, err) assert.Equal(t, "Bearer tok", holder.Get()) @@ -477,7 +477,7 @@ func TestCheckAPIToken(t *testing.T) { mockPlatformCoreClient := new(astroplatformcore_mocks.ClientWithResponsesInterface) t.Run("test context switch", func(t *testing.T) { - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -494,13 +494,13 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.NoError(t, err) }) t.Run("failed to parse api token", func(t *testing.T) { - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -517,12 +517,12 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.Error(t, err) }) t.Run("unable to fetch current context", func(t *testing.T) { - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -536,7 +536,7 @@ func TestCheckAPIToken(t *testing.T) { err := config.ResetCurrentContext() assert.NoError(t, err) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.NoError(t, err) }) @@ -556,7 +556,7 @@ func TestCheckAPIToken(t *testing.T) { }, } - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -573,7 +573,7 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} _, err = checkAPIToken(false, holder, mockPlatformCoreClient) assert.ErrorIs(t, err, errNotAPIToken) }) @@ -596,7 +596,7 @@ func TestCheckAPIToken(t *testing.T) { }, } - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -613,7 +613,7 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.ErrorIs(t, err, errExpiredAPIToken) }) @@ -635,7 +635,7 @@ func TestCheckAPIToken(t *testing.T) { }, } - authLogin = func(domain, token string, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + authLogin = func(domain, token string, store keychain.SecureStore, creds *credentials.CurrentCredentials, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { return nil } @@ -652,7 +652,7 @@ func TestCheckAPIToken(t *testing.T) { err := context.Switch(domain) assert.NoError(t, err) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} _, err = checkAPIToken(true, holder, mockPlatformCoreClient) assert.NoError(t, err) }) diff --git a/cmd/root.go b/cmd/root.go index b6494de18..4521bbe43 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -19,6 +19,7 @@ import ( "github.com/astronomer/astro-cli/houston" "github.com/astronomer/astro-cli/internal/telemetry" "github.com/astronomer/astro-cli/pkg/ansi" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" "github.com/astronomer/astro-cli/pkg/keychain" ) @@ -38,16 +39,16 @@ const ( // NewRootCmd adds all of the primary commands for the cli func NewRootCmd() *cobra.Command { var err error - tokenHolder := &httputil.TokenHolder{} + creds := &credentials.CurrentCredentials{} store, storeErr := newSecureStore() httpClient := houston.NewHTTPClient() - houstonClient = houston.NewClient(httpClient, tokenHolder) + houstonClient = houston.NewClient(httpClient, creds) - airflowClient := airflowclient.NewAirflowClient(httputil.NewHTTPClient(), tokenHolder) - astroCoreClient := astrocore.NewCoreClient(httputil.NewHTTPClient(), tokenHolder) - astroCoreIamClient := astroiamcore.NewIamCoreClient(httputil.NewHTTPClient(), tokenHolder) - platformCoreClient := platformclient.NewPlatformCoreClient(httputil.NewHTTPClient(), tokenHolder) + airflowClient := airflowclient.NewAirflowClient(httputil.NewHTTPClient(), creds) + astroCoreClient := astrocore.NewCoreClient(httputil.NewHTTPClient(), creds) + astroCoreIamClient := astroiamcore.NewIamCoreClient(httputil.NewHTTPClient(), creds) + platformCoreClient := platformclient.NewPlatformCoreClient(httputil.NewHTTPClient(), creds) ctx := cloudPlatform isCloudCtx := context.IsCloudContext() @@ -79,29 +80,29 @@ Welcome to the Astro CLI, the modern command line interface for data orchestrati } return utils.ChainRunEs( SetupLogging, - CreateRootPersistentPreRunE(storeErr, store, tokenHolder, astroCoreClient, platformCoreClient), + CreateRootPersistentPreRunE(storeErr, store, creds, astroCoreClient, platformCoreClient), telemetry.CreateTrackingHook(), )(cmd, args) }, } rootCmd.AddCommand( - newLoginCommand(store, tokenHolder, astroCoreClient, platformCoreClient, os.Stdout), + newLoginCommand(store, creds, astroCoreClient, platformCoreClient, os.Stdout), newLogoutCommand(store, os.Stdout), - newAuthRootCmd(store, tokenHolder, astroCoreClient, platformCoreClient, os.Stdout), + newAuthRootCmd(store, creds, astroCoreClient, platformCoreClient, os.Stdout), newVersionCommand(), newDevRootCmd(platformCoreClient, astroCoreClient, store), newContextCmd(os.Stdout), newConfigRootCmd(os.Stdout), newRunCommand(), - api.NewAPICmd(tokenHolder), + api.NewAPICmd(creds), newTelemetryCmd(os.Stdout), newTelemetrySendCmd(), ) if context.IsCloudContext() { // Include all the commands to be exposed for cloud users rootCmd.AddCommand( - cloudCmd.AddCmds(platformCoreClient, astroCoreClient, airflowClient, astroCoreIamClient, os.Stdout)..., + cloudCmd.AddCmds(platformCoreClient, astroCoreClient, airflowClient, astroCoreIamClient, creds, os.Stdout)..., ) } else { // Include all the commands to be exposed for software users rootCmd.AddCommand( diff --git a/cmd/root_hooks.go b/cmd/root_hooks.go index 8225463b2..9cf160319 100644 --- a/cmd/root_hooks.go +++ b/cmd/root_hooks.go @@ -16,7 +16,7 @@ import ( softwareCmd "github.com/astronomer/astro-cli/cmd/software" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/context" - "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/keychain" "github.com/astronomer/astro-cli/version" ) @@ -29,7 +29,7 @@ func SetupLogging(_ *cobra.Command, _ []string) error { // CreateRootPersistentPreRunE takes clients as arguments and returns a cobra // pre-run hook that sets up the context and checks for the latest version. -func CreateRootPersistentPreRunE(storeErr error, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, astroCoreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) func(cmd *cobra.Command, args []string) error { +func CreateRootPersistentPreRunE(storeErr error, store keychain.SecureStore, creds *credentials.CurrentCredentials, astroCoreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error { // login/logout don't need existing credentials, skip auth setup if cmd.CalledAs() == "login" || cmd.CalledAs() == "logout" { @@ -56,19 +56,19 @@ func CreateRootPersistentPreRunE(storeErr error, store keychain.SecureStore, tok } if context.IsCloudContext() { - if err := handleCloudSetup(cmd, store, tokenHolder, platformCoreClient, astroCoreClient); err != nil { + if err := handleCloudSetup(cmd, store, creds, platformCoreClient, astroCoreClient); err != nil { return err } } else { - loadSoftwareToken(store, tokenHolder) + loadSoftwareToken(store, creds) } softwareCmd.PrintDebugLogs() return nil } } -func handleCloudSetup(cmd *cobra.Command, store keychain.SecureStore, tokenHolder *httputil.TokenHolder, platformCoreClient astroplatformcore.CoreClient, astroCoreClient astrocore.CoreClient) error { - err := cloudCmd.Setup(cmd, store, tokenHolder, platformCoreClient, astroCoreClient) +func handleCloudSetup(cmd *cobra.Command, store keychain.SecureStore, creds *credentials.CurrentCredentials, platformCoreClient astroplatformcore.CoreClient, astroCoreClient astrocore.CoreClient) error { + err := cloudCmd.Setup(cmd, store, creds, platformCoreClient, astroCoreClient) if err == nil { return nil } @@ -82,7 +82,7 @@ func handleCloudSetup(cmd *cobra.Command, store keychain.SecureStore, tokenHolde return nil } -func loadSoftwareToken(store keychain.SecureStore, tokenHolder *httputil.TokenHolder) { +func loadSoftwareToken(store keychain.SecureStore, creds *credentials.CurrentCredentials) { if store == nil { return } @@ -90,7 +90,7 @@ func loadSoftwareToken(store keychain.SecureStore, tokenHolder *httputil.TokenHo if err != nil { return } - if creds, credErr := store.GetCredentials(c.Domain); credErr == nil { - tokenHolder.Set(creds.Token) + if keyCreds, credErr := store.GetCredentials(c.Domain); credErr == nil { + creds.Set(keyCreds.Token) } } diff --git a/cmd/root_hooks_test.go b/cmd/root_hooks_test.go index ed1fb661d..84633f660 100644 --- a/cmd/root_hooks_test.go +++ b/cmd/root_hooks_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/keychain" testUtil "github.com/astronomer/astro-cli/pkg/testing" ) @@ -16,7 +16,7 @@ func TestLoadSoftwareToken_LoadsToken(t *testing.T) { err := store.SetCredentials("astronomer_dev.com", keychain.Credentials{Token: "test-token"}) assert.NoError(t, err) - holder := &httputil.TokenHolder{} + holder := &credentials.CurrentCredentials{} loadSoftwareToken(store, holder) assert.Equal(t, "test-token", holder.Get()) } diff --git a/houston/houston.go b/houston/houston.go index 5903dc771..166e01afb 100644 --- a/houston/houston.go +++ b/houston/houston.go @@ -15,6 +15,7 @@ import ( "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/pkg/credentials" "github.com/astronomer/astro-cli/pkg/httputil" ) @@ -107,8 +108,8 @@ type ClientImplementation struct { // NewClient - initialized the Houston Client object with proper HTTP Client configuration // set as a variable so we can change it to return mock houston clients in tests -var NewClient = func(c *httputil.HTTPClient, tokenHolder *httputil.TokenHolder) ClientInterface { - client := newInternalClient(c, tokenHolder) +var NewClient = func(c *httputil.HTTPClient, creds *credentials.CurrentCredentials) ClientInterface { + client := newInternalClient(c, creds) return &ClientImplementation{ client: client, } @@ -116,8 +117,8 @@ var NewClient = func(c *httputil.HTTPClient, tokenHolder *httputil.TokenHolder) // Client containers the logger and HTTPClient used to communicate with the HoustonAPI type Client struct { - HTTPClient *httputil.HTTPClient - tokenHolder *httputil.TokenHolder + HTTPClient *httputil.HTTPClient + creds *credentials.CurrentCredentials } func NewHTTPClient() *httputil.HTTPClient { @@ -136,10 +137,10 @@ func NewHTTPClient() *httputil.HTTPClient { } // newInternalClient returns a new Client with the logger and HTTP Client setup. -func newInternalClient(c *httputil.HTTPClient, tokenHolder *httputil.TokenHolder) *Client { +func newInternalClient(c *httputil.HTTPClient, creds *credentials.CurrentCredentials) *Client { return &Client{ - HTTPClient: c, - tokenHolder: tokenHolder, + HTTPClient: c, + creds: creds, } } @@ -182,8 +183,8 @@ func (c *Client) Do(doOpts *httputil.DoOptions) (*Response, error) { // DoWithContext executes a query against the Houston API, logging out any errors contained in the response object func (c *Client) DoWithContext(doOpts *httputil.DoOptions, ctx *config.Context) (*Response, error) { // set headers - if c.tokenHolder != nil { - if tok := c.tokenHolder.Get(); tok != "" { + if c.creds != nil { + if tok := c.creds.Get(); tok != "" { doOpts.Headers["authorization"] = tok } } diff --git a/pkg/httputil/token_holder.go b/pkg/credentials/credentials.go similarity index 59% rename from pkg/httputil/token_holder.go rename to pkg/credentials/credentials.go index 7554f517d..6390a06cb 100644 --- a/pkg/httputil/token_holder.go +++ b/pkg/credentials/credentials.go @@ -1,33 +1,33 @@ -package httputil +package credentials import "sync" -// TokenHolder holds the current auth token in memory for the duration of a +// CurrentCredentials holds the current auth token in memory for the duration of a // command invocation. It is populated by PersistentPreRunE after credentials // are resolved from the secure store, and read by API client request editors // on every outbound request. // // It is constructed once in NewRootCmd and passed by pointer to both the API // clients and CreateRootPersistentPreRunE. There is no global state. -type TokenHolder struct { +type CurrentCredentials struct { mu sync.RWMutex token string } -// NewTokenHolder creates a TokenHolder with an initial token value. -func NewTokenHolder(token string) *TokenHolder { - return &TokenHolder{token: token} +// New creates a CurrentCredentials with an initial token value. +func New(token string) *CurrentCredentials { + return &CurrentCredentials{token: token} } // Set stores the token. -func (h *TokenHolder) Set(token string) { +func (h *CurrentCredentials) Set(token string) { h.mu.Lock() h.token = token h.mu.Unlock() } // Get returns the current token. -func (h *TokenHolder) Get() string { +func (h *CurrentCredentials) Get() string { h.mu.RLock() defer h.mu.RUnlock() return h.token diff --git a/pkg/httputil/token_holder_test.go b/pkg/credentials/credentials_test.go similarity index 55% rename from pkg/httputil/token_holder_test.go rename to pkg/credentials/credentials_test.go index c21a08376..132299adf 100644 --- a/pkg/httputil/token_holder_test.go +++ b/pkg/credentials/credentials_test.go @@ -1,15 +1,15 @@ -package httputil_test +package credentials_test import ( "testing" "github.com/stretchr/testify/assert" - "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/credentials" ) -func TestTokenHolder(t *testing.T) { - h := &httputil.TokenHolder{} +func TestCurrentCredentials(t *testing.T) { + h := &credentials.CurrentCredentials{} assert.Equal(t, "", h.Get()) h.Set("Bearer abc")