diff --git a/cli/internal/aks/deploy/cilium.go b/cli/internal/aks/deploy/cilium.go index 38fb546..a3131f6 100644 --- a/cli/internal/aks/deploy/cilium.go +++ b/cli/internal/aks/deploy/cilium.go @@ -3,14 +3,12 @@ package deploy import ( "context" "errors" - "fmt" "log" - "net/url" "os" "os/exec" "github.com/Azure/aks-flex/plugin/pkg/util/config" - "k8s.io/client-go/tools/clientcmd" + "github.com/Azure/aks-flex/plugin/pkg/util/k8s" ) var ciliumInstallInstruction = errors.New( @@ -32,7 +30,8 @@ func deployCilium( kubeconfigFile string, cfg *config.Config, ) error { - k8sServiceHost, k8sServicePort, err := kubeconfigAPIServer(kubeconfigFile) + clusterContext := cfg.ClusterName + "-admin" + k8sServiceHost, k8sServicePort, err := k8s.APIServerFromKubeconfigFile(kubeconfigFile, clusterContext) if err != nil { return err } @@ -41,7 +40,7 @@ func deployCilium( ctx, "cilium", "install", "--kubeconfig", kubeconfigFile, - "--context", cfg.ClusterName+"-admin", + "--context", clusterContext, "--namespace", "kube-system", "--datapath-mode", "aks-byocni", "--helm-set", "aksbyocni.enabled=true", @@ -58,52 +57,7 @@ func deployCilium( "KUBECONFIG="+kubeconfigFile, "PATH="+os.Getenv("PATH"), ) - log.Printf("Running: cilium install --kubeconfig %s --context %s --namespace kube-system --datapath-mode aks-byocni --helm-set aksbyocni.enabled=true --helm-set cluster.name=%s --helm-set operator.replicas=1 --helm-set kubeProxyReplacement=true --helm-set k8sServiceHost=%s --helm-set k8sServicePort=%s", kubeconfigFile, cfg.ClusterName+"-admin", cfg.ClusterName, k8sServiceHost, k8sServicePort) + log.Printf("Running: cilium install --kubeconfig %s --context %s --namespace kube-system --datapath-mode aks-byocni --helm-set aksbyocni.enabled=true --helm-set cluster.name=%s --helm-set operator.replicas=1 --helm-set kubeProxyReplacement=true --helm-set k8sServiceHost=%s --helm-set k8sServicePort=%s", kubeconfigFile, clusterContext, cfg.ClusterName, k8sServiceHost, k8sServicePort) return cmd.Run() } - -func kubeconfigAPIServer(kubeconfigFile string) (string, string, error) { - kcfg, err := clientcmd.LoadFromFile(kubeconfigFile) - if err != nil { - return "", "", fmt.Errorf("loading kubeconfig for cilium install: %w", err) - } - - ctxName := kcfg.CurrentContext - if ctxName == "" { - return "", "", errors.New("kubeconfig missing current context") - } - - ctxCfg, ok := kcfg.Contexts[ctxName] - if !ok || ctxCfg == nil { - return "", "", fmt.Errorf("kubeconfig missing context %q", ctxName) - } - - clusterCfg, ok := kcfg.Clusters[ctxCfg.Cluster] - if !ok || clusterCfg == nil { - return "", "", fmt.Errorf("kubeconfig missing cluster %q", ctxCfg.Cluster) - } - - u, err := url.Parse(clusterCfg.Server) - if err != nil { - return "", "", fmt.Errorf("parsing API server URL %q: %w", clusterCfg.Server, err) - } - - hostname := u.Hostname() - port := u.Port() - if hostname == "" { - return "", "", fmt.Errorf("API server URL missing hostname: %q", clusterCfg.Server) - } - if port == "" { - switch u.Scheme { - case "https": - port = "443" - case "http": - port = "80" - default: - return "", "", fmt.Errorf("API server URL missing port and unsupported scheme %q", u.Scheme) - } - } - - return hostname, port, nil -} diff --git a/cli/internal/aks/deploy/cilium_test.go b/cli/internal/aks/deploy/cilium_test.go index c5ec79e..c8706fb 100644 --- a/cli/internal/aks/deploy/cilium_test.go +++ b/cli/internal/aks/deploy/cilium_test.go @@ -4,6 +4,8 @@ import ( "os" "path/filepath" "testing" + + k8sutil "github.com/Azure/aks-flex/plugin/pkg/util/k8s" ) func TestKubeconfigAPIServer(t *testing.T) { @@ -29,9 +31,9 @@ users: t.Fatalf("write kubeconfig: %v", err) } - host, port, err := kubeconfigAPIServer(kubeconfig) + host, port, err := k8sutil.APIServerFromKubeconfigFile(kubeconfig, "") if err != nil { - t.Fatalf("kubeconfigAPIServer returned error: %v", err) + t.Fatalf("APIServerFromKubeconfigFile returned error: %v", err) } if host != "example.hcp.eastus2.azmk8s.io" { t.Fatalf("unexpected host %q", host) diff --git a/cli/internal/config/configcmd/defaults.go b/cli/internal/config/configcmd/defaults.go index a3f310b..d366141 100644 --- a/cli/internal/config/configcmd/defaults.go +++ b/cli/internal/config/configcmd/defaults.go @@ -2,10 +2,8 @@ package configcmd import ( "context" - "encoding/json" "fmt" "os" - "path/filepath" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" @@ -45,7 +43,7 @@ func OrPlaceholder(val string) string { // to placeholder values that the user must replace manually. func DefaultKubeadmConfig(ctx context.Context) *kubeadm.Config { credOptions := &azidentity.AzureCLICredentialOptions{} - if tenantID := azureConfigTenantID(); tenantID != "" { + if tenantID := config.AzureTenantID(); tenantID != "" { credOptions.TenantID = tenantID } credentials, err := azidentity.NewAzureCLICredential(credOptions) @@ -70,37 +68,3 @@ func DefaultKubeadmConfig(ctx context.Context) *kubeadm.Config { } return cfg } - -func azureConfigTenantID() string { - azureConfigDir := os.Getenv("AZURE_CONFIG_DIR") - if azureConfigDir == "" { - azureConfigDir = filepath.Join(os.Getenv("HOME"), ".azure") - } - - b, err := os.ReadFile(filepath.Join(azureConfigDir, "azureProfile.json")) - if err != nil { - return "" - } - - var profile struct { - Subscriptions []struct { - IsDefault bool `json:"isDefault"` - TenantID string `json:"tenantId"` - } `json:"subscriptions"` - } - if err := json.Unmarshal(b, &profile); err != nil { - return "" - } - - for _, sub := range profile.Subscriptions { - if sub.IsDefault && sub.TenantID != "" { - return sub.TenantID - } - } - - if len(profile.Subscriptions) == 1 { - return profile.Subscriptions[0].TenantID - } - - return "" -} diff --git a/cli/internal/config/configcmd/defaults_test.go b/cli/internal/config/configcmd/defaults_test.go index 1fd7af8..1ab3551 100644 --- a/cli/internal/config/configcmd/defaults_test.go +++ b/cli/internal/config/configcmd/defaults_test.go @@ -1,24 +1 @@ package configcmd - -import ( - "os" - "path/filepath" - "testing" -) - -func TestAzureConfigTenantIDUsesAzureConfigDirProfile(t *testing.T) { - home := t.TempDir() - t.Setenv("HOME", home) - azureConfigDir := filepath.Join(t.TempDir(), "azure-profile") - if err := os.MkdirAll(azureConfigDir, 0o755); err != nil { - t.Fatalf("mkdir azure config dir: %v", err) - } - if err := os.WriteFile(filepath.Join(azureConfigDir, "azureProfile.json"), []byte(`{"subscriptions":[{"id":"sub","isDefault":true,"tenantId":"tenant-123"}]}`), 0o600); err != nil { - t.Fatalf("write azureProfile.json: %v", err) - } - t.Setenv("AZURE_CONFIG_DIR", azureConfigDir) - - if got := azureConfigTenantID(); got != "tenant-123" { - t.Fatalf("unexpected tenant id %q", got) - } -} diff --git a/plugin/pkg/util/config/config.go b/plugin/pkg/util/config/config.go index 23c5bab..92f7e2c 100644 --- a/plugin/pkg/util/config/config.go +++ b/plugin/pkg/util/config/config.go @@ -3,12 +3,10 @@ package config import ( "fmt" "os" - "path/filepath" + "os/exec" "regexp" "strconv" "strings" - - "gopkg.in/ini.v1" ) var ( @@ -102,27 +100,26 @@ func (c *Config) validate() error { return nil } +// AzureTenantID returns the tenant ID of the current Azure CLI account by +// running `az account show --query 'tenantId' -o tsv`. +func AzureTenantID() string { + out, err := exec.Command("az", "account", "show", "--query", "tenantId", "-o", "tsv").Output() + if err != nil { + return "" + } + return strings.TrimSpace(string(out)) +} + func defaultSubscriptionID() string { if subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID"); subscriptionID != "" { return subscriptionID } - azureConfigDir := os.Getenv("AZURE_CONFIG_DIR") - if azureConfigDir == "" { - azureConfigDir = filepath.Join(os.Getenv("HOME"), ".azure") - } - - b, err := os.ReadFile(filepath.Join(azureConfigDir, "clouds.config")) - if err != nil { - return "" - } - - f, err := ini.Load(b) + out, err := exec.Command("az", "account", "show", "--query", "id", "-o", "tsv").Output() if err != nil { return "" } - - return f.Section("AzureCloud").Key("subscription").String() + return strings.TrimSpace(string(out)) } func defaultResourceGroupName() string { diff --git a/plugin/pkg/util/config/config_test.go b/plugin/pkg/util/config/config_test.go index 1710aa7..b4b295a 100644 --- a/plugin/pkg/util/config/config_test.go +++ b/plugin/pkg/util/config/config_test.go @@ -3,25 +3,42 @@ package config import ( "os" "path/filepath" + "runtime" "testing" ) -func TestDefaultSubscriptionIDHonorsAzureConfigDir(t *testing.T) { +func TestDefaultSubscriptionIDUsesAZCLI(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping shell-script fake az on Windows") + } t.Setenv("AZURE_SUBSCRIPTION_ID", "") - home := t.TempDir() - t.Setenv("HOME", home) - azureConfigDir := filepath.Join(t.TempDir(), "azure-custom") - if err := os.MkdirAll(azureConfigDir, 0o755); err != nil { - t.Fatalf("mkdir azureConfigDir: %v", err) - } - if err := os.WriteFile(filepath.Join(azureConfigDir, "clouds.config"), []byte("[AzureCloud]\nsubscription = 11111111-2222-3333-4444-555555555555\n"), 0o600); err != nil { - t.Fatalf("write clouds.config: %v", err) + dir := t.TempDir() + fakeAZ := filepath.Join(dir, "az") + if err := os.WriteFile(fakeAZ, []byte("#!/bin/sh\necho '11111111-2222-3333-4444-555555555555'\n"), 0o755); err != nil { + t.Fatalf("write fake az: %v", err) } - t.Setenv("AZURE_CONFIG_DIR", azureConfigDir) + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) got := defaultSubscriptionID() if got != "11111111-2222-3333-4444-555555555555" { t.Fatalf("unexpected subscription id %q", got) } } + +func TestAzureTenantIDUsesAZCLI(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping shell-script fake az on Windows") + } + + dir := t.TempDir() + fakeAZ := filepath.Join(dir, "az") + if err := os.WriteFile(fakeAZ, []byte("#!/bin/sh\necho 'tenant-from-az'\n"), 0o755); err != nil { + t.Fatalf("write fake az: %v", err) + } + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) + + if got := AzureTenantID(); got != "tenant-from-az" { + t.Fatalf("unexpected tenant id %q", got) + } +} diff --git a/plugin/pkg/util/k8s/k8s.go b/plugin/pkg/util/k8s/k8s.go index 1069925..981271d 100644 --- a/plugin/pkg/util/k8s/k8s.go +++ b/plugin/pkg/util/k8s/k8s.go @@ -2,6 +2,9 @@ package k8s import ( "context" + "errors" + "fmt" + "net/url" "os" "path/filepath" @@ -119,3 +122,54 @@ func MergeKubeconfigInto(ctx context.Context, credentials azcore.TokenCredential return os.WriteFile(path, content, 0600) } + +// APIServerFromKubeconfigFile returns the API server hostname and port from +// the kubeconfig file at path. If contextName is non-empty it is used to +// select the context; otherwise the file's current-context is used. +func APIServerFromKubeconfigFile(path, contextName string) (host, port string, err error) { + kcfg, err := clientcmd.LoadFromFile(path) + if err != nil { + return "", "", fmt.Errorf("loading kubeconfig for API server: %w", err) + } + + ctxName := contextName + if ctxName == "" { + ctxName = kcfg.CurrentContext + } + if ctxName == "" { + return "", "", errors.New("kubeconfig missing current context") + } + + ctxCfg, ok := kcfg.Contexts[ctxName] + if !ok || ctxCfg == nil { + return "", "", fmt.Errorf("kubeconfig missing context %q", ctxName) + } + + clusterCfg, ok := kcfg.Clusters[ctxCfg.Cluster] + if !ok || clusterCfg == nil { + return "", "", fmt.Errorf("kubeconfig missing cluster %q", ctxCfg.Cluster) + } + + u, err := url.Parse(clusterCfg.Server) + if err != nil { + return "", "", fmt.Errorf("parsing API server URL %q: %w", clusterCfg.Server, err) + } + + hostname := u.Hostname() + p := u.Port() + if hostname == "" { + return "", "", fmt.Errorf("API server URL missing hostname: %q", clusterCfg.Server) + } + if p == "" { + switch u.Scheme { + case "https": + p = "443" + case "http": + p = "80" + default: + return "", "", fmt.Errorf("API server URL missing port and unsupported scheme %q", u.Scheme) + } + } + + return hostname, p, nil +}