diff --git a/.github/workflows/multi-profile-e2e.yml b/.github/workflows/multi-profile-e2e.yml new file mode 100644 index 00000000..61a1d9ae --- /dev/null +++ b/.github/workflows/multi-profile-e2e.yml @@ -0,0 +1,54 @@ +name: Multi Profile E2E + +on: + pull_request: + push: + workflow_dispatch: + +permissions: + contents: read + +concurrency: + group: multi-profile-e2e-${{ github.ref }} + cancel-in-progress: true + +jobs: + multi-profile-e2e: + name: Multi Profile E2E + runs-on: ubuntu-latest + timeout-minutes: 15 + env: + MULTI_PROFILE_E2E_LOG: .tmp-bin/multi-profile-e2e.log + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Run isolated multi-profile chain + shell: bash + run: | + set -o pipefail + mkdir -p .tmp-bin + bash scripts/dev/test-multi-profile-e2e.sh --keep-workdir | tee "$MULTI_PROFILE_E2E_LOG" + { + echo "### Multi Profile E2E" + echo "- Command: \`bash scripts/dev/test-multi-profile-e2e.sh --keep-workdir\`" + echo "- Scope: isolated auth/profile storage, profile switch/use, one-shot profile override, CSV multi-profile aggregation, legacy migration" + echo "- Result: passed" + } >> "$GITHUB_STEP_SUMMARY" + + - name: Upload debug artifacts + if: failure() + uses: actions/upload-artifact@v4 + with: + name: multi-profile-e2e-debug + path: | + .tmp-bin/multi-profile-e2e.*/out + .tmp-bin/multi-profile-e2e.log + if-no-files-found: ignore + retention-days: 3 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b888bc8f..e54347b3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -31,6 +31,9 @@ jobs: - name: Install archive tooling run: sudo apt-get update && sudo apt-get install -y zip unzip + - name: Multi Profile E2E + run: bash scripts/dev/test-multi-profile-e2e.sh + - name: Install rcodesign (ad-hoc sign darwin binaries from Linux) run: | set -eu diff --git a/go.mod b/go.mod index 4c591985..3527f9bc 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,13 @@ go 1.25.8 require ( github.com/RealAlexandreAI/json-repair v0.0.15 + github.com/charmbracelet/bubbletea v1.3.6 github.com/charmbracelet/huh v1.0.0 + github.com/charmbracelet/lipgloss v1.1.0 github.com/fatih/color v1.18.0 github.com/google/uuid v1.6.0 github.com/itchyny/gojq v0.12.18 + github.com/muesli/termenv v0.16.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/spf13/cobra v1.10.2 github.com/zalando/go-keyring v0.2.8 @@ -21,9 +24,7 @@ require ( github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/catppuccin/go v0.3.0 // indirect github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect - github.com/charmbracelet/bubbletea v1.3.6 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect - github.com/charmbracelet/lipgloss v1.1.0 // indirect github.com/charmbracelet/x/ansi v0.9.3 // indirect github.com/charmbracelet/x/cellbuf v0.0.13 // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect @@ -44,7 +45,6 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect - github.com/muesli/termenv v0.16.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/sync v0.20.0 // indirect diff --git a/internal/app/auth_command.go b/internal/app/auth_command.go index 5923b64d..29580af1 100644 --- a/internal/app/auth_command.go +++ b/internal/app/auth_command.go @@ -39,11 +39,12 @@ import ( ) type authLoginConfig struct { - Token string - Force bool - Device bool - Recommend bool - Yes bool + Token string + Force bool + Device bool + Recommend bool + Yes bool + TargetCorpID string } type authLoginGuideAction string @@ -109,10 +110,11 @@ func newAuthLoginCommand(patCaller edition.ToolCaller) *cobra.Command { 否则 OAuth 回调会跳到本机不可达的 127.0.0.1 链接,授权完成后无法回写 token。 示例: - dws auth login # 本机登录后选择推荐/全部权限与授权业务域 + dws auth login # 本机登录并新增/刷新一个组织 profile + dws auth login --profile # 指定本次授权目标组织,不持久切换当前组织 dws auth login --recommend # 无交互批量授权服务端推荐权限 dws auth login --device # SSH 远程 / 无头环境登录 (设备流) - dws auth login --force # 强制重新登录 (忽略缓存 token) + dws auth login --force # 兼容保留;login 默认已忽略缓存并进入授权流程 dws auth login --token xxx # 使用指定 token`, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { @@ -154,8 +156,9 @@ func newAuthLoginCommand(patCaller edition.ToolCaller) *cobra.Command { provider := authpkg.NewOAuthProvider(configDir, nil) provider.Output = cmd.ErrOrStderr() provider.NoBrowser, _ = cmd.Flags().GetBool("no-browser") + provider.TargetCorpID = cfg.TargetCorpID configureOAuthProviderCompatibility(provider, configDir) - tokenData, err = provider.Login(loginCtx, cfg.Force) + tokenData, err = provider.Login(loginCtx, authLoginForcesAuthorization(cfg)) if err != nil { return apperrors.NewAuth(fmt.Sprintf("dingtalk login failed: %v", err)) } @@ -163,6 +166,11 @@ func newAuthLoginCommand(patCaller edition.ToolCaller) *cobra.Command { ResetRuntimeTokenCache() clearCompatCache() + if tokenData != nil && strings.TrimSpace(tokenData.CorpID) != "" { + _ = enrichAuthLoginProfileFromContact(cmd.Context(), configDir, patCaller, tokenData) + ResetRuntimeTokenCache() + clearCompatCache() + } w := cmd.OutOrStdout() runPostLoginAuthorization := func() error { @@ -217,7 +225,7 @@ func newAuthLoginCommand(patCaller edition.ToolCaller) *cobra.Command { if err := runPostLoginAuthorization(); err != nil { return err } - return writeAuthLoginJSON(w, tokenData, cfg.Force) + return writeAuthLoginJSON(w, tokenData, authLoginForcesAuthorization(cfg)) } // Default table output @@ -225,7 +233,7 @@ func newAuthLoginCommand(patCaller edition.ToolCaller) *cobra.Command { return err } fmt.Fprintln(w) - if !cfg.Device && tokenData != nil && tokenData.IsAccessTokenValid() && !cfg.Force { + if !cfg.Device && tokenData != nil && tokenData.IsAccessTokenValid() && !authLoginForcesAuthorization(cfg) { fmt.Fprintln(w, authLoginStatusLine("Token 有效,无需重新登录")) } else { fmt.Fprintln(w, authLoginStatusLine("登录成功!")) @@ -250,7 +258,7 @@ func newAuthLoginCommand(patCaller edition.ToolCaller) *cobra.Command { } cmd.Flags().String("token", "", "Access token") cmd.Flags().Bool("device", false, "Use device authorization flow") - cmd.Flags().Bool("force", false, "Force interactive login (ignore cached token)") + cmd.Flags().Bool("force", false, "兼容保留;login 默认已忽略缓存并进入授权流程") cmd.Flags().Bool("recommend", false, "登录成功后无交互批量授权服务端推荐权限") // Hidden compatibility flags cmd.Flags().String("redirect-url", "", "Loopback redirect URL") @@ -373,58 +381,67 @@ func selectLoginRecommendScopeMode() (pat.LoginRecommendScopeMode, error) { } func newAuthLogoutCommand() *cobra.Command { - return &cobra.Command{ - Use: "logout", - Short: "清除认证信息", + cmd := &cobra.Command{ + Use: "logout", + Short: "清除认证信息(默认退出所有组织)", + Long: `清除本机钉钉登录态。 + +默认退出所有已登录组织 profile;指定 --profile 时只退出该组织,不影响其他组织。`, + Example: ` dws auth logout + dws auth logout --profile + dws auth logout --profile "钉钉"`, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { configDir := defaultConfigDir() + profileSelector, err := cmd.Flags().GetString("profile") + if err != nil { + return apperrors.NewInternal("failed to read --profile") + } revokeCtx, cancel := context.WithTimeout(cmd.Context(), 15*time.Second) defer cancel() - _ = authpkg.RevokeTokenRemote(revokeCtx) - - // Load token data to get associated clientId before deletion - var storedClientID string - if tokenData, err := authpkg.LoadTokenData(configDir); err == nil && tokenData != nil { - storedClientID = tokenData.ClientID - } - - if err := authpkg.DeleteTokenData(configDir); err != nil { - return apperrors.NewInternal(fmt.Sprintf("failed to clear token data: %v", err)) - } - // Clean up associated client secret and app token from keychain - if storedClientID != "" { - _ = authpkg.DeleteClientSecret(storedClientID) - _ = authpkg.DeleteAppTokenData(storedClientID) - } - // Also try cleaning app token using appKey from app config - if appKey, _ := authpkg.ResolveAppCredentials(configDir); appKey != "" && appKey != storedClientID { - _ = authpkg.DeleteAppTokenData(appKey) + if strings.TrimSpace(profileSelector) != "" { + if err := logoutOneProfile(cmd, revokeCtx, configDir, profileSelector); err != nil { + return err + } + } else { + if err := logoutAllProfiles(cmd, revokeCtx, configDir); err != nil { + return err + } } - // Clean up app credentials (app.json + keychain secret) - _ = authpkg.DeleteAppConfig(configDir) - _ = os.Remove(filepath.Join(configDir, "mcp_url")) - _ = os.Remove(filepath.Join(configDir, "token")) - _ = os.Remove(filepath.Join(configDir, "token.json")) ResetRuntimeTokenCache() clearCompatCache() w := cmd.OutOrStdout() - fmt.Fprintln(w, "[OK] 已清除所有认证信息") + fmt.Fprintln(w, "[OK] 已清除认证信息") if !edition.Get().IsEmbedded { fmt.Fprintln(w, "请运行 dws auth login --recommend 重新登录") } return nil }, } + cmd.Flags().String("profile", "", "指定要退出的 profile 名或 corpId") + return cmd } func newAuthStatusCommand() *cobra.Command { - return &cobra.Command{ - Use: "status", - Short: "查看认证状态", + cmd := &cobra.Command{ + Use: "status", + Short: "查看认证状态", + Long: `查看当前或指定组织 profile 的认证状态。 + +指定 --profile 时只读取并刷新被选中的 token slot,不会修改 currentProfile。`, + Example: ` dws auth status + dws auth status --profile + dws auth status --profile "钉钉" + dws auth status --profile --format json`, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { configDir := defaultConfigDir() + profileSelector, err := cmd.Flags().GetString("profile") + if err != nil { + return apperrors.NewInternal("failed to read --profile") + } + restoreProfile := pushRuntimeProfile(profileSelector) + defer restoreProfile() authenticated := false refreshed := false @@ -444,6 +461,8 @@ func newAuthStatusCommand() *cobra.Command { } } else if edition.Get().AutoPurgeToken { _ = authpkg.DeleteTokenData(configDir) + } else if tokenData != nil { + _ = authpkg.MarkProfileStatus(configDir, tokenData.CorpID, authpkg.ProfileStatusExpired) } } if authStatusAuthenticated(tokenData) { @@ -467,6 +486,12 @@ func newAuthStatusCommand() *cobra.Command { fmt.Fprintf(w, "%-16s%s\n", "状态:", "已登录 ✅") } if tokenData != nil { + if tokenData.CorpName != "" { + fmt.Fprintf(w, "%-16s%s\n", "企业:", tokenData.CorpName) + } + if tokenData.CorpID != "" { + fmt.Fprintf(w, "%-16s%s\n", "企业 ID:", tokenData.CorpID) + } if tokenData.IsRefreshTokenValid() { fmt.Fprintf(w, "%-16s%s\n", "Refresh Token:", "有效 ✅") } else { @@ -485,6 +510,74 @@ func newAuthStatusCommand() *cobra.Command { return nil }, } + cmd.Flags().String("profile", "", "指定要查看的 profile 名或 corpId") + return cmd +} + +func logoutOneProfile(_ *cobra.Command, ctx context.Context, configDir, selector string) error { + if _, err := authpkg.ResolveProfile(configDir, selector); err != nil { + return apperrors.NewValidation(err.Error()) + } + restoreProfile := pushRuntimeProfile(selector) + defer restoreProfile() + _ = authpkg.RevokeTokenRemote(ctx) + if err := authpkg.DeleteTokenDataForProfile(configDir, selector); err != nil { + return apperrors.NewInternal(fmt.Sprintf("failed to clear token data: %v", err)) + } + return nil +} + +func logoutAllProfiles(_ *cobra.Command, ctx context.Context, configDir string) error { + if err := authpkg.EnsureProfilesMigration(configDir); err != nil { + return apperrors.NewInternal(fmt.Sprintf("failed to migrate profiles: %v", err)) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + return apperrors.NewInternal(fmt.Sprintf("failed to load profiles: %v", err)) + } + if cfg == nil || len(cfg.Profiles) == 0 { + _ = authpkg.RevokeTokenRemote(ctx) + } else { + for _, profile := range cfg.Profiles { + restoreProfile := pushRuntimeProfile(profile.CorpID) + _ = authpkg.RevokeTokenRemote(ctx) + restoreProfile() + } + } + if err := authpkg.DeleteAllTokenData(configDir); err != nil { + return apperrors.NewInternal(fmt.Sprintf("failed to clear token data: %v", err)) + } + return nil +} + +func pushRuntimeProfile(selector string) func() { + selector = strings.TrimSpace(selector) + if selector == "" { + return func() {} + } + previous := authpkg.RuntimeProfile() + authpkg.SetRuntimeProfile(selector) + return func() { + authpkg.SetRuntimeProfile(previous) + } +} + +func cleanupAuthConfigIfNoProfiles(configDir string) { + cfg, err := authpkg.LoadProfiles(configDir) + if err == nil && len(cfg.Profiles) > 0 { + return + } + if authpkg.TokenDataExistsKeychain() { + return + } + appKey, _ := authpkg.ResolveAppCredentials(configDir) + if appKey != "" { + _ = authpkg.DeleteAppTokenData(appKey) + } + _ = authpkg.DeleteAppConfig(configDir) + _ = os.Remove(filepath.Join(configDir, "mcp_url")) + _ = os.Remove(filepath.Join(configDir, "token")) + _ = authpkg.DeleteTokenMarker(configDir) } func newAuthExportCommand() *cobra.Command { @@ -683,11 +776,12 @@ func newAuthResetCommand() *cobra.Command { DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { configDir := defaultConfigDir() - if err := authpkg.DeleteTokenData(configDir); err != nil { + if err := authpkg.DeleteAllTokenData(configDir); err != nil { return apperrors.NewInternal(fmt.Sprintf("failed to reset token data: %v", err)) } _ = os.Remove(filepath.Join(configDir, "mcp_url")) _ = os.Remove(filepath.Join(configDir, "token")) + _ = authpkg.DeleteAppConfig(configDir) ResetRuntimeTokenCache() clearCompatCache() w := cmd.OutOrStdout() @@ -958,18 +1052,153 @@ func resolveAuthLoginConfig(cmd *cobra.Command) (authLoginConfig, error) { return authLoginConfig{}, apperrors.NewInternal("failed to read --recommend") } yes := false + profileSelector := "" if cmd.Root() != nil { yes, _ = cmd.Root().PersistentFlags().GetBool("yes") + profileSelector, _ = cmd.Root().PersistentFlags().GetString("profile") + } + targetCorpID, err := resolveAuthLoginTargetCorpID(defaultConfigDir(), profileSelector) + if err != nil { + return authLoginConfig{}, err } return authLoginConfig{ - Token: strings.TrimSpace(token), - Force: force, - Device: device, - Recommend: recommend, - Yes: yes, + Token: strings.TrimSpace(token), + Force: force, + Device: device, + Recommend: recommend, + Yes: yes, + TargetCorpID: targetCorpID, }, nil } +func authLoginForcesAuthorization(_ authLoginConfig) bool { + return true +} + +func resolveAuthLoginTargetCorpID(configDir, selector string) (string, error) { + selector = strings.TrimSpace(selector) + if selector == "" { + return "", nil + } + if profile, err := authpkg.ResolveProfile(configDir, selector); err == nil && profile != nil { + return strings.TrimSpace(profile.CorpID), nil + } + if strings.HasPrefix(selector, "ding") { + return selector, nil + } + return "", apperrors.NewValidation(fmt.Sprintf("profile %q not found", selector)) +} + +type contactProfileIdentity struct { + CorpID string + CorpName string + UserID string + UserName string +} + +func enrichAuthLoginProfileFromContact(ctx context.Context, configDir string, caller edition.ToolCaller, data *authpkg.TokenData) error { + if caller == nil || data == nil { + return nil + } + corpID := strings.TrimSpace(data.CorpID) + if corpID == "" { + return nil + } + if strings.TrimSpace(data.CorpName) != "" && strings.TrimSpace(data.UserID) != "" && strings.TrimSpace(data.UserName) != "" { + return nil + } + + restoreProfile := pushRuntimeProfile(corpID) + defer restoreProfile() + ResetRuntimeTokenCache() + + result, err := caller.CallTool(ctx, "contact", "get_current_user_profile", map[string]any{ + "profile": corpID, + }) + if err != nil { + return err + } + identity, ok := contactProfileIdentityFromToolResult(result) + if !ok { + return nil + } + if identity.CorpID != "" && identity.CorpID != corpID { + return fmt.Errorf("contact profile corpId %q does not match login corpId %q", identity.CorpID, corpID) + } + + updated := *data + if identity.CorpName != "" { + updated.CorpName = identity.CorpName + } + if identity.UserID != "" { + updated.UserID = identity.UserID + } + if identity.UserName != "" { + updated.UserName = identity.UserName + } + if updated.CorpName == data.CorpName && updated.UserID == data.UserID && updated.UserName == data.UserName { + return nil + } + if err := authpkg.SaveTokenData(configDir, &updated); err != nil { + return err + } + *data = updated + return nil +} + +func contactProfileIdentityFromToolResult(result *edition.ToolResult) (contactProfileIdentity, bool) { + if result == nil { + return contactProfileIdentity{}, false + } + for _, block := range result.Content { + if strings.TrimSpace(block.Text) == "" { + continue + } + if identity, ok := contactProfileIdentityFromJSON([]byte(block.Text)); ok { + return identity, true + } + } + return contactProfileIdentity{}, false +} + +func contactProfileIdentityFromJSON(data []byte) (contactProfileIdentity, bool) { + var payload struct { + Result []struct { + OrgEmployeeModel struct { + CorpID string `json:"corpId"` + OrgName string `json:"orgName"` + UserID string `json:"userId"` + UserIDLower string `json:"userid"` + OrgUserName string `json:"orgUserName"` + Name string `json:"name"` + } `json:"orgEmployeeModel"` + } `json:"result"` + } + if err := json.Unmarshal(data, &payload); err != nil { + return contactProfileIdentity{}, false + } + if len(payload.Result) == 0 { + return contactProfileIdentity{}, false + } + org := payload.Result[0].OrgEmployeeModel + identity := contactProfileIdentity{ + CorpID: strings.TrimSpace(org.CorpID), + CorpName: strings.TrimSpace(org.OrgName), + UserID: firstNonEmptyString(org.UserID, org.UserIDLower), + UserName: firstNonEmptyString(org.OrgUserName, org.Name), + } + return identity, identity.CorpID != "" || identity.CorpName != "" || identity.UserID != "" || identity.UserName != "" +} + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + func authStatusAuthenticated(data *authpkg.TokenData) bool { if data == nil { return false diff --git a/internal/app/auth_command_test.go b/internal/app/auth_command_test.go index 2844ab70..74db1272 100644 --- a/internal/app/auth_command_test.go +++ b/internal/app/auth_command_test.go @@ -184,6 +184,161 @@ func TestAuthStatusRefreshFailureLeavesStoredTokenIntact(t *testing.T) { } } +func TestAuthStatusTableIncludesCorpName(t *testing.T) { + setupAuthLogoutProfiles(t, authLogoutTestToken("corp_primary")) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--format", "table", "auth", "status"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("auth status --format table error = %v\noutput:\n%s", err, out.String()) + } + for _, want := range []string{"企业:", "corp_primary org", "企业 ID:", "corp_primary"} { + if !bytes.Contains(out.Bytes(), []byte(want)) { + t.Fatalf("auth status table missing %q in output:\n%s", want, out.String()) + } + } +} + +func TestAuthStatusProfileOverrideDoesNotSwitchCurrentProfile(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--format", "table", "auth", "status", "--profile", "corp_primary"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("auth status --profile error = %v\noutput:\n%s", err, out.String()) + } + for _, want := range []string{"corp_primary org", "corp_primary"} { + if !bytes.Contains(out.Bytes(), []byte(want)) { + t.Fatalf("auth status --profile output missing %q:\n%s", want, out.String()) + } + } + if bytes.Contains(out.Bytes(), []byte("corp_secondary org")) { + t.Fatalf("auth status --profile should render selected profile, got:\n%s", out.String()) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_secondary" { + t.Fatalf("currentProfile = %q, want unchanged corp_secondary", cfg.CurrentProfile) + } +} + +func TestAuthLogoutDefaultDeletesAllProfilesAndPreservesAppConfig(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + if err := authpkg.SaveAppConfig(configDir, &authpkg.AppConfig{ + ClientID: "client-app", + ClientSecret: authpkg.PlainSecret("secret-app"), + }); err != nil { + t.Fatalf("SaveAppConfig() error = %v", err) + } + + originalTransport := http.DefaultTransport + t.Cleanup(func() { + http.DefaultTransport = originalTransport + }) + http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("remote revoke disabled in unit test") + }) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"auth", "logout"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("auth logout error = %v\noutput:\n%s", err, out.String()) + } + for _, want := range []string{"[OK] 已清除认证信息", "重新登录"} { + if !strings.Contains(out.String(), want) { + t.Fatalf("auth logout output missing %q:\n%s", want, out.String()) + } + } + + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.PrimaryProfile != "" || cfg.CurrentProfile != "" || cfg.PreviousProfile != "" || len(cfg.Profiles) != 0 { + t.Fatalf("profiles after logout = %#v, want empty", cfg) + } + if authpkg.TokenDataExistsKeychainForCorpID("corp_primary") { + t.Fatal("primary profile token should be deleted") + } + if authpkg.TokenDataExistsKeychainForCorpID("corp_secondary") { + t.Fatal("secondary profile token should be deleted") + } + if authpkg.TokenDataExistsKeychain() { + t.Fatal("legacy auth-token mirror should be deleted") + } + appConfig, err := authpkg.LoadAppConfig(configDir) + if err != nil { + t.Fatalf("LoadAppConfig() error = %v", err) + } + if appConfig == nil || appConfig.ClientID != "client-app" { + t.Fatalf("app config after logout = %#v, want preserved client-app", appConfig) + } +} + +func TestAuthLogoutProfileDeletesOnlySelectedProfile(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + + originalTransport := http.DefaultTransport + t.Cleanup(func() { + http.DefaultTransport = originalTransport + }) + http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("remote revoke disabled in unit test") + }) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"auth", "logout", "--profile", "corp_primary"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("auth logout --profile corp_primary error = %v\noutput:\n%s", err, out.String()) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.PrimaryProfile != "corp_secondary" || cfg.CurrentProfile != "corp_secondary" { + t.Fatalf("profiles pointers = primary %q current %q, want corp_secondary/corp_secondary", cfg.PrimaryProfile, cfg.CurrentProfile) + } + if len(cfg.Profiles) != 1 || cfg.Profiles[0].CorpID != "corp_secondary" { + t.Fatalf("profiles = %#v, want only corp_secondary retained", cfg.Profiles) + } + if authpkg.TokenDataExistsKeychainForCorpID("corp_primary") { + t.Fatal("selected primary profile token should be deleted") + } + if !authpkg.TokenDataExistsKeychainForCorpID("corp_secondary") { + t.Fatal("unselected secondary profile token should be retained") + } + loaded, err := authpkg.LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() error = %v", err) + } + if loaded.CorpID != "corp_secondary" || loaded.AccessToken != "access-corp_secondary" { + t.Fatalf("default token = (%q, %q), want retained secondary token", loaded.CorpID, loaded.AccessToken) + } +} + func TestAuthLoginPostLoginTUIModeRespectsRecommendAndFormat(t *testing.T) { newRoot := func(t *testing.T) *cobra.Command { t.Helper() @@ -297,6 +452,15 @@ func TestResolveAuthLoginConfigReadsInheritedYes(t *testing.T) { } } +func TestAuthLoginForcesAuthorizationByDefault(t *testing.T) { + if !authLoginForcesAuthorization(authLoginConfig{}) { + t.Fatal("auth login should force authorization by default so each login can add an organization profile") + } + if !authLoginForcesAuthorization(authLoginConfig{Force: false}) { + t.Fatal("Force=false should still force authorization") + } +} + func TestAuthLoginRecommendSkipsPostLoginTUI(t *testing.T) { t.Setenv(keychain.DisableKeychainEnv, "1") t.Setenv(keychain.StorageDirEnv, t.TempDir()) @@ -578,6 +742,53 @@ func TestAuthLoginDefaultTUIRunsAfterLoginTokenSaved(t *testing.T) { } } +func TestEnrichAuthLoginProfileFromContactPersistsCorpName(t *testing.T) { + t.Setenv(keychain.DisableKeychainEnv, "1") + t.Setenv(keychain.StorageDirEnv, t.TempDir()) + configDir := t.TempDir() + t.Setenv("DWS_CONFIG_DIR", configDir) + + token := &authpkg.TokenData{ + AccessToken: "access-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(time.Hour), + RefreshExpAt: time.Now().Add(24 * time.Hour), + CorpID: "ding32fff839a3e0105d", + ClientID: "client-id", + Source: "mcp", + } + if err := authpkg.SaveTokenData(configDir, token); err != nil { + t.Fatalf("SaveTokenData() error = %v", err) + } + + fake := &authLoginRecommendSequenceCaller{responses: []string{ + `{"success":true,"result":[{"orgEmployeeModel":{"corpId":"ding32fff839a3e0105d","orgName":"钉钉(中国)信息技术有限公司","userId":"011352590165863362195","orgUserName":"玄玦(主用钉)"}}]}`, + }} + if err := enrichAuthLoginProfileFromContact(context.Background(), configDir, fake, token); err != nil { + t.Fatalf("enrichAuthLoginProfileFromContact() error = %v", err) + } + if token.CorpName != "钉钉(中国)信息技术有限公司" { + t.Fatalf("token corpName = %q, want 钉钉(中国)信息技术有限公司", token.CorpName) + } + if token.UserID != "011352590165863362195" || token.UserName != "玄玦(主用钉)" { + t.Fatalf("token user identity = (%q, %q), want contact result", token.UserID, token.UserName) + } + + loaded, err := authpkg.LoadTokenDataForProfile(configDir, "ding32fff839a3e0105d") + if err != nil { + t.Fatalf("LoadTokenDataForProfile() error = %v", err) + } + if loaded.CorpName != "钉钉(中国)信息技术有限公司" { + t.Fatalf("persisted corpName = %q, want 钉钉(中国)信息技术有限公司", loaded.CorpName) + } + if len(fake.tools) != 1 || fake.tools[0] != "get_current_user_profile" { + t.Fatalf("tool calls = %v, want get_current_user_profile", fake.tools) + } + if got := fake.args[0]["profile"]; got != "ding32fff839a3e0105d" { + t.Fatalf("contact profile arg = %#v, want ding32fff839a3e0105d", got) + } +} + type roundTripFunc func(*http.Request) (*http.Response, error) func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { @@ -642,3 +853,41 @@ func stringSliceArgEqual(got any, want []string) bool { return false } } + +func setupAuthLogoutProfiles(t *testing.T, tokens ...*authpkg.TokenData) string { + t.Helper() + root := t.TempDir() + configDir := filepath.Join(root, "config") + t.Setenv(keychain.DisableKeychainEnv, "1") + t.Setenv(keychain.StorageDirEnv, filepath.Join(root, "keychain")) + t.Setenv("DWS_CONFIG_DIR", configDir) + authpkg.SetRuntimeProfile("") + ResetRuntimeTokenCache() + clearCompatCache() + t.Cleanup(func() { + authpkg.SetRuntimeProfile("") + ResetRuntimeTokenCache() + clearCompatCache() + }) + + for _, token := range tokens { + if err := authpkg.SaveTokenData(configDir, token); err != nil { + t.Fatalf("SaveTokenData(%s) error = %v", token.CorpID, err) + } + } + return configDir +} + +func authLogoutTestToken(corpID string) *authpkg.TokenData { + return &authpkg.TokenData{ + AccessToken: "access-" + corpID, + RefreshToken: "refresh-" + corpID, + ExpiresAt: time.Now().Add(time.Hour), + RefreshExpAt: time.Now().Add(24 * time.Hour), + CorpID: corpID, + CorpName: corpID + " org", + UserID: "user-" + corpID, + UserName: "User " + corpID, + ClientID: "client-" + corpID, + } +} diff --git a/internal/app/flags.go b/internal/app/flags.go index 1a869c27..1918002c 100644 --- a/internal/app/flags.go +++ b/internal/app/flags.go @@ -29,6 +29,7 @@ type GlobalFlags struct { JQ string Mock bool Output string + Profile string Timeout int Token string Verbose bool @@ -46,6 +47,7 @@ func bindPersistentFlags(cmd *cobra.Command, flags *GlobalFlags) { cmd.PersistentFlags().BoolVar(&flags.Mock, "mock", false, "使用 Mock 数据 (开发调试用)") cmd.PersistentFlags().StringVarP(&flags.Output, "output", "o", "", "Write command output to a file") _ = cmd.PersistentFlags().MarkHidden("output") + cmd.PersistentFlags().StringVar(&flags.Profile, "profile", "", "一次性指定本次命令使用的组织 profile 名或 corpId;多个按 CSV 逗号分隔,如 corpA,corpB") cmd.PersistentFlags().IntVar(&flags.Timeout, "timeout", 30, "HTTP 请求超时时间 (秒)") cmd.PersistentFlags().StringVar(&flags.Token, "token", "", "Override the configured API token") _ = cmd.PersistentFlags().MarkHidden("token") diff --git a/internal/app/help_source_test.go b/internal/app/help_source_test.go index 62763d68..c548e5fa 100644 --- a/internal/app/help_source_test.go +++ b/internal/app/help_source_test.go @@ -163,11 +163,16 @@ func TestRootHelpUsesMCPOnlySummary(t *testing.T) { t.Fatalf("root help missing %q:\n%s", want, got) } } - for _, unwanted := range []string{"快速开始:", "更多信息:", "auth 认证管理", "Flags:"} { + for _, unwanted := range []string{"快速开始:", "更多信息:", "auth 认证管理"} { if strings.Contains(got, unwanted) { t.Fatalf("root help unexpectedly contains %q:\n%s", unwanted, got) } } + for _, want := range []string{"Global Flags:", "--profile"} { + if !strings.Contains(got, want) { + t.Fatalf("root help missing %q:\n%s", want, got) + } + } } func TestRootHelpCustomizationDoesNotAffectSubcommandHelp(t *testing.T) { @@ -217,6 +222,60 @@ func TestRootHelpCustomizationDoesNotAffectSubcommandHelp(t *testing.T) { } } +func TestProfileHelpDocumentsMultiProfileUsage(t *testing.T) { + got := executeHelpForTest(t, "profile", "switch", "--help") + for _, want := range []string{ + "切换默认组织 profile", + "需要只影响单次业务命令时,请使用全局 --profile", + "dws profile switch --corpId ", + "dws --profile contact user get-self", + "--corpId string", + "--name string", + } { + if !strings.Contains(got, want) { + t.Fatalf("profile switch help missing %q:\n%s", want, got) + } + } + + got = executeHelpForTest(t, "profile", "list", "--help") + for _, want := range []string{ + "列出本机已登录的所有组织 profile", + "dws profile list --format json", + } { + if !strings.Contains(got, want) { + t.Fatalf("profile list help missing %q:\n%s", want, got) + } + } +} + +func TestAuthHelpDocumentsProfileUsage(t *testing.T) { + got := executeHelpForTest(t, "auth", "login", "--help") + if !strings.Contains(got, "dws auth login --profile ") { + t.Fatalf("auth login help missing --profile example:\n%s", got) + } + + got = executeHelpForTest(t, "auth", "status", "--help") + for _, want := range []string{ + "查看当前或指定组织 profile 的认证状态", + "只读取并刷新被选中的 token slot", + "dws auth status --profile ", + } { + if !strings.Contains(got, want) { + t.Fatalf("auth status help missing %q:\n%s", want, got) + } + } + + got = executeHelpForTest(t, "auth", "logout", "--help") + for _, want := range []string{ + "默认退出所有已登录组织 profile", + "dws auth logout --profile ", + } { + if !strings.Contains(got, want) { + t.Fatalf("auth logout help missing %q:\n%s", want, got) + } + } +} + func TestRootCommandRegistersUpgradeCommand(t *testing.T) { root := NewRootCommand() if cmd := lookupCommand(root, "upgrade"); cmd == nil { @@ -224,6 +283,22 @@ func TestRootCommandRegistersUpgradeCommand(t *testing.T) { } } +func executeHelpForTest(t *testing.T, args ...string) string { + t.Helper() + t.Setenv(cli.CatalogFixtureEnv, "") + t.Setenv(cli.CacheDirEnv, t.TempDir()) + + root := NewRootCommand() + var out bytes.Buffer + root.SetOut(&out) + root.SetErr(&out) + root.SetArgs(args) + if err := root.Execute(); err != nil { + t.Fatalf("Execute(%v) error = %v\noutput:\n%s", args, err, out.String()) + } + return out.String() +} + func discoveryServerEntry(command, description string, groups, toolOverrides map[string]any) map[string]any { cliMeta := map[string]any{ "id": command, diff --git a/internal/app/multi_profile_runner_test.go b/internal/app/multi_profile_runner_test.go new file mode 100644 index 00000000..f4ff87c5 --- /dev/null +++ b/internal/app/multi_profile_runner_test.go @@ -0,0 +1,148 @@ +package app + +import ( + "context" + "strings" + "testing" + + authpkg "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/auth" + "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/executor" +) + +func TestRuntimeRunnerAggregatesCommaSeparatedProfiles(t *testing.T) { + setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_a"), + authLogoutTestToken("corp_b"), + ) + authpkg.SetRuntimeProfile("corp_a, corp_b") + + runner := &runtimeRunner{fallback: multiProfileFallbackRunner{}} + result, err := runner.Run(context.Background(), executor.Invocation{ + Kind: "helper_invocation", + CanonicalProduct: "contact", + Tool: "get_current_user_profile", + Params: map[string]any{"limit": 10}, + }) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if got := authpkg.RuntimeProfile(); got != "corp_a, corp_b" { + t.Fatalf("runtime profile after Run = %q, want restored raw selector", got) + } + + content := result.Response["content"].(map[string]any) + if content["multiProfile"] != true { + t.Fatalf("multiProfile = %#v, want true", content["multiProfile"]) + } + if content["success"] != true { + t.Fatalf("success = %#v, want true", content["success"]) + } + profiles := content["profiles"].([]any) + if len(profiles) != 2 { + t.Fatalf("profiles len = %d, want 2", len(profiles)) + } + for i, wantCorpID := range []string{"corp_a", "corp_b"} { + entry := profiles[i].(map[string]any) + if entry["corpId"] != wantCorpID { + t.Fatalf("profiles[%d].corpId = %#v, want %q", i, entry["corpId"], wantCorpID) + } + if entry["ok"] != true { + t.Fatalf("profiles[%d].ok = %#v, want true", i, entry["ok"]) + } + resultPayload := entry["result"].(map[string]any) + if resultPayload["runtimeProfile"] != wantCorpID { + t.Fatalf("profiles[%d].result.runtimeProfile = %#v, want %q", i, resultPayload["runtimeProfile"], wantCorpID) + } + } +} + +func TestRuntimeRunnerDeduplicatesCommaSeparatedProfilesByCorpID(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, authLogoutTestToken("corp_a"), authLogoutTestToken("corp_b")) + authpkg.SetRuntimeProfile("corp_a, corp_a org,corp_b") + + selections, multi, err := resolveMultiProfileSelections(configDir, authpkg.RuntimeProfile()) + if err != nil { + t.Fatalf("resolveMultiProfileSelections() error = %v", err) + } + if !multi { + t.Fatal("multi = false, want true") + } + if len(selections) != 2 { + t.Fatalf("selections len = %d, want 2", len(selections)) + } + if selections[0].Profile.CorpID != "corp_a" || selections[1].Profile.CorpID != "corp_b" { + t.Fatalf("resolved corp IDs = %q, %q; want corp_a, corp_b", selections[0].Profile.CorpID, selections[1].Profile.CorpID) + } +} + +func TestRuntimeRunnerKeepsSingleProfileBehavior(t *testing.T) { + setupAuthLogoutProfiles(t, authLogoutTestToken("corp_a"), authLogoutTestToken("corp_b")) + authpkg.SetRuntimeProfile("corp_a") + + runner := &runtimeRunner{fallback: multiProfileFallbackRunner{}} + result, err := runner.Run(context.Background(), executor.Invocation{ + Kind: "helper_invocation", + CanonicalProduct: "contact", + Tool: "get_current_user_profile", + }) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if _, ok := result.Response["content"].(map[string]any)["multiProfile"]; ok { + t.Fatalf("single profile unexpectedly returned aggregate content: %#v", result.Response) + } + if got := authpkg.RuntimeProfile(); got != "corp_a" { + t.Fatalf("runtime profile after Run = %q, want corp_a", got) + } +} + +func TestCommaNamedProfileStillResolvesAsSingleProfile(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, authLogoutTestToken("corp_comma"), authLogoutTestToken("corp_other")) + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + for i := range cfg.Profiles { + if cfg.Profiles[i].CorpID == "corp_comma" { + cfg.Profiles[i].Name = "alpha,beta" + } + } + if err := authpkg.SaveProfiles(configDir, cfg); err != nil { + t.Fatalf("SaveProfiles() error = %v", err) + } + + selections, multi, err := resolveMultiProfileSelections(configDir, "alpha,beta") + if err != nil { + t.Fatalf("resolveMultiProfileSelections() error = %v", err) + } + if multi { + t.Fatalf("multi = true, want false; selections=%#v", selections) + } +} + +func TestCommaSeparatedProfileRejectsEmptySelector(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, authLogoutTestToken("corp_a"), authLogoutTestToken("corp_b")) + + _, _, err := resolveMultiProfileSelections(configDir, "corp_a,,corp_b") + if err == nil { + t.Fatal("resolveMultiProfileSelections() error = nil, want validation error") + } + if !strings.Contains(err.Error(), "empty profile selector") { + t.Fatalf("error = %q, want empty profile selector", err.Error()) + } +} + +type multiProfileFallbackRunner struct{} + +func (multiProfileFallbackRunner) Run(_ context.Context, invocation executor.Invocation) (executor.Result, error) { + invocation.Implemented = true + return executor.Result{ + Invocation: invocation, + Response: map[string]any{ + "content": map[string]any{ + "runtimeProfile": authpkg.RuntimeProfile(), + "tool": invocation.Tool, + }, + }, + }, nil +} diff --git a/internal/app/p1_shared_install_test.go b/internal/app/p1_shared_install_test.go new file mode 100644 index 00000000..226fd586 --- /dev/null +++ b/internal/app/p1_shared_install_test.go @@ -0,0 +1,83 @@ +package app + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +// writeMultiSkillSrc creates a fake multi skill source tree with the given +// subdir names, each containing a minimal SKILL.md. +func writeMultiSkillSrc(t *testing.T, names ...string) string { + t.Helper() + src := t.TempDir() + for _, n := range names { + dir := filepath.Join(src, n) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "SKILL.md"), []byte("# "+n+"\n"), 0o644); err != nil { + t.Fatal(err) + } + } + return src +} + +func contains(ss []string, want string) bool { + for _, s := range ss { + if s == want { + return true + } + } + return false +} + +// dws-shared must ship even when --skill narrows the set to a single product. +func TestP1SharedAlwaysIncludedWithSkillFilter(t *testing.T) { + src := writeMultiSkillSrc(t, "dws-shared", "dingtalk-aitable", "dingtalk-calendar") + all, err := listMultiSkillNames(src) + if err != nil { + t.Fatal(err) + } + if !contains(all, "dws-shared") { + t.Fatalf("listMultiSkillNames did not enumerate dws-shared: %v", all) + } + filtered, err := filterMultiSkillNames(all, []string{"aitable"}, nil) + if err != nil { + t.Fatal(err) + } + if contains(filtered, "dws-shared") { + t.Fatalf("precondition: filter should drop dws-shared for -s aitable: %v", filtered) + } + final := ensureMandatorySharedSkill(filtered, all) + if !contains(final, "dws-shared") { + t.Fatalf("ensureMandatorySharedSkill must re-add dws-shared: %v", final) + } + + // Actually install with the filtered+mandatory set and assert dws-shared landed. + dest := t.TempDir() + var out, errOut bytes.Buffer + if _, _, err := installMultiSkillToHomes(src, final, []string{dest}, &out, &errOut); err != nil { + t.Fatalf("install: %v (%s)", err, errOut.String()) + } + if _, err := os.Stat(filepath.Join(dest, "dws-shared", "SKILL.md")); err != nil { + t.Fatalf("dws-shared not installed with -s aitable: %v", err) + } + if _, err := os.Stat(filepath.Join(dest, "dingtalk-aitable", "SKILL.md")); err != nil { + t.Fatalf("dingtalk-aitable not installed: %v", err) + } +} + +// When the source has no dws-shared (older layout), nothing is forced. +func TestP1SharedNoopWhenAbsent(t *testing.T) { + src := writeMultiSkillSrc(t, "dingtalk-aitable") + all, err := listMultiSkillNames(src) + if err != nil { + t.Fatal(err) + } + final := ensureMandatorySharedSkill([]string{"dingtalk-aitable"}, all) + if contains(final, "dws-shared") { + t.Fatalf("must not invent dws-shared when source lacks it: %v", final) + } +} diff --git a/internal/app/profile_args_test.go b/internal/app/profile_args_test.go new file mode 100644 index 00000000..031992df --- /dev/null +++ b/internal/app/profile_args_test.go @@ -0,0 +1,82 @@ +package app + +import ( + "os" + "reflect" + "testing" +) + +func TestNormalizeProfileFlagArgsAcceptsUnquotedCommaContinuation(t *testing.T) { + cases := []struct { + name string + args []string + want []string + }{ + { + name: "root profile before command", + args: []string{"--mock", "--profile", "corpA,", "corpB", "contact", "user", "get-self"}, + want: []string{"--mock", "--profile", "corpA,corpB", "contact", "user", "get-self"}, + }, + { + name: "profile after leaf command", + args: []string{"contact", "user", "get-self", "--profile", "corpA,", "corpB", "--format", "json"}, + want: []string{"contact", "user", "get-self", "--profile", "corpA,corpB", "--format", "json"}, + }, + { + name: "equals form", + args: []string{"--profile=corpA,", "corpB", "contact", "user", "get-self"}, + want: []string{"--profile=corpA,corpB", "contact", "user", "get-self"}, + }, + { + name: "three profiles", + args: []string{"--profile", "corpA,", "corpB,", "corpC", "contact", "user", "get-self"}, + want: []string{"--profile", "corpA,corpB,corpC", "contact", "user", "get-self"}, + }, + { + name: "already quoted by shell remains unchanged", + args: []string{"--profile", "corpA, corpB", "contact", "user", "get-self"}, + want: []string{"--profile", "corpA, corpB", "contact", "user", "get-self"}, + }, + { + name: "single profile remains unchanged", + args: []string{"--profile", "corpA", "contact", "user", "get-self"}, + want: []string{"--profile", "corpA", "contact", "user", "get-self"}, + }, + { + name: "trailing comma before next flag remains validation input", + args: []string{"--profile", "corpA,", "--format", "json", "contact", "user", "get-self"}, + want: []string{"--profile", "corpA,", "--format", "json", "contact", "user", "get-self"}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, _ := normalizeProfileFlagArgs(tc.args) + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("normalizeProfileFlagArgs() = %#v, want %#v", got, tc.want) + } + }) + } +} + +func TestPreparseProfileFlagUsesNormalizedProfileArgs(t *testing.T) { + got := preparseProfileFlag([]string{"--profile", "corpA,", "corpB", "contact", "user", "get-self"}) + if got != "corpA,corpB" { + t.Fatalf("preparseProfileFlag() = %q, want corpA,corpB", got) + } +} + +func TestNormalizeProcessProfileArgsRestoresOriginalArgv(t *testing.T) { + oldArgs := os.Args + t.Cleanup(func() { os.Args = oldArgs }) + + os.Args = []string{"dws", "--profile", "corpA,", "corpB", "contact", "user", "get-self"} + restore := normalizeProcessProfileArgs() + if want := []string{"dws", "--profile", "corpA,corpB", "contact", "user", "get-self"}; !reflect.DeepEqual(os.Args, want) { + t.Fatalf("os.Args after normalize = %#v, want %#v", os.Args, want) + } + restore() + if want := []string{"dws", "--profile", "corpA,", "corpB", "contact", "user", "get-self"}; !reflect.DeepEqual(os.Args, want) { + t.Fatalf("os.Args after restore = %#v, want %#v", os.Args, want) + } +} diff --git a/internal/app/profile_command.go b/internal/app/profile_command.go new file mode 100644 index 00000000..ab7166bb --- /dev/null +++ b/internal/app/profile_command.go @@ -0,0 +1,747 @@ +// Copyright 2026 Alibaba Group +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package app + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "sort" + "strings" + "time" + + authpkg "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/auth" + apperrors "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/errors" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/muesli/termenv" + "github.com/spf13/cobra" +) + +func newProfileCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "profile", + Short: "组织 profile 管理", + Long: `管理本机已登录的钉钉组织 profile。 + +每个 profile 对应一个已授权组织。业务命令可通过全局 --profile 临时指定组织, +profile switch/use 才会持久修改默认组织上下文。`, + Example: ` dws profile list + dws profile switch + dws profile switch + dws profile switch - + dws --profile contact user get-self`, + Args: cobra.NoArgs, + TraverseChildren: true, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Help() + }, + } + cmd.AddCommand(newProfileListCommand(), newProfileSwitchCommand(), newProfileUseCommand()) + return cmd +} + +func newProfileListCommand() *cobra.Command { + return &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "列出已登录组织 profile", + Long: "列出本机已登录的所有组织 profile,包含当前组织、主组织、组织名、corpId、状态和用户信息。", + Example: ` dws profile list + dws profile list --format json`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + configDir := defaultConfigDir() + if err := authpkg.EnsureProfilesMigration(configDir); err != nil { + return apperrors.NewInternal(fmt.Sprintf("failed to migrate profiles: %v", err)) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + return apperrors.NewInternal(fmt.Sprintf("failed to load profiles: %v", err)) + } + format, _ := cmd.Root().PersistentFlags().GetString("format") + if strings.EqualFold(strings.TrimSpace(format), "json") { + return writeProfileListJSON(cmd.OutOrStdout(), cfg) + } + writeProfileListTable(cmd.OutOrStdout(), cfg) + return nil + }, + } +} + +func newProfileUseCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "use [name|corpId|-]", + Short: "切换当前组织 profile(兼容 profile switch)", + Long: "兼容命令,语义等同于 dws profile switch。可用组织名、profile 名、corpId 或 - 切回上一个组织。", + Example: ` dws profile use + dws profile use --name "钉钉" + dws profile use -`, + Args: cobra.MaximumNArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return runProfileSwitchCommand(cmd, args) + }, + } + addProfileSwitchSelectorFlags(cmd) + return cmd +} + +func newProfileSwitchCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "switch [name|corpId|-]", + Short: "切换当前组织 profile", + Long: `切换默认组织 profile,并记录 previousProfile 以支持 dws profile switch - 快速切回。 + +不带参数时,交互终端会展示组织选择器;非交互环境请显式传入组织名、profile 名或 corpId。 +需要只影响单次业务命令时,请使用全局 --profile。`, + Example: ` dws profile switch + dws profile switch + dws profile switch --corpId + dws profile switch --name "钉钉" + dws profile switch - + dws --profile contact user get-self`, + Args: cobra.MaximumNArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return runProfileSwitchCommand(cmd, args) + }, + } + addProfileSwitchSelectorFlags(cmd) + return cmd +} + +func addProfileSwitchSelectorFlags(cmd *cobra.Command) { + cmd.Flags().String("corpId", "", "按 corpId 直接切换组织 profile") + cmd.Flags().String("corp-id", "", "按 corpId 直接切换组织 profile") + cmd.Flags().String("corpid", "", "按 corpId 直接切换组织 profile") + cmd.Flags().String("corp", "", "按 corpId 直接切换组织 profile") + cmd.Flags().String("name", "", "按组织名或 profile 名直接切换组织 profile") + _ = cmd.Flags().MarkHidden("corp-id") + _ = cmd.Flags().MarkHidden("corpid") + _ = cmd.Flags().MarkHidden("corp") +} + +var ( + profileSwitchSelector = selectProfileSwitchProfile + profileSwitchInteractiveTerminal = isInteractiveTerminal +) + +const ( + profileSwitchVisibleOptions = 5 + profileSwitchCellPadding = 1 + profileSwitchOrgWidth = 34 + profileSwitchStatusWidth = 10 +) + +var profileSwitchRenderer = newProfileSwitchRenderer() + +func newProfileSwitchRenderer() *lipgloss.Renderer { + renderer := lipgloss.NewRenderer(io.Discard) + renderer.SetColorProfile(termenv.TrueColor) + renderer.SetHasDarkBackground(true) + return renderer +} + +func runProfileSwitchCommand(cmd *cobra.Command, args []string) error { + configDir := defaultConfigDir() + selector, err := profileSwitchSelectorFromCommand(cmd, args) + if err != nil { + return err + } + usedTUI := false + if selector == "" { + selector, err = profileSwitchSelector(cmd, configDir) + if err != nil { + return err + } + usedTUI = true + } + return switchProfileAndWrite(cmd, configDir, selector, usedTUI) +} + +func profileSwitchSelectorFromCommand(cmd *cobra.Command, args []string) (string, error) { + selectors := make([]string, 0, 2) + if len(args) > 0 { + selectors = append(selectors, strings.TrimSpace(args[0])) + } + for _, name := range []string{"corpId", "corp-id", "corpid", "corp", "name"} { + value, changed := changedStringFlag(cmd, name) + if !changed { + continue + } + if value == "" { + return "", apperrors.NewValidation(fmt.Sprintf("--%s 不能为空", name)) + } + selectors = append(selectors, value) + } + if len(selectors) == 0 { + return "", nil + } + selector := selectors[0] + for _, candidate := range selectors[1:] { + if candidate != selector { + return "", apperrors.NewValidation("只能指定一个组织选择器,请使用位置参数或 --corpId/--name 其中一种") + } + } + return selector, nil +} + +func changedStringFlag(cmd *cobra.Command, name string) (string, bool) { + if cmd == nil || cmd.Flags() == nil { + return "", false + } + flag := cmd.Flags().Lookup(name) + if flag == nil || !flag.Changed { + return "", false + } + return strings.TrimSpace(flag.Value.String()), true +} + +func switchProfileAndWrite(cmd *cobra.Command, configDir, selector string, usedTUI bool) error { + var ( + profile *authpkg.Profile + err error + ) + if strings.TrimSpace(selector) == "-" { + profile, err = authpkg.UsePreviousProfile(configDir) + } else { + profile, err = authpkg.SetCurrentProfile(configDir, selector) + } + if err != nil { + return apperrors.NewValidation(err.Error()) + } + ResetRuntimeTokenCache() + clearCompatCache() + format, _ := cmd.Root().PersistentFlags().GetString("format") + if strings.EqualFold(strings.TrimSpace(format), "json") && !(usedTUI && authLoginAllowsInteractiveDefault(cmd, format)) { + cfg, loadErr := authpkg.LoadProfiles(configDir) + if loadErr != nil { + return apperrors.NewInternal(fmt.Sprintf("failed to load profiles: %v", loadErr)) + } + return writeProfileUseJSON(cmd.OutOrStdout(), profile, cfg) + } + fmt.Fprintln(cmd.OutOrStdout(), profileUseMessage(profile)) + return nil +} + +func selectProfileSwitchProfile(cmd *cobra.Command, configDir string) (string, error) { + if !profileSwitchInteractiveTerminal() { + return "", apperrors.NewValidation("profile selector required in non-interactive mode; use dws profile switch ") + } + if err := authpkg.EnsureProfilesMigration(configDir); err != nil { + return "", apperrors.NewInternal(fmt.Sprintf("failed to migrate profiles: %v", err)) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + return "", apperrors.NewInternal(fmt.Sprintf("failed to load profiles: %v", err)) + } + if cfg == nil || len(cfg.Profiles) == 0 { + return "", apperrors.NewValidation("未找到已登录 profile,请先运行 dws auth login") + } + choice := strings.TrimSpace(cfg.CurrentProfile) + if choice == "" { + choice = strings.TrimSpace(cfg.PrimaryProfile) + } + if choice == "" { + choice = cfg.Profiles[0].CorpID + } + return runProfileSwitchTUI(cmd, cfg, choice) +} + +func runProfileSwitchTUI(cmd *cobra.Command, cfg *authpkg.ProfilesConfig, selectedCorpID string) (string, error) { + model := newProfileSwitchTUIModel(cfg, selectedCorpID) + program := tea.NewProgram( + model, + tea.WithAltScreen(), + tea.WithInput(cmd.InOrStdin()), + tea.WithOutput(cmd.ErrOrStderr()), + tea.WithContext(cmd.Context()), + ) + finalModel, err := program.Run() + if err != nil { + if errors.Is(err, tea.ErrInterrupted) { + return "", apperrors.NewValidation("组织选择中止: user aborted") + } + return "", apperrors.NewInternal(fmt.Sprintf("failed to run profile selector: %v", err)) + } + final, ok := finalModel.(profileSwitchTUIModel) + if !ok || final.aborted || !final.submitted { + return "", apperrors.NewValidation("组织选择中止: user aborted") + } + return final.selectedCorpID(), nil +} + +type profileSwitchTUIModel struct { + cfg *authpkg.ProfilesConfig + profiles []authpkg.Profile + selected int + offset int + submitted bool + aborted bool +} + +func newProfileSwitchTUIModel(cfg *authpkg.ProfilesConfig, selectedCorpID string) profileSwitchTUIModel { + model := profileSwitchTUIModel{cfg: cfg} + if cfg != nil { + model.profiles = profileSwitchSortedProfiles(cfg.Profiles) + } + model.selected = profileSwitchProfileIndex(model.profiles, selectedCorpID) + if model.selected < 0 { + model.selected = 0 + } + model.ensureSelectedVisible() + return model +} + +func profileSwitchSortedProfiles(profiles []authpkg.Profile) []authpkg.Profile { + sorted := append([]authpkg.Profile(nil), profiles...) + sort.SliceStable(sorted, func(i, j int) bool { + left, leftOK := profileSwitchSortTime(sorted[i]) + right, rightOK := profileSwitchSortTime(sorted[j]) + if leftOK && rightOK && !left.Equal(right) { + return left.After(right) + } + if leftOK != rightOK { + return leftOK + } + return false + }) + return sorted +} + +func profileSwitchSortTime(p authpkg.Profile) (time.Time, bool) { + for _, raw := range []string{p.LastLoginAt, p.UpdatedAt, p.LastUsedAt} { + if t, ok := parseProfileSwitchTime(raw); ok { + return t, true + } + } + return time.Time{}, false +} + +func parseProfileSwitchTime(raw string) (time.Time, bool) { + raw = strings.TrimSpace(raw) + if raw == "" { + return time.Time{}, false + } + t, err := time.Parse(time.RFC3339, raw) + if err != nil { + return time.Time{}, false + } + return t, true +} + +func (m profileSwitchTUIModel) Init() tea.Cmd { + return nil +} + +func (m profileSwitchTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "ctrl+c", "esc", "q": + m.aborted = true + return m, tea.Quit + case "up", "k": + if m.selected > 0 { + m.selected-- + m.ensureSelectedVisible() + } + case "down", "j": + if m.selected < len(m.profiles)-1 { + m.selected++ + m.ensureSelectedVisible() + } + case "enter": + m.submitted = true + return m, tea.Quit + } + } + return m, nil +} + +func (m profileSwitchTUIModel) View() string { + var b strings.Builder + title := profileSwitchTitleStyle().Render("选择要切换的组织") + hint := profileSwitchMutedStyle().Render("全部已登录 profile,↑↓ 选择,Enter 确认") + b.WriteString(title) + b.WriteString("\n") + b.WriteString(hint) + b.WriteString("\n\n") + b.WriteString(m.tableView()) + b.WriteString("\n") + b.WriteString(profileSwitchMutedStyle().Render("↑/k up • ↓/j down • enter submit • esc cancel")) + return b.String() +} + +func (m profileSwitchTUIModel) tableView() string { + rows := []string{ + profileSwitchBorder("┌", "┬", "┐"), + profileSwitchStyledTableLine("组织名", "本地状态", profileSwitchHeaderStyle()), + profileSwitchBorder("├", "┼", "┤"), + } + for i := 0; i < profileSwitchVisibleOptions; i++ { + idx := m.offset + i + if idx >= 0 && idx < len(m.profiles) { + rows = append(rows, m.profileRow(idx)) + continue + } + rows = append(rows, profileSwitchStyledTableLine("", "", profileSwitchNormalRowStyle())) + } + rows = append(rows, profileSwitchBorder("└", "┴", "┘")) + return strings.Join(rows, "\n") +} + +func (m profileSwitchTUIModel) profileRow(idx int) string { + profile := m.profiles[idx] + org, status := profileSwitchProfileCells(profile, m.cfg) + style := profileSwitchNormalRowStyle() + if idx == m.selected { + org = "› " + org + style = profileSwitchSelectedRowStyle() + } else { + org = " " + org + } + return profileSwitchStyledTableLine(org, status, style) +} + +func (m *profileSwitchTUIModel) ensureSelectedVisible() { + if len(m.profiles) == 0 { + m.selected = 0 + m.offset = 0 + return + } + if m.selected < 0 { + m.selected = 0 + } + if m.selected >= len(m.profiles) { + m.selected = len(m.profiles) - 1 + } + if m.selected < m.offset { + m.offset = m.selected + } + if m.selected >= m.offset+profileSwitchVisibleOptions { + m.offset = m.selected - profileSwitchVisibleOptions + 1 + } + maxOffset := len(m.profiles) - profileSwitchVisibleOptions + if maxOffset < 0 { + maxOffset = 0 + } + if m.offset > maxOffset { + m.offset = maxOffset + } + if m.offset < 0 { + m.offset = 0 + } +} + +func (m profileSwitchTUIModel) selectedCorpID() string { + if m.selected < 0 || m.selected >= len(m.profiles) { + return "" + } + return strings.TrimSpace(m.profiles[m.selected].CorpID) +} + +func profileSwitchProfileIndex(profiles []authpkg.Profile, corpID string) int { + corpID = strings.TrimSpace(corpID) + for i, p := range profiles { + if strings.TrimSpace(p.CorpID) == corpID { + return i + } + } + return -1 +} + +func profileSwitchOptionLabel(p authpkg.Profile, cfg *authpkg.ProfilesConfig) string { + org, status := profileSwitchProfileCells(p, cfg) + if status == "" { + return org + } + return strings.Join([]string{org, status}, " | ") +} + +func profileSwitchProfileCells(p authpkg.Profile, cfg *authpkg.ProfilesConfig) (string, string) { + return profileOrgName(p), profileSwitchProfileStatus(p, cfg) +} + +func profileSwitchProfileStatus(p authpkg.Profile, cfg *authpkg.ProfilesConfig) string { + if cfg != nil && p.CorpID == cfg.CurrentProfile { + return "当前组织" + } + return "" +} + +func profileSwitchBorder(left, sep, right string) string { + segments := []string{ + strings.Repeat("─", profileSwitchCellWidth(profileSwitchOrgWidth)), + strings.Repeat("─", profileSwitchCellWidth(profileSwitchStatusWidth)), + } + return profileSwitchBorderStyle().Render(left + strings.Join(segments, sep) + right) +} + +func profileSwitchTableLine(org, status string) string { + cells := []string{ + profileSwitchTableCell(org, profileSwitchOrgWidth), + profileSwitchTableCell(status, profileSwitchStatusWidth), + } + return "│" + strings.Join(cells, "│") + "│" +} + +func profileSwitchStyledTableLine(org, status string, style lipgloss.Style) string { + cells := []string{ + style.Render(profileSwitchTableCell(org, profileSwitchOrgWidth)), + style.Render(profileSwitchTableCell(status, profileSwitchStatusWidth)), + } + return profileSwitchTableSeparator() + strings.Join(cells, profileSwitchTableSeparator()) + profileSwitchTableSeparator() +} + +func profileSwitchTableSeparator() string { + return profileSwitchBorderStyle().Render("│") +} + +func profileSwitchTableCell(value string, width int) string { + clipped := clipProfileDisplayCell(strings.TrimSpace(value), width) + padding := strings.Repeat(" ", profileSwitchCellPadding) + return padding + padProfileDisplayCell(clipped, width) + padding +} + +func padProfileDisplayCell(value string, width int) string { + padding := width - lipgloss.Width(value) + if padding < 0 { + padding = 0 + } + return value + strings.Repeat(" ", padding) +} + +func profileSwitchCellWidth(contentWidth int) int { + return contentWidth + profileSwitchCellPadding*2 +} + +func profileSwitchSelectedRowStyle() lipgloss.Style { + return lipgloss.NewStyle().Renderer(profileSwitchRenderer).Foreground(lipgloss.Color("#69B1FF")).Bold(true) +} + +func profileSwitchNormalRowStyle() lipgloss.Style { + return lipgloss.NewStyle().Renderer(profileSwitchRenderer).Foreground(lipgloss.Color("#FFFFFF")) +} + +func profileSwitchHeaderStyle() lipgloss.Style { + return profileSwitchMutedStyle().Bold(true) +} + +func profileSwitchBorderStyle() lipgloss.Style { + return lipgloss.NewStyle().Renderer(profileSwitchRenderer).Foreground(lipgloss.Color("#2F3B52")) +} + +func profileSwitchTitleStyle() lipgloss.Style { + return lipgloss.NewStyle().Renderer(profileSwitchRenderer).Foreground(lipgloss.Color("#69B1FF")).Bold(true) +} + +func profileSwitchMutedStyle() lipgloss.Style { + return lipgloss.NewStyle().Renderer(profileSwitchRenderer).Foreground(lipgloss.Color("#8A96A8")) +} + +type profileListResponse struct { + Success bool `json:"success"` + PrimaryProfile string `json:"primaryProfile,omitempty"` + CurrentProfile string `json:"currentProfile,omitempty"` + PreviousProfile string `json:"previousProfile,omitempty"` + Profiles []profileView `json:"profiles"` +} + +type profileUseResponse struct { + Success bool `json:"success"` + Profile profileView `json:"profile"` +} + +type profileView struct { + CorpID string `json:"corpId"` + CorpName string `json:"corpName"` + UserID string `json:"userId,omitempty"` + UserName string `json:"userName,omitempty"` + ClientID string `json:"clientId,omitempty"` + Status string `json:"status,omitempty"` + AuthorizedDomains []string `json:"authorizedDomains,omitempty"` + ExpiresAt string `json:"expiresAt,omitempty"` + RefreshExpAt string `json:"refreshExpAt,omitempty"` + LastLoginAt string `json:"lastLoginAt,omitempty"` + LastUsedAt string `json:"lastUsedAt,omitempty"` + IsPrimary bool `json:"isPrimary"` + IsCurrent bool `json:"isCurrent"` +} + +func writeProfileListJSON(w io.Writer, cfg *authpkg.ProfilesConfig) error { + resp := profileListResponse{ + Success: true, + PrimaryProfile: cfg.PrimaryProfile, + CurrentProfile: cfg.CurrentProfile, + PreviousProfile: cfg.PreviousProfile, + Profiles: profileViews(cfg), + } + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(resp) +} + +func writeProfileUseJSON(w io.Writer, profile *authpkg.Profile, cfg *authpkg.ProfilesConfig) error { + resp := profileUseResponse{Success: true} + if profile != nil { + primaryProfile := "" + currentProfile := "" + if cfg != nil { + primaryProfile = cfg.PrimaryProfile + currentProfile = cfg.CurrentProfile + } + resp.Profile = profileViewFromProfile(*profile, primaryProfile, currentProfile) + } + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(resp) +} + +func writeProfileListTable(w io.Writer, cfg *authpkg.ProfilesConfig) { + if cfg == nil || len(cfg.Profiles) == 0 { + fmt.Fprintln(w, "未找到已登录 profile") + return + } + fmt.Fprintf(w, "%-3s %-3s %-28s %-34s %-10s %s\n", "CUR", "PRI", "ORG_NAME", "CORP_ID", "STATUS", "USER") + for _, p := range cfg.Profiles { + current := "" + if p.CorpID == cfg.CurrentProfile { + current = "*" + } + primary := "" + if p.CorpID == cfg.PrimaryProfile { + primary = "*" + } + user := p.UserName + if user == "" { + user = p.UserID + } + status := p.Status + if status == "" { + status = authpkg.ProfileStatusActive + } + fmt.Fprintf( + w, + "%-3s %-3s %-28s %-34s %-10s %s\n", + current, + primary, + clipProfileCell(profileOrgName(p), 28), + clipProfileCell(p.CorpID, 34), + status, + user, + ) + } +} + +func profileUseMessage(profile *authpkg.Profile) string { + if profile == nil { + return "[OK] 当前 profile 已切换" + } + corpID := strings.TrimSpace(profile.CorpID) + orgName := strings.TrimSpace(profile.CorpName) + if orgName == "" { + orgName = profileOrgName(*profile) + } + return fmt.Sprintf("[OK] 当前组织: %s (%s)", orgName, corpID) +} + +func profileOrgName(p authpkg.Profile) string { + if v := strings.TrimSpace(p.CorpName); v != "" { + return v + } + if v := strings.TrimSpace(p.Name); v != "" { + return v + } + return strings.TrimSpace(p.CorpID) +} + +func profileViews(cfg *authpkg.ProfilesConfig) []profileView { + if cfg == nil { + return nil + } + views := make([]profileView, 0, len(cfg.Profiles)) + for _, p := range cfg.Profiles { + views = append(views, profileViewFromProfile(p, cfg.PrimaryProfile, cfg.CurrentProfile)) + } + return views +} + +func profileViewFromProfile(p authpkg.Profile, primaryProfile, currentProfile string) profileView { + return profileView{ + CorpID: p.CorpID, + CorpName: profileOrgName(p), + UserID: p.UserID, + UserName: p.UserName, + ClientID: p.ClientID, + Status: p.Status, + AuthorizedDomains: p.AuthorizedDomains, + ExpiresAt: p.ExpiresAt, + RefreshExpAt: p.RefreshExpAt, + LastLoginAt: p.LastLoginAt, + LastUsedAt: p.LastUsedAt, + IsPrimary: p.CorpID == primaryProfile, + IsCurrent: p.CorpID == currentProfile, + } +} + +func clipProfileCell(value string, limit int) string { + if limit <= 0 { + return "" + } + runes := []rune(value) + if len(runes) <= limit { + return value + } + if limit <= 3 { + return string(runes[:limit]) + } + return string(runes[:limit-3]) + "..." +} + +func clipProfileDisplayCell(value string, limit int) string { + if limit <= 0 { + return "" + } + if lipgloss.Width(value) <= limit { + return value + } + if limit <= 3 { + var b strings.Builder + for _, r := range value { + rw := lipgloss.Width(string(r)) + if lipgloss.Width(b.String())+rw > limit { + break + } + b.WriteRune(r) + } + return b.String() + } + target := limit - 3 + var b strings.Builder + width := 0 + for _, r := range value { + rw := lipgloss.Width(string(r)) + if width+rw > target { + break + } + b.WriteRune(r) + width += rw + } + return b.String() + "..." +} diff --git a/internal/app/profile_command_test.go b/internal/app/profile_command_test.go new file mode 100644 index 00000000..e4a28cd9 --- /dev/null +++ b/internal/app/profile_command_test.go @@ -0,0 +1,582 @@ +// Copyright 2026 Alibaba Group +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package app + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + "testing" + + authpkg "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/auth" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/spf13/cobra" +) + +func TestWriteProfileUseJSONKeepsPrimaryAndCurrentDistinct(t *testing.T) { + profile := &authpkg.Profile{ + Name: "B Org", + CorpID: "corp_b", + CorpName: "B Org", + Status: authpkg.ProfileStatusActive, + } + cfg := &authpkg.ProfilesConfig{ + PrimaryProfile: "corp_a", + CurrentProfile: "corp_b", + } + var buf bytes.Buffer + if err := writeProfileUseJSON(&buf, profile, cfg); err != nil { + t.Fatalf("writeProfileUseJSON() error = %v", err) + } + var resp profileUseResponse + if err := json.Unmarshal(buf.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if bytes.Contains(buf.Bytes(), []byte(`"name"`)) { + t.Fatalf("profile use JSON should not contain name when corpName is present:\n%s", buf.String()) + } + if resp.Profile.CorpName != "B Org" { + t.Fatalf("corpName = %q, want B Org", resp.Profile.CorpName) + } + if !resp.Profile.IsCurrent { + t.Fatalf("isCurrent = false, want true") + } + if resp.Profile.IsPrimary { + t.Fatalf("isPrimary = true, want false") + } +} + +func TestProfileListRootCommandJSONIncludesCorpName(t *testing.T) { + setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--format", "json", "profile", "list"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("profile list --format json error = %v\noutput:\n%s", err, out.String()) + } + var resp profileListResponse + if err := json.Unmarshal(out.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v\noutput:\n%s", err, out.String()) + } + if !resp.Success { + t.Fatal("success = false, want true") + } + if resp.PrimaryProfile != "corp_primary" || resp.CurrentProfile != "corp_secondary" || resp.PreviousProfile != "corp_primary" { + t.Fatalf("profile pointers = primary %q current %q previous %q, want corp_primary/corp_secondary/corp_primary", resp.PrimaryProfile, resp.CurrentProfile, resp.PreviousProfile) + } + if len(resp.Profiles) != 2 { + t.Fatalf("profiles len = %d, want 2", len(resp.Profiles)) + } + if bytes.Contains(out.Bytes(), []byte(`"name"`)) { + t.Fatalf("profile list JSON should not contain name when corpName is present:\n%s", out.String()) + } + for _, p := range resp.Profiles { + if p.CorpName == "" { + t.Fatalf("profile %s missing corpName in JSON response: %#v", p.CorpID, p) + } + } +} + +func TestProfileUseRootCommandSwitchesOrganizationAndLegacyMirror(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--format", "table", "profile", "use", "corp_primary"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("profile use corp_primary error = %v\noutput:\n%s", err, out.String()) + } + if !bytes.Contains(out.Bytes(), []byte("组织: corp_primary org")) { + t.Fatalf("profile use output should include organization name:\n%s", out.String()) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_primary" || cfg.PreviousProfile != "corp_secondary" { + t.Fatalf("profile pointers = current %q previous %q, want corp_primary/corp_secondary", cfg.CurrentProfile, cfg.PreviousProfile) + } + legacyToken, err := authpkg.LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() error = %v", err) + } + if legacyToken.CorpID != "corp_primary" { + t.Fatalf("legacy token corp = %q, want corp_primary", legacyToken.CorpID) + } + + cmd = NewRootCommand() + out.Reset() + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--format", "table", "profile", "use", "-"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("profile use - error = %v\noutput:\n%s", err, out.String()) + } + if !bytes.Contains(out.Bytes(), []byte("组织: corp_secondary org")) { + t.Fatalf("profile use - output should include organization name:\n%s", out.String()) + } + cfg, err = authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_secondary" || cfg.PreviousProfile != "corp_primary" { + t.Fatalf("profile pointers = current %q previous %q, want corp_secondary/corp_primary", cfg.CurrentProfile, cfg.PreviousProfile) + } + legacyToken, err = authpkg.LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() error = %v", err) + } + if legacyToken.CorpID != "corp_secondary" { + t.Fatalf("legacy token corp = %q, want corp_secondary", legacyToken.CorpID) + } +} + +func TestProfileSwitchRootCommandSwitchesPrimaryOrganizationAndLegacyMirror(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--format", "table", "profile", "switch", "corp_primary"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("profile switch corp_primary error = %v\noutput:\n%s", err, out.String()) + } + if !bytes.Contains(out.Bytes(), []byte("组织: corp_primary org")) { + t.Fatalf("profile switch output should include organization name:\n%s", out.String()) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_primary" || cfg.PreviousProfile != "corp_secondary" { + t.Fatalf("profile pointers = current %q previous %q, want corp_primary/corp_secondary", cfg.CurrentProfile, cfg.PreviousProfile) + } + legacyToken, err := authpkg.LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() error = %v", err) + } + if legacyToken.CorpID != "corp_primary" { + t.Fatalf("legacy token corp = %q, want corp_primary", legacyToken.CorpID) + } +} + +func TestProfileSwitchRootCommandSupportsCorpIDFlag(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--format", "table", "profile", "switch", "--corpId", "corp_primary"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("profile switch --corpId error = %v\noutput:\n%s", err, out.String()) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_primary" { + t.Fatalf("currentProfile = %q, want corp_primary", cfg.CurrentProfile) + } + + cmd = NewRootCommand() + out.Reset() + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--format", "table", "profile", "use", "--corp", "corp_secondary"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("profile use --corp error = %v\noutput:\n%s", err, out.String()) + } + cfg, err = authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_secondary" { + t.Fatalf("currentProfile = %q, want corp_secondary", cfg.CurrentProfile) + } +} + +func TestProfileSwitchRootCommandRejectsConflictingSelectors(t *testing.T) { + setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"profile", "switch", "corp_primary", "--corpId", "corp_secondary"}) + err := cmd.Execute() + if err == nil { + t.Fatalf("profile switch with conflicting selectors succeeded\noutput:\n%s", out.String()) + } + if !strings.Contains(err.Error(), "只能指定一个组织选择器") { + t.Fatalf("error = %v, want conflicting selector validation", err) + } +} + +func TestProfileSwitchNoArgsUsesTUISelector(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + oldSelector := profileSwitchSelector + t.Cleanup(func() { + profileSwitchSelector = oldSelector + }) + called := false + profileSwitchSelector = func(cmd *cobra.Command, gotConfigDir string) (string, error) { + called = true + if gotConfigDir != configDir { + t.Fatalf("configDir = %q, want %q", gotConfigDir, configDir) + } + return "corp_primary", nil + } + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"profile", "switch"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("profile switch error = %v\noutput:\n%s", err, out.String()) + } + if !called { + t.Fatal("profile switch without args did not invoke TUI selector") + } + if !bytes.Contains(out.Bytes(), []byte("组织: corp_primary org")) { + t.Fatalf("profile switch TUI path should use human output by default:\n%s", out.String()) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_primary" { + t.Fatalf("currentProfile = %q, want corp_primary", cfg.CurrentProfile) + } +} + +func TestProfileSwitchOptionLabelUsesOnlyOrganizationAndCurrentState(t *testing.T) { + cfg := &authpkg.ProfilesConfig{ + PrimaryProfile: "corp_primary", + CurrentProfile: "corp_secondary", + Profiles: []authpkg.Profile{ + { + CorpID: "corp_primary", + CorpName: "第一组织", + UserName: "alice", + Status: authpkg.ProfileStatusActive, + }, + { + CorpID: "corp_secondary", + CorpName: "第二组织", + UserName: "bob", + Status: authpkg.ProfileStatusActive, + }, + }, + } + primary := profileSwitchOptionLabel(cfg.Profiles[0], cfg) + current := profileSwitchOptionLabel(cfg.Profiles[1], cfg) + for _, label := range []string{primary, current} { + if strings.Contains(label, "\n") { + t.Fatalf("profile switch label contains newline: %q", label) + } + } + if !strings.Contains(primary, "第一组织") { + t.Fatalf("primary option missing organization name: %q", primary) + } + if !strings.Contains(current, "当前组织") { + t.Fatalf("current option missing current marker: %q", current) + } + for _, unwanted := range []string{"alice", "bob", "已登录", "主组织", "corp_primary", "corp_secondary"} { + if strings.Contains(primary, unwanted) || strings.Contains(current, unwanted) { + t.Fatalf("profile switch option should not contain %q: %q / %q", unwanted, primary, current) + } + } +} + +func TestProfileSwitchTUIViewUsesFixedOuterTable(t *testing.T) { + cfg := profileSwitchTestConfig(2) + model := newProfileSwitchTUIModel(cfg, "corp_00") + view := model.tableView() + if lines := strings.Split(view, "\n"); len(lines) != profileSwitchVisibleOptions+4 { + t.Fatalf("table line count = %d, want %d:\n%s", len(lines), profileSwitchVisibleOptions+4, view) + } + for _, want := range []string{"┌", "┬", "┐", "├", "┼", "┤", "└", "┴", "┘", "组织名", "本地状态"} { + if !strings.Contains(view, want) { + t.Fatalf("profile switch table missing %q in:\n%s", want, view) + } + } + for _, unwanted := range []string{"CORP_ID", "ORGANIZATION", "STATUS"} { + if strings.Contains(view, unwanted) { + t.Fatalf("profile switch table should not contain %q:\n%s", unwanted, view) + } + } + if got := strings.Count(view, "│"); got != (profileSwitchVisibleOptions+1)*3 { + t.Fatalf("table vertical separators = %d, want %d\n%s", got, (profileSwitchVisibleOptions+1)*3, view) + } + for _, profile := range cfg.Profiles { + if got := strings.Count(view, profile.CorpID); got != 0 { + t.Fatalf("profile corpId %s appears %d times, want hidden:\n%s", profile.CorpID, got, view) + } + } +} + +func TestProfileSwitchTUISortsLatestLoggedInProfilesFirst(t *testing.T) { + cfg := &authpkg.ProfilesConfig{ + PrimaryProfile: "old", + CurrentProfile: "old", + Profiles: []authpkg.Profile{ + {CorpID: "old", CorpName: "旧组织", LastLoginAt: "2026-06-26T10:00:00+08:00"}, + {CorpID: "new", CorpName: "新组织", LastLoginAt: "2026-06-26T12:00:00+08:00"}, + {CorpID: "fallback", CorpName: "兜底组织", UpdatedAt: "2026-06-26T11:00:00+08:00"}, + }, + } + model := newProfileSwitchTUIModel(cfg, "old") + gotOrder := []string{model.profiles[0].CorpID, model.profiles[1].CorpID, model.profiles[2].CorpID} + wantOrder := []string{"new", "fallback", "old"} + if strings.Join(gotOrder, ",") != strings.Join(wantOrder, ",") { + t.Fatalf("profile order = %v, want %v", gotOrder, wantOrder) + } + if got := model.selectedCorpID(); got != "old" { + t.Fatalf("selectedCorpID = %q, want old", got) + } +} + +func TestProfileSwitchTUIArrowKeysMoveSelectionWithoutDuplicatingRows(t *testing.T) { + cfg := profileSwitchTestConfig(7) + model := newProfileSwitchTUIModel(cfg, "corp_00") + for step := 0; step < 6; step++ { + view := model.tableView() + if got := strings.Count(view, "›"); got != 1 { + t.Fatalf("step %d selected cursor count = %d, want 1:\n%s", step, got, view) + } + for _, profile := range cfg.Profiles { + name := profileOrgName(profile) + if got := strings.Count(view, name); got > 1 { + t.Fatalf("step %d profile %s appears %d times, want at most once:\n%s", step, name, got, view) + } + } + next, _ := model.Update(tea.KeyMsg{Type: tea.KeyDown}) + model = next.(profileSwitchTUIModel) + } + if model.selected != 6 || model.offset != 2 { + t.Fatalf("selection after down keys = selected %d offset %d, want 6/2", model.selected, model.offset) + } +} + +func TestProfileSwitchTableRowsKeepFixedDisplayWidth(t *testing.T) { + rows := []string{ + profileSwitchTableLine("组织名", "本地状态"), + profileSwitchTableLine("› 钉钉(中国)信息技术有限公司", "当前组织"), + profileSwitchTableLine(" ACME", ""), + profileSwitchTableLine("", ""), + profileSwitchStyledTableLine("组织名", "本地状态", profileSwitchHeaderStyle()), + profileSwitchStyledTableLine("› 钉钉(中国)信息技术有限公司", "当前组织", profileSwitchSelectedRowStyle()), + profileSwitchStyledTableLine(" ACME", "", profileSwitchNormalRowStyle()), + profileSwitchStyledTableLine("", "", profileSwitchNormalRowStyle()), + } + wantWidth := lipgloss.Width(rows[0]) + for i, row := range rows { + if got := lipgloss.Width(row); got != wantWidth { + t.Fatalf("row[%d] width = %d, want %d: %q", i, got, wantWidth, row) + } + if got := strings.Count(row, "│"); got != 3 { + t.Fatalf("row[%d] separator count = %d, want 3: %q", i, got, row) + } + } +} + +func TestProfileSwitchOptionLabelHidesCorpID(t *testing.T) { + const corpID = "ding8196cd9a2b2405da24f2f5cc6abecb85" + cfg := &authpkg.ProfilesConfig{ + PrimaryProfile: corpID, + CurrentProfile: corpID, + } + label := profileSwitchOptionLabel(authpkg.Profile{ + CorpID: corpID, + CorpName: "钉钉", + }, cfg) + for _, want := range []string{"钉钉", "当前组织"} { + if !strings.Contains(label, want) { + t.Fatalf("profile switch label missing %q in %q", want, label) + } + } + for _, unwanted := range []string{"ding8196", "cb85", "主组织"} { + if strings.Contains(label, unwanted) { + t.Fatalf("profile switch label should not contain %q in %q", unwanted, label) + } + } +} + +func profileSwitchTestConfig(count int) *authpkg.ProfilesConfig { + cfg := &authpkg.ProfilesConfig{ + PrimaryProfile: "corp_00", + CurrentProfile: "corp_00", + } + for i := 0; i < count; i++ { + corpID := fmt.Sprintf("corp_%02d", i) + cfg.Profiles = append(cfg.Profiles, authpkg.Profile{ + CorpID: corpID, + CorpName: fmt.Sprintf("组织%02d", i), + Status: authpkg.ProfileStatusActive, + }) + } + return cfg +} + +func TestAuthCommandDoesNotExposeSwitch(t *testing.T) { + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"auth", "switch"}) + err := cmd.Execute() + if err == nil { + t.Fatalf("auth switch succeeded, want unknown command error\noutput:\n%s", out.String()) + } + if !strings.Contains(err.Error(), `unknown command "switch" for "dws auth"`) { + t.Fatalf("error = %v, want auth switch unknown command", err) + } +} + +func TestProfileUseNoArgsUsesTUISelector(t *testing.T) { + configDir := setupAuthLogoutProfiles(t, + authLogoutTestToken("corp_primary"), + authLogoutTestToken("corp_secondary"), + ) + oldSelector := profileSwitchSelector + t.Cleanup(func() { + profileSwitchSelector = oldSelector + }) + profileSwitchSelector = func(cmd *cobra.Command, gotConfigDir string) (string, error) { + if gotConfigDir != configDir { + t.Fatalf("configDir = %q, want %q", gotConfigDir, configDir) + } + return "corp_primary", nil + } + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"profile", "use"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("profile use error = %v\noutput:\n%s", err, out.String()) + } + if !bytes.Contains(out.Bytes(), []byte("组织: corp_primary org")) { + t.Fatalf("profile use TUI path should use human output by default:\n%s", out.String()) + } + cfg, err := authpkg.LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_primary" { + t.Fatalf("currentProfile = %q, want corp_primary", cfg.CurrentProfile) + } +} + +func TestProfileSwitchSelectorRequiresInteractiveTerminal(t *testing.T) { + oldInteractive := profileSwitchInteractiveTerminal + t.Cleanup(func() { + profileSwitchInteractiveTerminal = oldInteractive + }) + profileSwitchInteractiveTerminal = func() bool { return false } + + _, err := selectProfileSwitchProfile(nil, t.TempDir()) + if err == nil { + t.Fatal("selectProfileSwitchProfile() succeeded, want validation error") + } + if !bytes.Contains([]byte(err.Error()), []byte("profile selector required")) { + t.Fatalf("error = %v, want profile selector hint", err) + } +} + +func TestWriteProfileListTableIncludesCorpName(t *testing.T) { + cfg := &authpkg.ProfilesConfig{ + PrimaryProfile: "corp_a", + CurrentProfile: "corp_b", + Profiles: []authpkg.Profile{ + { + Name: "DingTalk China", + CorpID: "corp_a", + CorpName: "钉钉(中国)信息技术有限公司", + UserName: "alice", + Status: authpkg.ProfileStatusActive, + }, + { + Name: "B Org", + CorpID: "corp_b", + CorpName: "B 组织", + UserID: "bob-id", + }, + }, + } + var buf bytes.Buffer + writeProfileListTable(&buf, cfg) + out := buf.String() + for _, want := range []string{ + "ORG_NAME", + "钉钉(中国)信息技术有限公司", + "B 组织", + "corp_a", + "corp_b", + } { + if !bytes.Contains(buf.Bytes(), []byte(want)) { + t.Fatalf("profile list table missing %q in output:\n%s", want, out) + } + } + for _, unwanted := range []string{"PROFILE", "DingTalk China"} { + if bytes.Contains(buf.Bytes(), []byte(unwanted)) { + t.Fatalf("profile list table should not contain %q in output:\n%s", unwanted, out) + } + } +} + +func TestProfileUseMessageIncludesCorpName(t *testing.T) { + got := profileUseMessage(&authpkg.Profile{ + Name: "DingTalk China", + CorpID: "ding8196", + CorpName: "钉钉(中国)信息技术有限公司", + }) + for _, want := range []string{"当前组织: 钉钉(中国)信息技术有限公司", "ding8196"} { + if !bytes.Contains([]byte(got), []byte(want)) { + t.Fatalf("profileUseMessage() missing %q in %q", want, got) + } + } + if bytes.Contains([]byte(got), []byte("DingTalk China")) { + t.Fatalf("profileUseMessage() should not include profile name when corpName is present: %q", got) + } +} diff --git a/internal/app/profile_product_command_test.go b/internal/app/profile_product_command_test.go new file mode 100644 index 00000000..795ff1f8 --- /dev/null +++ b/internal/app/profile_product_command_test.go @@ -0,0 +1,158 @@ +package app + +import ( + "bytes" + "context" + "sync" + "testing" + + authpkg "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/auth" + "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/compat" + "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/executor" + "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/market" + "github.com/spf13/cobra" +) + +func TestProductCommandsAcceptGlobalProfileFlag(t *testing.T) { + const selectedProfile = "corp_profile_matrix" + + products := []struct { + name string + path []string + tool string + }{ + {name: "aitable", path: []string{"aitable", "profile-test", "probe"}, tool: "aitable_profile_probe"}, + {name: "attendance", path: []string{"attendance", "profile-test", "probe"}, tool: "attendance_profile_probe"}, + {name: "calendar", path: []string{"calendar", "profile-test", "probe"}, tool: "calendar_profile_probe"}, + {name: "contact", path: []string{"contact", "profile-test", "probe"}, tool: "contact_profile_probe"}, + {name: "devdoc", path: []string{"devdoc", "profile-test", "probe"}, tool: "devdoc_profile_probe"}, + {name: "ding", path: []string{"ding", "profile-test", "probe"}, tool: "ding_profile_probe"}, + {name: "report", path: []string{"report", "profile-test", "probe"}, tool: "report_profile_probe"}, + {name: "todo", path: []string{"todo", "profile-test", "probe"}, tool: "todo_profile_probe"}, + } + + descriptors := make([]market.ServerDescriptor, 0, len(products)) + for _, product := range products { + descriptors = append(descriptors, profileFlagProductDescriptor(product.name, product.tool)) + } + + capture := &profileFlagRunner{} + oldLoadDynamicCommands := loadDynamicCommandsFn + loadDynamicCommandsFn = func(_ context.Context, _ executor.Runner) []*cobra.Command { + SetDynamicServers(descriptors) + return compat.BuildDynamicCommands(descriptors, capture, nil, nil) + } + authpkg.SetRuntimeProfile("") + ResetRuntimeTokenCache() + t.Cleanup(func() { + loadDynamicCommandsFn = oldLoadDynamicCommands + SetDynamicServers(nil) + authpkg.SetRuntimeProfile("") + ResetRuntimeTokenCache() + }) + + for _, product := range products { + t.Run(product.name, func(t *testing.T) { + capture.reset() + authpkg.SetRuntimeProfile("") + + cmd := NewRootCommand() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + args := append([]string{"-f", "json"}, product.path...) + args = append(args, "--profile", selectedProfile) + cmd.SetArgs(args) + + // Arrange / Act: execute a product command with root --profile after the leaf. + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute(%v) error = %v\noutput:\n%s", args, err, out.String()) + } + + // Assert: the product tool runs under the selected profile without leaking it as a business arg. + call := capture.last() + if call == nil { + t.Fatal("expected product command to invoke runner") + } + if call.product != product.name { + t.Fatalf("canonical product = %q, want %q", call.product, product.name) + } + if call.tool != product.tool { + t.Fatalf("tool = %q, want %q", call.tool, product.tool) + } + if call.profile != selectedProfile { + t.Fatalf("runtime profile at execution = %q, want %q", call.profile, selectedProfile) + } + if _, ok := call.params["profile"]; ok { + t.Fatalf("--profile leaked into business params: %#v", call.params) + } + }) + } +} + +func profileFlagProductDescriptor(product, tool string) market.ServerDescriptor { + return market.ServerDescriptor{ + Key: product, + DisplayName: product, + Endpoint: "https://example.invalid/" + product, + CLI: market.CLIOverlay{ + ID: product, + Command: product, + Groups: map[string]market.CLIGroupDef{ + "profile-test": {Description: "profile-test"}, + }, + ToolOverrides: map[string]market.CLIToolOverride{ + tool: { + CLIName: "probe", + Group: "profile-test", + Description: tool, + RejectPositional: true, + }, + }, + }, + } +} + +type profileFlagCall struct { + product string + tool string + profile string + params map[string]any +} + +type profileFlagRunner struct { + mu sync.Mutex + calls []profileFlagCall +} + +func (r *profileFlagRunner) Run(_ context.Context, invocation executor.Invocation) (executor.Result, error) { + r.mu.Lock() + defer r.mu.Unlock() + params := make(map[string]any, len(invocation.Params)) + for key, value := range invocation.Params { + params[key] = value + } + r.calls = append(r.calls, profileFlagCall{ + product: invocation.CanonicalProduct, + tool: invocation.Tool, + profile: authpkg.RuntimeProfile(), + params: params, + }) + return executor.Result{Invocation: invocation}, nil +} + +func (r *profileFlagRunner) reset() { + r.mu.Lock() + defer r.mu.Unlock() + r.calls = nil +} + +func (r *profileFlagRunner) last() *profileFlagCall { + r.mu.Lock() + defer r.mu.Unlock() + if len(r.calls) == 0 { + return nil + } + call := r.calls[len(r.calls)-1] + return &call +} diff --git a/internal/app/root.go b/internal/app/root.go index a23c03e3..fa6b0971 100644 --- a/internal/app/root.go +++ b/internal/app/root.go @@ -67,6 +67,9 @@ func Execute() (exitCode int) { } }() + restoreArgs := normalizeProcessProfileArgs() + defer restoreArgs() + timing := NewTimingCollector() defer func() { StopAllStdioClients() // Ensure child processes are terminated on exit @@ -298,6 +301,7 @@ func NewRootCommandWithEngine(rootCtx context.Context, engine *pipeline.Engine) rootCtx = context.Background() } flags := &GlobalFlags{} + authpkg.SetRuntimeProfile(preparseProfileFlag(os.Args[1:])) loader := cli.EnvironmentLoader{ LookupEnv: os.LookupEnv, CatalogBaseURLOverride: DiscoveryBaseURL(), @@ -321,6 +325,7 @@ func NewRootCommandWithEngine(rootCtx context.Context, engine *pipeline.Engine) return cmd.Help() }, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + authpkg.SetRuntimeProfile(flags.Profile) // Apply OAuth credential overrides from CLI flags (highest priority). if flags.ClientID != "" { authpkg.SetClientID(flags.ClientID) @@ -358,6 +363,7 @@ func NewRootCommandWithEngine(rootCtx context.Context, engine *pipeline.Engine) utilityCommands := []*cobra.Command{ newAuthCommand(patCaller), + newProfileCommand(), newAPICommand(flags), newSkillCommand(), newCacheCommand(), @@ -404,6 +410,85 @@ func NewRootCommandWithEngine(rootCtx context.Context, engine *pipeline.Engine) return root } +func preparseProfileFlag(args []string) string { + args, _ = normalizeProfileFlagArgs(args) + for i := 0; i < len(args); i++ { + arg := strings.TrimSpace(args[i]) + switch { + case arg == "--profile" && i+1 < len(args): + return strings.TrimSpace(args[i+1]) + case strings.HasPrefix(arg, "--profile="): + return strings.TrimSpace(strings.TrimPrefix(arg, "--profile=")) + } + } + return "" +} + +func normalizeProcessProfileArgs() func() { + original := append([]string(nil), os.Args...) + if len(os.Args) > 1 { + if normalized, changed := normalizeProfileFlagArgs(os.Args[1:]); changed { + os.Args = append([]string{os.Args[0]}, normalized...) + } + } + return func() { + os.Args = original + } +} + +func normalizeProfileFlagArgs(args []string) ([]string, bool) { + if len(args) == 0 { + return args, false + } + out := make([]string, 0, len(args)) + for i := 0; i < len(args); i++ { + arg := args[i] + trimmed := strings.TrimSpace(arg) + switch { + case trimmed == "--profile": + out = append(out, arg) + if i+1 >= len(args) { + continue + } + value, next := collectProfileFlagValue(args[i+1], args, i+2) + out = append(out, value) + i = next - 1 + case strings.HasPrefix(trimmed, "--profile="): + value, next := collectProfileFlagValue(strings.TrimPrefix(trimmed, "--profile="), args, i+1) + out = append(out, "--profile="+value) + i = next - 1 + default: + out = append(out, arg) + } + } + return out, argsChanged(args, out) +} + +func collectProfileFlagValue(first string, args []string, next int) (string, int) { + parts := []string{strings.TrimSpace(first)} + for len(parts) > 0 && strings.HasSuffix(strings.TrimSpace(parts[len(parts)-1]), ",") && next < len(args) { + candidate := strings.TrimSpace(args[next]) + if candidate == "" || strings.HasPrefix(candidate, "-") { + break + } + parts = append(parts, candidate) + next++ + } + return strings.Join(parts, ""), next +} + +func argsChanged(before, after []string) bool { + if len(before) != len(after) { + return true + } + for i := range before { + if before[i] != after[i] { + return true + } + } + return false +} + func newAuthCommand(patCaller edition.ToolCaller) *cobra.Command { return buildAuthCommand(patCaller) } @@ -802,6 +887,7 @@ func hideNonDirectRuntimeCommands(root *cobra.Command) { "completion": true, "skill": true, "plugin": true, + "profile": true, "version": true, "help": true, "recovery": true, @@ -828,7 +914,7 @@ func hideNonDirectRuntimeCommands(root *cobra.Command) { // by a malicious or misconfigured plugin. var reservedCommands = map[string]bool{ "auth": true, "api": true, "login": true, "logout": true, - "plugin": true, "skill": true, "cache": true, + "plugin": true, "profile": true, "skill": true, "cache": true, "config": true, "doctor": true, "completion": true, "recovery": true, "upgrade": true, "version": true, "schema": true, "mcp": true, "help": true, diff --git a/internal/app/root_execute_test.go b/internal/app/root_execute_test.go index 4a2e57b0..7ec88e8b 100644 --- a/internal/app/root_execute_test.go +++ b/internal/app/root_execute_test.go @@ -263,7 +263,7 @@ func TestRootHelpDoesNotRequirePINOrLogin(t *testing.T) { if !strings.Contains(out.String(), "Discovered MCP Services:") { t.Fatalf("root help output missing MCP summary:\n%s", out.String()) } - for _, want := range []string{"Utility Commands:", "skill", "auth", "version"} { + for _, want := range []string{"Utility Commands:", "skill", "auth", "profile", "version", "Global Flags:", "--profile"} { if !strings.Contains(out.String(), want) { t.Fatalf("root help output missing %q:\n%s", want, out.String()) } diff --git a/internal/app/root_help.go b/internal/app/root_help.go index 552142bc..451b1129 100644 --- a/internal/app/root_help.go +++ b/internal/app/root_help.go @@ -9,6 +9,7 @@ import ( "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/tui" "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/pkg/edition" "github.com/spf13/cobra" + "github.com/spf13/pflag" ) func configureRootHelp(root *cobra.Command) { @@ -86,6 +87,7 @@ func renderRootHelp(root *cobra.Command) { _ = tw.Flush() _, _ = fmt.Fprintln(w) } + renderRootGlobalFlags(root) _, _ = fmt.Fprintf(w, "%s %s\n", tui.Key("Next"), `Use "dws --help" for more information about a discovered MCP service or "dws --help" for utility commands.`) // Render root.Long after the command list so agents see the upgrade @@ -99,6 +101,53 @@ func renderRootHelp(root *cobra.Command) { } } +func renderRootGlobalFlags(root *cobra.Command) { + if root == nil { + return + } + flags := visiblePersistentFlags(root) + if len(flags) == 0 { + return + } + w := root.OutOrStdout() + _, _ = fmt.Fprintln(w, tui.Section("Global Flags:")) + _, _ = fmt.Fprintln(w) + tw := tabwriter.NewWriter(w, 0, 0, 2, ' ', 0) + for _, flag := range flags { + _, _ = fmt.Fprintf(tw, " %s\t%s\n", formatRootFlag(flag), tui.Dim(strings.TrimSpace(flag.Usage))) + } + _ = tw.Flush() + _, _ = fmt.Fprintln(w) +} + +func visiblePersistentFlags(root *cobra.Command) []*pflag.Flag { + if root == nil { + return nil + } + flags := make([]*pflag.Flag, 0) + root.PersistentFlags().VisitAll(func(flag *pflag.Flag) { + if flag == nil || flag.Hidden { + return + } + flags = append(flags, flag) + }) + return flags +} + +func formatRootFlag(flag *pflag.Flag) string { + if flag == nil { + return "" + } + name := "--" + flag.Name + if flag.Value != nil && flag.Value.Type() != "bool" { + name += " " + flag.Value.Type() + } + if flag.Shorthand == "" { + return " " + name + } + return "-" + flag.Shorthand + ", " + name +} + func commandShort(cmd *cobra.Command) string { if cmd == nil { return "" diff --git a/internal/app/runner.go b/internal/app/runner.go index 100d6bf7..49b525c2 100644 --- a/internal/app/runner.go +++ b/internal/app/runner.go @@ -162,6 +162,18 @@ func (r *runtimeRunner) Run(ctx context.Context, invocation executor.Invocation) // invocations within the same process free. logHostOwnedPATDecisionOnce() + selections, multi, err := resolveMultiProfileSelections(defaultConfigDir(), authpkg.RuntimeProfile()) + if err != nil { + return executor.Result{}, apperrors.NewValidation(err.Error()) + } + if multi { + return r.runMultiProfile(ctx, invocation, selections) + } + + return r.runSingle(ctx, invocation, true) +} + +func (r *runtimeRunner) runSingle(ctx context.Context, invocation executor.Invocation, prefetchToken bool) (executor.Result, error) { if r.loader == nil || r.transport == nil { return r.fallback.Run(ctx, invocation) } @@ -179,7 +191,9 @@ func (r *runtimeRunner) Run(ctx context.Context, invocation executor.Invocation) // Prefetch the Keychain token in the background. Keychain access costs // ~70ms on macOS; starting it here lets the load overlap with endpoint // resolution and catalog loading below. - go getCachedRuntimeToken(ctx) + if prefetchToken { + go getCachedRuntimeToken(ctx) + } if shouldUseDirectRuntime(invocation) { if endpoint, ok := directRuntimeEndpoint(invocation.CanonicalProduct, invocation.Tool); ok { @@ -239,6 +253,144 @@ func (r *runtimeRunner) Run(ctx context.Context, invocation executor.Invocation) return r.executeInvocation(ctx, endpoint, invocation) } +type multiProfileSelection struct { + Selector string + Profile authpkg.Profile +} + +func resolveMultiProfileSelections(configDir, rawSelector string) ([]multiProfileSelection, bool, error) { + rawSelector = strings.TrimSpace(rawSelector) + if rawSelector == "" || !strings.Contains(rawSelector, ",") { + return nil, false, nil + } + if p, err := authpkg.ResolveProfile(configDir, rawSelector); err == nil && p != nil { + return nil, false, nil + } + + parts := strings.Split(rawSelector, ",") + selections := make([]multiProfileSelection, 0, len(parts)) + seen := make(map[string]bool, len(parts)) + for _, part := range parts { + selector := strings.TrimSpace(part) + if selector == "" { + return nil, false, fmt.Errorf("--profile contains an empty profile selector: %q", rawSelector) + } + profile, err := authpkg.ResolveProfile(configDir, selector) + if err != nil { + return nil, false, err + } + if profile == nil { + return nil, false, fmt.Errorf("profile %q not found", selector) + } + if seen[profile.CorpID] { + continue + } + seen[profile.CorpID] = true + selections = append(selections, multiProfileSelection{ + Selector: selector, + Profile: *profile, + }) + } + if len(selections) == 0 { + return nil, false, nil + } + return selections, true, nil +} + +func (r *runtimeRunner) runMultiProfile(ctx context.Context, invocation executor.Invocation, selections []multiProfileSelection) (executor.Result, error) { + previousProfile := authpkg.RuntimeProfile() + defer authpkg.SetRuntimeProfile(previousProfile) + + entries := make([]any, 0, len(selections)) + succeeded := 0 + failed := 0 + + for _, selection := range selections { + authpkg.SetRuntimeProfile(selection.Profile.CorpID) + result, err := r.runSingle(ctx, cloneInvocation(invocation), false) + + entry := map[string]any{ + "selector": selection.Selector, + "corpId": selection.Profile.CorpID, + "corpName": selection.Profile.CorpName, + "ok": err == nil, + } + if err != nil { + failed++ + entry["error"] = multiProfileErrorPayload(err) + } else { + succeeded++ + if payload := multiProfileResultPayload(result); payload != nil { + entry["result"] = payload + } + if result.Response != nil { + if endpoint, ok := result.Response["endpoint"]; ok { + entry["endpoint"] = endpoint + } + } + } + entries = append(entries, entry) + } + + invocation.Implemented = true + return executor.Result{ + Invocation: invocation, + Response: map[string]any{ + "content": map[string]any{ + "success": failed == 0, + "multiProfile": true, + "summary": map[string]any{ + "total": len(selections), + "succeeded": succeeded, + "failed": failed, + }, + "profiles": entries, + }, + }, + }, nil +} + +func cloneInvocation(invocation executor.Invocation) executor.Invocation { + cloned := invocation + if invocation.Params != nil { + cloned.Params = make(map[string]any, len(invocation.Params)) + for key, value := range invocation.Params { + cloned.Params[key] = value + } + } + return cloned +} + +func multiProfileResultPayload(result executor.Result) any { + if result.Response == nil { + return nil + } + if content, ok := result.Response["content"]; ok { + return content + } + return result.Response +} + +func multiProfileErrorPayload(err error) map[string]any { + payload := map[string]any{ + "message": err.Error(), + } + var typed *apperrors.Error + if errors.As(err, &typed) { + payload["category"] = string(typed.Category) + if typed.Reason != "" { + payload["reason"] = typed.Reason + } + if typed.Operation != "" { + payload["operation"] = typed.Operation + } + if code := typed.ExitCode(); code != 0 { + payload["exitCode"] = code + } + } + return payload +} + // handleCatalogMiss decides what to do when discovery catalog does not cover the // requested product / tool and no `directRuntimeEndpoint` match fired earlier. // @@ -604,28 +756,40 @@ func resolveRuntimeAuthToken(ctx context.Context, explicitToken string) string { // Cached token state for process lifetime var ( - cachedRuntimeToken string - cachedRuntimeTokenOnce sync.Once + cachedRuntimeTokenMu sync.Mutex + cachedRuntimeTokens = map[string]string{} ) // getCachedRuntimeToken returns a cached access token, loading it only once per process. // This avoids repeated Keychain access which takes ~70ms each time. func getCachedRuntimeToken(ctx context.Context) string { - cachedRuntimeTokenOnce.Do(func() { - loadStart := time.Now() - defer func() { RecordTiming(ctx, "auth_keychain", time.Since(loadStart)) }() - - configDir := defaultConfigDir() - token, tokenErr := resolveAccessTokenFromDir(ctx, configDir) - if tokenErr != nil && errors.Is(tokenErr, authpkg.ErrTokenDecryption) { - slog.Error(tokenErr.Error()) - return - } - if token != "" { - cachedRuntimeToken = token - } - }) - return cachedRuntimeToken + cacheKey := strings.TrimSpace(authpkg.RuntimeProfile()) + if cacheKey == "" { + cacheKey = "__default__" + } + cachedRuntimeTokenMu.Lock() + if token := cachedRuntimeTokens[cacheKey]; token != "" { + cachedRuntimeTokenMu.Unlock() + return token + } + cachedRuntimeTokenMu.Unlock() + + loadStart := time.Now() + defer func() { RecordTiming(ctx, "auth_keychain", time.Since(loadStart)) }() + + configDir := defaultConfigDir() + token, tokenErr := resolveAccessTokenFromDir(ctx, configDir) + if tokenErr != nil && errors.Is(tokenErr, authpkg.ErrTokenDecryption) { + slog.Error(tokenErr.Error()) + return "" + } + if token == "" { + return "" + } + cachedRuntimeTokenMu.Lock() + cachedRuntimeTokens[cacheKey] = token + cachedRuntimeTokenMu.Unlock() + return token } // generateExecutionID returns a random 16-char hex string used to correlate @@ -640,8 +804,9 @@ func generateExecutionID() string { // ResetRuntimeTokenCache clears the cached token, forcing a reload on next access. // This should be called after login/logout operations. func ResetRuntimeTokenCache() { - cachedRuntimeTokenOnce = sync.Once{} - cachedRuntimeToken = "" + cachedRuntimeTokenMu.Lock() + defer cachedRuntimeTokenMu.Unlock() + cachedRuntimeTokens = map[string]string{} } func newRuntimeContentScanner() safety.Scanner { diff --git a/internal/app/skill_setup.go b/internal/app/skill_setup.go index a47a5340..c07424e3 100644 --- a/internal/app/skill_setup.go +++ b/internal/app/skill_setup.go @@ -123,7 +123,9 @@ func runSkillSetup(cmd *cobra.Command, _ []string) error { if filterErr != nil { return filterErr } - multiSkillNames = filtered + // dws-shared carries the global rules every product skill declares as a + // PREREQUISITE; it must ship even when --skill / --exclude narrows the set. + multiSkillNames = ensureMandatorySharedSkill(filtered, allMultiSkillNames) } if !autoYes { @@ -160,6 +162,33 @@ func runSkillSetup(cmd *cobra.Command, _ []string) error { // bundle in skills/multi/ (e.g. dingtalk-aitable, dingtalk-calendar). const multiSkillPrefix = "dingtalk-" +// multiSharedSkill is the shared, non-product skill that every per-product +// skill declares as a PREREQUISITE. It must always be installed in multi mode +// regardless of --skill / --exclude, otherwise the product skills reference a +// dws-shared that was never installed. +const multiSharedSkill = "dws-shared" + +// ensureMandatorySharedSkill guarantees the shared dependency skill is included +// whenever it exists in the source, even if --skill / --exclude narrowed it out. +func ensureMandatorySharedSkill(selected, all []string) []string { + hasShared := false + for _, n := range all { + if n == multiSharedSkill { + hasShared = true + break + } + } + if !hasShared { + return selected + } + for _, n := range selected { + if n == multiSharedSkill { + return selected + } + } + return append([]string{multiSharedSkill}, selected...) +} + // normalizeMultiSkillName accepts either the short form (aitable) or the // full form (dingtalk-aitable) and returns the canonical full form. // Empty input returns "". Comparison is case-insensitive. diff --git a/internal/auth/auth_extra_test.go b/internal/auth/auth_extra_test.go index 2eb5b682..5d702d32 100644 --- a/internal/auth/auth_extra_test.go +++ b/internal/auth/auth_extra_test.go @@ -330,6 +330,63 @@ func TestBuildTokenData_DefaultExpiry(t *testing.T) { } } +func TestParseMCPTokenResponseIncludesCorpName(t *testing.T) { + provider := &OAuthProvider{} + data, err := provider.parseMCPTokenResponse([]byte(`{ + "accessToken": "access-123", + "refreshToken": "refresh-456", + "expiresIn": 7200, + "corpId": "ding123", + "corpName": "钉钉(中国)信息技术有限公司" + }`)) + if err != nil { + t.Fatalf("parseMCPTokenResponse() error = %v", err) + } + if data.CorpID != "ding123" { + t.Fatalf("corp id = %q, want ding123", data.CorpID) + } + if data.CorpName != "钉钉(中国)信息技术有限公司" { + t.Fatalf("corp name = %q, want 钉钉(中国)信息技术有限公司", data.CorpName) + } +} + +func TestParseMCPTokenResponseCorpNameFallbacks(t *testing.T) { + provider := &OAuthProvider{} + for _, tc := range []struct { + name string + body string + want string + }{ + { + name: "snake", + body: `{"accessToken":"access","refreshToken":"refresh","expiresIn":7200,"corpId":"ding123","corp_name":"Snake Corp"}`, + want: "Snake Corp", + }, + { + name: "orgName", + body: `{"accessToken":"access","refreshToken":"refresh","expiresIn":7200,"corpId":"ding123","orgName":"Org Corp"}`, + want: "Org Corp", + }, + } { + t.Run(tc.name, func(t *testing.T) { + data, err := provider.parseMCPTokenResponse([]byte(tc.body)) + if err != nil { + t.Fatalf("parseMCPTokenResponse() error = %v", err) + } + if data.CorpName != tc.want { + t.Fatalf("corp name = %q, want %q", data.CorpName, tc.want) + } + }) + } +} + +func TestBuildAuthURLIncludesTargetCorpID(t *testing.T) { + authURL := buildAuthURL("client-id", "http://127.0.0.1:1234/callback", "ding-target") + if !strings.Contains(authURL, "corpId=ding-target") { + t.Fatalf("auth URL missing target corpId: %s", authURL) + } +} + func buildTokenDataFromResponse(resp tokenResponse) *TokenData { if resp.AccessToken == "" { return nil diff --git a/internal/auth/keychain_store.go b/internal/auth/keychain_store.go index 1ab513f6..e0390430 100644 --- a/internal/auth/keychain_store.go +++ b/internal/auth/keychain_store.go @@ -17,6 +17,7 @@ import ( "encoding/json" "fmt" "log/slog" + "strings" "sync" "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/keychain" @@ -30,6 +31,24 @@ var ( // SaveTokenDataKeychain saves TokenData to the platform keychain. // This is the new secure storage method using random master key. func SaveTokenDataKeychain(data *TokenData) error { + return saveTokenDataKeychainAccount(keychain.AccountToken, data) +} + +// TokenAccountForCorpID returns the keychain account used for a corp-bound token. +func TokenAccountForCorpID(corpID string) string { + return keychain.AccountToken + ":" + strings.TrimSpace(corpID) +} + +// SaveTokenDataKeychainForCorpID saves TokenData to a corp-scoped keychain slot. +func SaveTokenDataKeychainForCorpID(corpID string, data *TokenData) error { + corpID = strings.TrimSpace(corpID) + if corpID == "" { + return fmt.Errorf("corpId is required for profile token storage") + } + return saveTokenDataKeychainAccount(TokenAccountForCorpID(corpID), data) +} + +func saveTokenDataKeychainAccount(account string, data *TokenData) error { jsonData, err := json.MarshalIndent(data, "", " ") if err != nil { return fmt.Errorf("marshal token data: %w", err) @@ -41,7 +60,7 @@ func SaveTokenDataKeychain(data *TokenData) error { } }() - if err := keychain.Set(keychain.Service, keychain.AccountToken, string(jsonData)); err != nil { + if err := keychain.Set(keychain.Service, account, string(jsonData)); err != nil { return fmt.Errorf("save to keychain: %w", err) } return nil @@ -49,12 +68,25 @@ func SaveTokenDataKeychain(data *TokenData) error { // LoadTokenDataKeychain loads TokenData from the platform keychain. func LoadTokenDataKeychain() (*TokenData, error) { - jsonStr, err := keychain.Get(keychain.Service, keychain.AccountToken) + return loadTokenDataKeychainAccount(keychain.AccountToken) +} + +// LoadTokenDataKeychainForCorpID loads TokenData from a corp-scoped keychain slot. +func LoadTokenDataKeychainForCorpID(corpID string) (*TokenData, error) { + corpID = strings.TrimSpace(corpID) + if corpID == "" { + return nil, fmt.Errorf("corpId is required for profile token storage") + } + return loadTokenDataKeychainAccount(TokenAccountForCorpID(corpID)) +} + +func loadTokenDataKeychainAccount(account string) (*TokenData, error) { + jsonStr, err := keychain.Get(keychain.Service, account) if err != nil { return nil, fmt.Errorf("load from keychain: %w", err) } if jsonStr == "" { - return nil, fmt.Errorf("no token data in keychain") + return nil, fmt.Errorf("no token data in keychain account %q", account) } var data TokenData @@ -69,11 +101,29 @@ func DeleteTokenDataKeychain() error { return keychain.Remove(keychain.Service, keychain.AccountToken) } +// DeleteTokenDataKeychainForCorpID removes TokenData from a corp-scoped keychain slot. +func DeleteTokenDataKeychainForCorpID(corpID string) error { + corpID = strings.TrimSpace(corpID) + if corpID == "" { + return fmt.Errorf("corpId is required for profile token storage") + } + return keychain.Remove(keychain.Service, TokenAccountForCorpID(corpID)) +} + // TokenDataExistsKeychain checks if token data exists in keychain. func TokenDataExistsKeychain() bool { return keychain.Exists(keychain.Service, keychain.AccountToken) } +// TokenDataExistsKeychainForCorpID checks if a corp-scoped token exists. +func TokenDataExistsKeychainForCorpID(corpID string) bool { + corpID = strings.TrimSpace(corpID) + if corpID == "" { + return false + } + return keychain.Exists(keychain.Service, TokenAccountForCorpID(corpID)) +} + // EnsureMigration performs one-time migration from legacy .data to keychain. // This should be called early in the auth flow (e.g., during GetAccessToken). // The migration is idempotent and thread-safe. diff --git a/internal/auth/oauth_helpers.go b/internal/auth/oauth_helpers.go index 372b353a..9c64960a 100644 --- a/internal/auth/oauth_helpers.go +++ b/internal/auth/oauth_helpers.go @@ -23,6 +23,7 @@ import ( "net/url" "os" "slices" + "strings" "time" "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/pkg/config" @@ -143,9 +144,13 @@ func (p *OAuthProvider) refreshWithRefreshToken(ctx context.Context, data *Token updated.CorpID = data.CorpID updated.UserID = data.UserID updated.UserName = data.UserName - updated.CorpName = data.CorpName + if updated.CorpName == "" { + updated.CorpName = data.CorpName + } - if err := SaveTokenData(p.configDir, updated); err != nil { + // Refresh runs under lockedRefresh's dual-layer lock; use the lock-free + // saver to avoid re-acquiring the non-reentrant lock (deadlock). + if err := saveTokenDataLocked(p.configDir, updated); err != nil { return nil, fmt.Errorf("保存刷新后的 token 失败(旧 refresh_token 已失效,请重新登录): %w", err) } return updated, nil @@ -185,9 +190,13 @@ func (p *OAuthProvider) refreshViaMCP(ctx context.Context, data *TokenData) (*To updated.CorpID = data.CorpID updated.UserID = data.UserID updated.UserName = data.UserName - updated.CorpName = data.CorpName + if updated.CorpName == "" { + updated.CorpName = data.CorpName + } - if err := SaveTokenData(p.configDir, updated); err != nil { + // Refresh runs under lockedRefresh's dual-layer lock; use the lock-free + // saver to avoid re-acquiring the non-reentrant lock (deadlock). + if err := saveTokenDataLocked(p.configDir, updated); err != nil { return nil, fmt.Errorf("保存刷新后的 token 失败(旧 refresh_token 已失效,请重新登录): %w", err) } return updated, nil @@ -259,7 +268,7 @@ func (p *OAuthProvider) parseTokenResponse(body []byte) (*TokenData, error) { } // parseMCPTokenResponse parses token response from MCP proxy. -// MCP OAuth response format: {"accessToken": "...", "refreshToken": "...", "expiresIn": 7200, "corpId": "..."} +// MCP OAuth response format: {"accessToken": "...", "refreshToken": "...", "expiresIn": 7200, "corpId": "...", "corpName": "..."} func (p *OAuthProvider) parseMCPTokenResponse(body []byte) (*TokenData, error) { var resp struct { AccessToken string `json:"accessToken"` @@ -267,6 +276,9 @@ func (p *OAuthProvider) parseMCPTokenResponse(body []byte) (*TokenData, error) { PersistentCode string `json:"persistentCode"` ExpiresIn int64 `json:"expiresIn"` CorpID string `json:"corpId"` + CorpName string `json:"corpName"` + CorpNameSnake string `json:"corp_name"` + OrgName string `json:"orgName"` // Error fields (when request fails) ErrorCode string `json:"errorCode,omitempty"` ErrorMsg string `json:"errorMsg,omitempty"` @@ -293,6 +305,7 @@ func (p *OAuthProvider) parseMCPTokenResponse(body []byte) (*TokenData, error) { ExpiresAt: now.Add(time.Duration(expiresIn) * time.Second), RefreshExpAt: now.Add(config.DefaultRefreshTokenLifetime), CorpID: resp.CorpID, + CorpName: firstNonEmpty(resp.CorpName, resp.CorpNameSnake, resp.OrgName), } if resp.PersistentCode != "" { data.PersistentCode = resp.PersistentCode @@ -300,7 +313,16 @@ func (p *OAuthProvider) parseMCPTokenResponse(body []byte) (*TokenData, error) { return data, nil } -func buildAuthURL(clientID, redirectURI string) string { +func firstNonEmpty(values ...string) string { + for _, v := range values { + if trimmed := strings.TrimSpace(v); trimmed != "" { + return trimmed + } + } + return "" +} + +func buildAuthURL(clientID, redirectURI, targetCorpID string) string { params := url.Values{ "client_id": {clientID}, "redirect_uri": {redirectURI}, @@ -308,6 +330,9 @@ func buildAuthURL(clientID, redirectURI string) string { "scope": {DefaultScopes}, "prompt": {"consent"}, } + if targetCorpID = strings.TrimSpace(targetCorpID); targetCorpID != "" { + params.Set("corpId", targetCorpID) + } return AuthorizeURL + "?" + params.Encode() } diff --git a/internal/auth/oauth_provider.go b/internal/auth/oauth_provider.go index 66a9c461..3d08da25 100644 --- a/internal/auth/oauth_provider.go +++ b/internal/auth/oauth_provider.go @@ -37,12 +37,13 @@ var oauthHTTPClient = &http.Client{ // OAuthProvider handles the DingTalk OAuth 2.0 authorization code flow. type OAuthProvider struct { - configDir string - clientID string - logger *slog.Logger - Output io.Writer - httpClient *http.Client - NoBrowser bool + configDir string + clientID string + logger *slog.Logger + Output io.Writer + httpClient *http.Client + NoBrowser bool + TargetCorpID string } // NewOAuthProvider creates a new OAuth provider. @@ -397,7 +398,7 @@ func (p *OAuthProvider) Login(ctx context.Context, force bool) (*TokenData, erro _ = server.Shutdown(shutCtx) }() - authURL := buildAuthURL(p.clientID, redirectURI) + authURL := buildAuthURL(p.clientID, redirectURI, p.TargetCorpID) if p.logger != nil { p.logger.Debug("authorization URL", "url", authURL) } @@ -547,9 +548,12 @@ func (p *OAuthProvider) GetAccessToken(ctx context.Context) (string, error) { if rErr == nil { return refreshed.AccessToken, nil } + _ = MarkProfileStatus(p.configDir, data.CorpID, ProfileStatusExpired) if p.logger != nil { p.logger.Warn(i18n.T("refresh_token 刷新失败"), "error", rErr) } + } else { + _ = MarkProfileStatus(p.configDir, data.CorpID, ProfileStatusExpired) } return "", errors.New(i18n.T("所有凭证已失效,请运行 dws auth login 重新登录")) diff --git a/internal/auth/portable_store.go b/internal/auth/portable_store.go index 0bc3d256..a32023df 100644 --- a/internal/auth/portable_store.go +++ b/internal/auth/portable_store.go @@ -52,6 +52,9 @@ func PortableAuthTargetPopulated(configDir string) bool { if TokenDataExistsKeychain() { return true } + if _, err := os.Stat(ProfilesPath(configDir)); err == nil { + return true + } if _, err := os.Stat(filepath.Join(configDir, "app.json")); err == nil { return true } @@ -199,7 +202,7 @@ func ImportPortableAuthBundle(configDir string, r io.Reader) (PortableImportRepo func portableConfigFiles(configDir string) ([]string, error) { var files []string - patterns := []string{"app*.json", "mcp_url", "terminal_url"} + patterns := []string{"app*.json", profilesJSONFile, "mcp_url", "terminal_url"} for _, pattern := range patterns { matches, err := filepath.Glob(filepath.Join(configDir, pattern)) if err != nil { diff --git a/internal/auth/portable_store_test.go b/internal/auth/portable_store_test.go index 3e019a5b..95ae727b 100644 --- a/internal/auth/portable_store_test.go +++ b/internal/auth/portable_store_test.go @@ -138,3 +138,76 @@ func TestPortableAuthBundleRoundTripPreservesRefreshToken(t *testing.T) { t.Fatalf("imported app config = %#v, want client ID preserved", cfg) } } + +func TestPortableAuthBundleRoundTripPreservesProfiles(t *testing.T) { + t.Setenv(keychain.DisableKeychainEnv, "1") + SetRuntimeProfile("") + t.Cleanup(func() { SetRuntimeProfile("") }) + + sourceKeychain := filepath.Join(t.TempDir(), "source-keychain") + t.Setenv(keychain.StorageDirEnv, sourceKeychain) + sourceConfig := filepath.Join(t.TempDir(), ".dws") + + tokenA := &TokenData{ + AccessToken: "access-a", + RefreshToken: "refresh-a", + ExpiresAt: time.Now().Add(time.Hour), + RefreshExpAt: time.Now().Add(30 * 24 * time.Hour), + CorpID: "corp_a", + CorpName: "A Org", + ClientID: "client-a", + } + tokenB := &TokenData{ + AccessToken: "access-b", + RefreshToken: "refresh-b", + ExpiresAt: time.Now().Add(time.Hour), + RefreshExpAt: time.Now().Add(30 * 24 * time.Hour), + CorpID: "corp_b", + CorpName: "B Org", + ClientID: "client-b", + } + if err := SaveTokenData(sourceConfig, tokenA); err != nil { + t.Fatalf("SaveTokenData(A) error = %v", err) + } + if err := SaveTokenData(sourceConfig, tokenB); err != nil { + t.Fatalf("SaveTokenData(B) error = %v", err) + } + + var bundle bytes.Buffer + if err := ExportPortableAuthBundle(sourceConfig, &bundle); err != nil { + t.Fatalf("ExportPortableAuthBundle() error = %v", err) + } + + targetKeychain := filepath.Join(t.TempDir(), "target-keychain") + t.Setenv(keychain.StorageDirEnv, targetKeychain) + targetConfig := filepath.Join(t.TempDir(), ".dws") + if _, err := ImportPortableAuthBundle(targetConfig, bytes.NewReader(bundle.Bytes())); err != nil { + t.Fatalf("ImportPortableAuthBundle() error = %v", err) + } + + cfg, err := LoadProfiles(targetConfig) + if err != nil { + t.Fatalf("LoadProfiles() after import error = %v", err) + } + if cfg.PrimaryProfile != "corp_a" || cfg.CurrentProfile != "corp_b" || cfg.PreviousProfile != "corp_a" { + t.Fatalf("profiles after import = %#v", cfg) + } + if len(cfg.Profiles) != 2 { + t.Fatalf("profiles len = %d, want 2: %#v", len(cfg.Profiles), cfg.Profiles) + } + + loadedA, err := LoadTokenDataForProfile(targetConfig, "corp_a") + if err != nil { + t.Fatalf("LoadTokenDataForProfile(A) after import error = %v", err) + } + if loadedA.AccessToken != "access-a" { + t.Fatalf("profile A token = %q, want access-a", loadedA.AccessToken) + } + loadedB, err := LoadTokenDataForProfile(targetConfig, "corp_b") + if err != nil { + t.Fatalf("LoadTokenDataForProfile(B) after import error = %v", err) + } + if loadedB.AccessToken != "access-b" { + t.Fatalf("profile B token = %q, want access-b", loadedB.AccessToken) + } +} diff --git a/internal/auth/profiles.go b/internal/auth/profiles.go new file mode 100644 index 00000000..c72770b3 --- /dev/null +++ b/internal/auth/profiles.go @@ -0,0 +1,678 @@ +// Copyright 2026 Alibaba Group +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/google/uuid" + + "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/pkg/config" +) + +// withProfilesLock runs fn while holding the auth dual-layer lock (process + +// cross-process file lock) so that all read-modify-write cycles on +// profiles.json and the legacy token mirror are serialized. +// +// The lock is NOT reentrant. fn must only call the lock-free *Locked variants; +// calling a public (locking) function from within fn would deadlock. Paths that +// already hold the lock (e.g. OAuthProvider.lockedRefresh and the read path +// reached from it) must likewise call the lock-free variants directly. +func withProfilesLock(configDir string, fn func() error) error { + lock, err := AcquireDualLock(context.Background(), configDir) + if err != nil { + return err + } + defer lock.Release() + return fn() +} + +const profilesJSONFile = "profiles.json" + +const ( + ProfileStatusActive = "active" + ProfileStatusExpired = "expired" + ProfileStatusRevoked = "revoked" +) + +// ProfilesConfig stores non-sensitive profile metadata. Token material stays in keychain. +type ProfilesConfig struct { + Version int `json:"version"` + PrimaryProfile string `json:"primaryProfile,omitempty"` + CurrentProfile string `json:"currentProfile,omitempty"` + PreviousProfile string `json:"previousProfile,omitempty"` + Profiles []Profile `json:"profiles,omitempty"` +} + +// Profile is a logged-in DingTalk organization identity. +type Profile struct { + Name string `json:"name"` + CorpID string `json:"corpId"` + CorpName string `json:"corpName,omitempty"` + UserID string `json:"userId,omitempty"` + UserName string `json:"userName,omitempty"` + ClientID string `json:"clientId,omitempty"` + Status string `json:"status,omitempty"` + AuthorizedDomains []string `json:"authorizedDomains,omitempty"` + ExpiresAt string `json:"expiresAt,omitempty"` + RefreshExpAt string `json:"refreshExpAt,omitempty"` + LastLoginAt string `json:"lastLoginAt,omitempty"` + LastUsedAt string `json:"lastUsedAt,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` +} + +var ( + runtimeProfileMu sync.RWMutex + runtimeProfile string +) + +// SetRuntimeProfile sets a process-local one-shot profile override. +func SetRuntimeProfile(profile string) { + runtimeProfileMu.Lock() + defer runtimeProfileMu.Unlock() + runtimeProfile = strings.TrimSpace(profile) +} + +// RuntimeProfile returns the process-local one-shot profile override. +func RuntimeProfile() string { + runtimeProfileMu.RLock() + defer runtimeProfileMu.RUnlock() + return runtimeProfile +} + +// ProfilesPath returns the profile metadata path for a config dir. +func ProfilesPath(configDir string) string { + return filepath.Join(configDir, profilesJSONFile) +} + +// LoadProfiles reads profiles.json. A missing file returns an empty config. +func LoadProfiles(configDir string) (*ProfilesConfig, error) { + path := ProfilesPath(configDir) + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &ProfilesConfig{Version: 1}, nil + } + return nil, fmt.Errorf("read profiles: %w", err) + } + var cfg ProfilesConfig + if err := json.Unmarshal(data, &cfg); err != nil { + // Corrupt file (e.g. an interrupted concurrent write): quarantine it and + // rebuild an empty config so the CLI can self-heal (auth reset / re-login) + // instead of being permanently locked out by an unreadable profiles.json. + quarantine := path + ".corrupt-" + time.Now().Format("20060102-150405.000") + _ = os.Rename(path, quarantine) + return &ProfilesConfig{Version: 1}, nil + } + normalizeProfilesConfig(&cfg) + return &cfg, nil +} + +// SaveProfiles writes profiles.json atomically. +func SaveProfiles(configDir string, cfg *ProfilesConfig) error { + if cfg == nil { + cfg = &ProfilesConfig{} + } + normalizeProfilesConfig(cfg) + if err := os.MkdirAll(configDir, config.DirPerm); err != nil { + return fmt.Errorf("create config dir: %w", err) + } + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("marshal profiles: %w", err) + } + data = append(data, '\n') + path := ProfilesPath(configDir) + // Per-write random temp name: a fixed "profiles.json.tmp" lets two + // concurrent writers interleave into the same temp file and rename a + // corrupted result into place. + tmp := path + "." + uuid.New().String() + ".tmp" + if err := os.WriteFile(tmp, data, config.FilePerm); err != nil { + return fmt.Errorf("write profiles tmp: %w", err) + } + if err := os.Rename(tmp, path); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("rename profiles: %w", err) + } + return nil +} + +// EnsureProfilesMigration initializes profiles.json from the legacy auth-token slot when needed. +// EnsureProfilesMigration migrates a legacy single-slot token into the +// profiles registry. It acquires the lock; call ensureProfilesMigrationLocked +// from contexts that already hold it (refresh / read paths). +func EnsureProfilesMigration(configDir string) error { + return withProfilesLock(configDir, func() error { + return ensureProfilesMigrationLocked(configDir) + }) +} + +func ensureProfilesMigrationLocked(configDir string) error { + cfg, err := LoadProfiles(configDir) + if err != nil { + return err + } + if len(cfg.Profiles) > 0 { + return nil + } + if !TokenDataExistsKeychain() { + return nil + } + data, err := LoadTokenDataKeychain() + if err != nil || data == nil || strings.TrimSpace(data.CorpID) == "" { + return nil + } + if err := SaveTokenDataKeychainForCorpID(data.CorpID, data); err != nil { + return err + } + return upsertProfileFromToken(configDir, cfg, data, false) +} + +// UpsertProfileFromToken updates profiles.json after a successful login or refresh. +func UpsertProfileFromToken(configDir string, data *TokenData) error { + return UpsertProfileFromTokenWithCurrent(configDir, data, true) +} + +// UpsertProfileFromTokenWithCurrent updates profiles.json and optionally makes +// the token's corp the persistent current profile. +func UpsertProfileFromTokenWithCurrent(configDir string, data *TokenData, makeCurrent bool) error { + return withProfilesLock(configDir, func() error { + return upsertProfileFromTokenWithCurrentLocked(configDir, data, makeCurrent) + }) +} + +func upsertProfileFromTokenWithCurrentLocked(configDir string, data *TokenData, makeCurrent bool) error { + cfg, err := LoadProfiles(configDir) + if err != nil { + return err + } + return upsertProfileFromToken(configDir, cfg, data, makeCurrent) +} + +func upsertProfileFromToken(configDir string, cfg *ProfilesConfig, data *TokenData, makeCurrent bool) error { + if data == nil { + return nil + } + corpID := strings.TrimSpace(data.CorpID) + if corpID == "" { + return nil + } + normalizeProfilesConfig(cfg) + now := time.Now().Format(time.RFC3339) + idx := profileIndexByCorpID(cfg, corpID) + if idx < 0 { + profile := Profile{ + Name: chooseProfileName(cfg, data), + CorpID: corpID, + CorpName: strings.TrimSpace(data.CorpName), + UserID: strings.TrimSpace(data.UserID), + UserName: strings.TrimSpace(data.UserName), + ClientID: strings.TrimSpace(data.ClientID), + Status: ProfileStatusActive, + ExpiresAt: timeOrRFC3339(data.ExpiresAt), + RefreshExpAt: timeOrRFC3339(data.RefreshExpAt), + LastLoginAt: now, + LastUsedAt: now, + UpdatedAt: now, + } + cfg.Profiles = append(cfg.Profiles, profile) + } else { + p := &cfg.Profiles[idx] + if shouldRefreshProfileName(p, data) { + p.Name = chooseProfileName(cfg, data) + } + if v := strings.TrimSpace(data.CorpName); v != "" { + p.CorpName = v + } + if v := strings.TrimSpace(data.UserID); v != "" { + p.UserID = v + } + if v := strings.TrimSpace(data.UserName); v != "" { + p.UserName = v + } + if v := strings.TrimSpace(data.ClientID); v != "" { + p.ClientID = v + } + p.Status = ProfileStatusActive + p.ExpiresAt = timeOrRFC3339(data.ExpiresAt) + p.RefreshExpAt = timeOrRFC3339(data.RefreshExpAt) + p.LastLoginAt = now + p.LastUsedAt = now + p.UpdatedAt = now + } + if cfg.PrimaryProfile == "" { + cfg.PrimaryProfile = corpID + } + if makeCurrent && cfg.CurrentProfile != corpID { + if cfg.CurrentProfile != "" { + cfg.PreviousProfile = cfg.CurrentProfile + } + cfg.CurrentProfile = corpID + } + if cfg.CurrentProfile == "" { + cfg.CurrentProfile = corpID + } + return SaveProfiles(configDir, cfg) +} + +// ResolveProfile returns a profile selected by name/corpId or by current/primary fallback. +func ResolveProfile(configDir, selector string) (*Profile, error) { + if err := ensureProfilesMigrationLocked(configDir); err != nil { + return nil, err + } + cfg, err := LoadProfiles(configDir) + if err != nil { + return nil, err + } + selector = strings.TrimSpace(selector) + if selector != "" { + p := findProfile(cfg, selector) + if p == nil { + return nil, fmt.Errorf("profile %q not found", selector) + } + return p, nil + } + if p := findProfile(cfg, cfg.CurrentProfile); p != nil { + return p, nil + } + if p := findProfile(cfg, cfg.PrimaryProfile); p != nil { + return p, nil + } + return nil, nil +} + +func resolveProfileForLoad(configDir, selector string) (*Profile, error) { + if err := ensureProfilesMigrationLocked(configDir); err != nil { + return nil, err + } + cfg, err := LoadProfiles(configDir) + if err != nil { + return nil, err + } + selector = strings.TrimSpace(selector) + if selector != "" { + p := findProfile(cfg, selector) + if p == nil { + return nil, fmt.Errorf("profile %q not found", selector) + } + return p, nil + } + for _, candidate := range []string{cfg.CurrentProfile, cfg.PrimaryProfile} { + if p := findProfile(cfg, candidate); p != nil && TokenDataExistsKeychainForCorpID(p.CorpID) { + return p, nil + } + } + if p := findProfile(cfg, cfg.CurrentProfile); p != nil { + return p, nil + } + if p := findProfile(cfg, cfg.PrimaryProfile); p != nil { + return p, nil + } + return nil, nil +} + +// SetCurrentProfile persists the selected current profile. +func SetCurrentProfile(configDir, selector string) (*Profile, error) { + var result *Profile + err := withProfilesLock(configDir, func() error { + p, e := setCurrentProfileLocked(configDir, selector) + result = p + return e + }) + return result, err +} + +func setCurrentProfileLocked(configDir, selector string) (*Profile, error) { + if err := ensureProfilesMigrationLocked(configDir); err != nil { + return nil, err + } + cfg, err := LoadProfiles(configDir) + if err != nil { + return nil, err + } + p := findProfile(cfg, selector) + if p == nil { + return nil, fmt.Errorf("profile %q not found", strings.TrimSpace(selector)) + } + if cfg.CurrentProfile != p.CorpID { + if cfg.CurrentProfile != "" { + cfg.PreviousProfile = cfg.CurrentProfile + } + cfg.CurrentProfile = p.CorpID + } + touchProfile(cfg, p.CorpID) + if err := SaveProfiles(configDir, cfg); err != nil { + return nil, err + } + if err := syncLegacyTokenMirrorLocked(configDir); err != nil { + return nil, err + } + return findProfile(cfg, p.CorpID), nil +} + +// UsePreviousProfile toggles currentProfile and previousProfile. +func UsePreviousProfile(configDir string) (*Profile, error) { + var result *Profile + err := withProfilesLock(configDir, func() error { + p, e := usePreviousProfileLocked(configDir) + result = p + return e + }) + return result, err +} + +func usePreviousProfileLocked(configDir string) (*Profile, error) { + if err := ensureProfilesMigrationLocked(configDir); err != nil { + return nil, err + } + cfg, err := LoadProfiles(configDir) + if err != nil { + return nil, err + } + prev := strings.TrimSpace(cfg.PreviousProfile) + if prev == "" { + return nil, fmt.Errorf("previous profile is empty") + } + p := findProfile(cfg, prev) + if p == nil { + return nil, fmt.Errorf("previous profile %q not found", prev) + } + cfg.PreviousProfile, cfg.CurrentProfile = cfg.CurrentProfile, p.CorpID + touchProfile(cfg, p.CorpID) + if err := SaveProfiles(configDir, cfg); err != nil { + return nil, err + } + if err := syncLegacyTokenMirrorLocked(configDir); err != nil { + return nil, err + } + return findProfile(cfg, p.CorpID), nil +} + +// RemoveProfile removes a profile from metadata and returns the removed profile. +func RemoveProfile(configDir, selector string) (*Profile, error) { + var result *Profile + err := withProfilesLock(configDir, func() error { + p, e := removeProfileLocked(configDir, selector) + result = p + return e + }) + return result, err +} + +func removeProfileLocked(configDir, selector string) (*Profile, error) { + cfg, err := LoadProfiles(configDir) + if err != nil { + return nil, err + } + p := findProfile(cfg, selector) + if p == nil { + return nil, fmt.Errorf("profile %q not found", strings.TrimSpace(selector)) + } + removed := *p + kept := cfg.Profiles[:0] + for _, profile := range cfg.Profiles { + if profile.CorpID != removed.CorpID { + kept = append(kept, profile) + } + } + cfg.Profiles = kept + if cfg.PrimaryProfile == removed.CorpID { + cfg.PrimaryProfile = firstProfileCorpID(cfg) + } + if cfg.CurrentProfile == removed.CorpID { + cfg.CurrentProfile = cfg.PrimaryProfile + if cfg.CurrentProfile == "" { + cfg.CurrentProfile = firstProfileCorpID(cfg) + } + } + if cfg.PreviousProfile == removed.CorpID { + cfg.PreviousProfile = "" + } + if len(cfg.Profiles) == 0 { + cfg.PrimaryProfile = "" + cfg.CurrentProfile = "" + cfg.PreviousProfile = "" + } + if err := SaveProfiles(configDir, cfg); err != nil { + return nil, err + } + return &removed, nil +} + +// MarkProfileStatus updates a profile status if it exists. +func MarkProfileStatus(configDir, corpID, status string) error { + if strings.TrimSpace(corpID) == "" { + return nil + } + return withProfilesLock(configDir, func() error { + return markProfileStatusLocked(configDir, corpID, status) + }) +} + +func markProfileStatusLocked(configDir, corpID, status string) error { + cfg, err := LoadProfiles(configDir) + if err != nil { + return err + } + p := findProfile(cfg, corpID) + if p == nil { + return nil + } + p.Status = strings.TrimSpace(status) + p.UpdatedAt = time.Now().Format(time.RFC3339) + return SaveProfiles(configDir, cfg) +} + +// SyncLegacyTokenMirror mirrors the current profile token into legacy auth-token. +func SyncLegacyTokenMirror(configDir string) error { + return withProfilesLock(configDir, func() error { + return syncLegacyTokenMirrorLocked(configDir) + }) +} + +func syncLegacyTokenMirrorLocked(configDir string) error { + cfg, err := LoadProfiles(configDir) + if err != nil { + return err + } + hadReadError := false + for _, candidate := range []string{cfg.CurrentProfile, cfg.PrimaryProfile} { + p := findProfile(cfg, candidate) + if p == nil { + continue + } + data, loadErr := LoadTokenDataKeychainForCorpID(p.CorpID) + if loadErr != nil { + // Transient keychain read failure: do NOT touch the existing mirror. + hadReadError = true + continue + } + if data != nil { + if err := SaveTokenDataKeychain(data); err != nil { + return err + } + return WriteTokenMarker(configDir) + } + } + if hadReadError { + // Keep the existing legacy mirror untouched rather than wiping a host + // app's login state just because keychain was momentarily unavailable. + return nil + } + // All candidate profiles confirmed absent (no token): clear the mirror. + _ = DeleteTokenDataKeychain() + _ = DeleteTokenMarker(configDir) + return nil +} + +func normalizeProfilesConfig(cfg *ProfilesConfig) { + if cfg == nil { + return + } + cfg.Version = 1 + seen := make(map[string]bool, len(cfg.Profiles)) + profiles := cfg.Profiles[:0] + for _, p := range cfg.Profiles { + p.CorpID = strings.TrimSpace(p.CorpID) + if p.CorpID == "" || seen[p.CorpID] { + continue + } + seen[p.CorpID] = true + p.Name = strings.TrimSpace(p.Name) + if p.Name == "" { + p.Name = p.CorpID + } + if corpName := strings.TrimSpace(p.CorpName); p.Name == p.CorpID && corpName != "" && !profileNameTakenByOtherCorp(cfg, corpName, p.CorpID) { + p.Name = corpName + } + if p.Status == "" { + p.Status = ProfileStatusActive + } + profiles = append(profiles, p) + } + cfg.Profiles = profiles + if cfg.PrimaryProfile != "" && findProfile(cfg, cfg.PrimaryProfile) == nil { + cfg.PrimaryProfile = "" + } + if cfg.CurrentProfile != "" && findProfile(cfg, cfg.CurrentProfile) == nil { + cfg.CurrentProfile = "" + } + if cfg.PreviousProfile != "" && findProfile(cfg, cfg.PreviousProfile) == nil { + cfg.PreviousProfile = "" + } + if cfg.PrimaryProfile == "" { + cfg.PrimaryProfile = firstProfileCorpID(cfg) + } + if cfg.CurrentProfile == "" { + cfg.CurrentProfile = cfg.PrimaryProfile + } +} + +func chooseProfileName(cfg *ProfilesConfig, data *TokenData) string { + base := strings.TrimSpace(data.CorpName) + if base == "" { + base = strings.TrimSpace(data.CorpID) + } + if base == "" { + base = "profile" + } + if !profileNameTakenByOtherCorp(cfg, base, data.CorpID) { + return base + } + suffix := shortCorpID(data.CorpID) + name := base + "-" + suffix + if !profileNameTakenByOtherCorp(cfg, name, data.CorpID) { + return name + } + for i := 2; ; i++ { + candidate := fmt.Sprintf("%s-%s-%d", base, suffix, i) + if !profileNameTakenByOtherCorp(cfg, candidate, data.CorpID) { + return candidate + } + } +} + +func shouldRefreshProfileName(p *Profile, data *TokenData) bool { + if p == nil || data == nil { + return false + } + name := strings.TrimSpace(p.Name) + if name == "" { + return true + } + return strings.TrimSpace(data.CorpName) != "" && name == strings.TrimSpace(p.CorpID) +} + +func profileNameTakenByOtherCorp(cfg *ProfilesConfig, name, corpID string) bool { + name = strings.TrimSpace(name) + corpID = strings.TrimSpace(corpID) + for _, p := range cfg.Profiles { + if p.CorpID != corpID && p.Name == name { + return true + } + } + return false +} + +func findProfile(cfg *ProfilesConfig, selector string) *Profile { + if cfg == nil { + return nil + } + selector = strings.TrimSpace(selector) + if selector == "" { + return nil + } + var corpNameMatch *Profile + for i := range cfg.Profiles { + if cfg.Profiles[i].CorpID == selector || cfg.Profiles[i].Name == selector { + return &cfg.Profiles[i] + } + if strings.TrimSpace(cfg.Profiles[i].CorpName) == selector { + if corpNameMatch != nil { + return nil + } + corpNameMatch = &cfg.Profiles[i] + } + } + return corpNameMatch +} + +func profileIndexByCorpID(cfg *ProfilesConfig, corpID string) int { + if cfg == nil { + return -1 + } + for i := range cfg.Profiles { + if cfg.Profiles[i].CorpID == corpID { + return i + } + } + return -1 +} + +func firstProfileCorpID(cfg *ProfilesConfig) string { + if cfg == nil || len(cfg.Profiles) == 0 { + return "" + } + return cfg.Profiles[0].CorpID +} + +func touchProfile(cfg *ProfilesConfig, corpID string) { + if p := findProfile(cfg, corpID); p != nil { + now := time.Now().Format(time.RFC3339) + p.LastUsedAt = now + p.UpdatedAt = now + } +} + +func timeOrRFC3339(t time.Time) string { + if t.IsZero() { + return "" + } + return t.Format(time.RFC3339) +} + +func shortCorpID(corpID string) string { + corpID = strings.TrimSpace(corpID) + if len(corpID) <= 8 { + return corpID + } + return corpID[len(corpID)-8:] +} diff --git a/internal/auth/token.go b/internal/auth/token.go index 93239a18..3582263c 100644 --- a/internal/auth/token.go +++ b/internal/auth/token.go @@ -22,8 +22,11 @@ import ( "net/url" "os" "path/filepath" + "strings" "time" + "github.com/google/uuid" + "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/pkg/edition" ) @@ -82,7 +85,7 @@ func WriteTokenMarker(configDir string) error { if err := os.MkdirAll(configDir, 0o700); err != nil { return err } - tmp := filepath.Join(configDir, tokenJSONFile+".tmp") + tmp := filepath.Join(configDir, tokenJSONFile+"."+uuid.New().String()+".tmp") if err := os.WriteFile(tmp, data, 0o600); err != nil { return err } @@ -91,7 +94,10 @@ func WriteTokenMarker(configDir string) error { // DeleteTokenMarker removes the token.json marker file. func DeleteTokenMarker(configDir string) error { - return os.Remove(filepath.Join(configDir, tokenJSONFile)) + if err := os.Remove(filepath.Join(configDir, tokenJSONFile)); err != nil && !os.IsNotExist(err) { + return err + } + return nil } // SaveTokenData persists TokenData. When an edition hook (SaveToken) is @@ -99,20 +105,67 @@ func DeleteTokenMarker(configDir string) error { // to the default keychain-based storage. func SaveTokenData(configDir string, data *TokenData) error { if h := edition.Get(); h.SaveToken != nil { - jsonData, err := json.MarshalIndent(data, "", " ") - if err != nil { - return fmt.Errorf("marshaling token data for hook: %w", err) + return saveTokenViaHook(h, configDir, data) + } + return withProfilesLock(configDir, func() error { + return saveTokenDataLocked(configDir, data) + }) +} + +// saveTokenDataLocked performs the keychain + profiles.json + legacy mirror +// writes assuming the auth dual-layer lock is already held. Callers that +// already hold the lock (OAuthProvider refresh path, the legacy secure->keychain +// migration in LoadTokenDataForProfile) must use this instead of SaveTokenData +// to avoid deadlocking on the non-reentrant lock. +func saveTokenDataLocked(configDir string, data *TokenData) error { + if h := edition.Get(); h.SaveToken != nil { + return saveTokenViaHook(h, configDir, data) + } + if data != nil && strings.TrimSpace(data.CorpID) != "" { + if err := SaveTokenDataKeychainForCorpID(data.CorpID, data); err != nil { + return err + } + makeCurrent := strings.TrimSpace(RuntimeProfile()) == "" + if err := upsertProfileFromTokenWithCurrentLocked(configDir, data, makeCurrent); err != nil { + return err } - return h.SaveToken(configDir, jsonData) + if makeCurrent { + if err := SaveTokenDataKeychain(data); err != nil { + return err + } + } else if err := syncLegacyTokenMirrorLocked(configDir); err != nil { + return err + } + return WriteTokenMarker(configDir) + } + if err := SaveTokenDataKeychain(data); err != nil { + return err } - return SaveTokenDataKeychain(data) + return WriteTokenMarker(configDir) +} + +func saveTokenViaHook(h *edition.Hooks, configDir string, data *TokenData) error { + jsonData, err := json.MarshalIndent(data, "", " ") + if err != nil { + return fmt.Errorf("marshaling token data for hook: %w", err) + } + return h.SaveToken(configDir, jsonData) } // LoadTokenData reads TokenData. When an edition hook (LoadToken) is // registered, it delegates entirely to the hook; otherwise it falls back // to keychain with legacy .data migration. func LoadTokenData(configDir string) (*TokenData, error) { + return LoadTokenDataForProfile(configDir, RuntimeProfile()) +} + +// LoadTokenDataForProfile reads TokenData for a profile selector without mutating +// currentProfile. Empty selector follows the default resolution chain. +func LoadTokenDataForProfile(configDir, profile string) (*TokenData, error) { if h := edition.Get(); h.LoadToken != nil { + if strings.TrimSpace(profile) != "" { + return nil, fmt.Errorf("profile selection is not supported by the current auth backend") + } jsonData, err := h.LoadToken(configDir) if err != nil { return nil, err @@ -125,6 +178,28 @@ func LoadTokenData(configDir string) (*TokenData, error) { } // Default: keychain with legacy .data migration + selected, err := resolveProfileForLoad(configDir, profile) + if err != nil { + return nil, err + } + if selected != nil { + data, err := LoadTokenDataKeychainForCorpID(selected.CorpID) + if err == nil { + return data, nil + } + if strings.TrimSpace(profile) != "" { + return nil, err + } + // No explicit --profile: `selected` is the resolved current/primary + // profile. Only fall back to the legacy single slot when it belongs to + // the SAME org; otherwise surface the error instead of silently acting + // as a different organization (the legacy mirror may have drifted). + if legacy, lerr := LoadTokenDataKeychain(); lerr == nil && legacy != nil && + strings.TrimSpace(legacy.CorpID) == strings.TrimSpace(selected.CorpID) { + return legacy, nil + } + return nil, err + } if TokenDataExistsKeychain() { return LoadTokenDataKeychain() } @@ -132,7 +207,9 @@ func LoadTokenData(configDir string) (*TokenData, error) { if err != nil { return nil, err } - if err := SaveTokenDataKeychain(data); err == nil { + // One-time legacy secure-store -> keychain migration. This read path may run + // while the refresh lock is already held, so use the lock-free saver. + if err := saveTokenDataLocked(configDir, data); err == nil { _ = DeleteSecureData(configDir) } return data, nil @@ -142,15 +219,95 @@ func LoadTokenData(configDir string) (*TokenData, error) { // registered, it delegates entirely to the hook; otherwise it falls back // to keychain + legacy cleanup. func DeleteTokenData(configDir string) error { + return DeleteTokenDataForProfile(configDir, RuntimeProfile()) +} + +// DeleteTokenDataForProfile removes one profile's token data. Empty selector +// removes the current/default profile, falling back to legacy single-slot auth. +func DeleteTokenDataForProfile(configDir, profile string) error { if h := edition.Get(); h.DeleteToken != nil { + if strings.TrimSpace(profile) != "" { + return fmt.Errorf("profile selection is not supported by the current auth backend") + } return h.DeleteToken(configDir) } + return withProfilesLock(configDir, func() error { + return deleteTokenDataForProfileLocked(configDir, profile) + }) +} + +func deleteTokenDataForProfileLocked(configDir, profile string) error { + selected, err := resolveProfileForLoad(configDir, profile) + if err != nil { + return err + } + if selected != nil { + keychainErr := DeleteTokenDataKeychainForCorpID(selected.CorpID) + _, removeErr := removeProfileLocked(configDir, selected.CorpID) + legacyErr := syncLegacyTokenMirrorLocked(configDir) + secureErr := DeleteSecureData(configDir) + if keychainErr != nil { + return keychainErr + } + if removeErr != nil { + return removeErr + } + if legacyErr != nil { + return legacyErr + } + return secureErr + } + keychainErr := DeleteTokenDataKeychain() legacyErr := DeleteSecureData(configDir) + markerErr := DeleteTokenMarker(configDir) if keychainErr != nil { return keychainErr } - return legacyErr + if legacyErr != nil { + return legacyErr + } + return markerErr +} + +// DeleteAllTokenData removes all profile-scoped and legacy token data. +func DeleteAllTokenData(configDir string) error { + if h := edition.Get(); h.DeleteToken != nil { + return h.DeleteToken(configDir) + } + return withProfilesLock(configDir, func() error { + var firstErr error + // Best-effort: even if profiles.json is unreadable, still clear every + // other slot so the user can always self-heal via auth reset / logout. + if cfg, err := LoadProfiles(configDir); err == nil { + for _, profile := range cfg.Profiles { + if e := DeleteTokenDataKeychainForCorpID(profile.CorpID); e != nil && firstErr == nil { + firstErr = e + } + } + } + if e := os.Remove(ProfilesPath(configDir)); e != nil && !os.IsNotExist(e) && firstErr == nil { + firstErr = e + } + // Sweep any quarantined corrupt-profiles files so they don't accumulate. + if matches, _ := filepath.Glob(ProfilesPath(configDir) + ".corrupt-*"); len(matches) > 0 { + for _, m := range matches { + if e := os.Remove(m); e != nil && !os.IsNotExist(e) && firstErr == nil { + firstErr = e + } + } + } + if e := DeleteTokenDataKeychain(); e != nil && firstErr == nil { + firstErr = e + } + if e := DeleteSecureData(configDir); e != nil && firstErr == nil { + firstErr = e + } + if e := DeleteTokenMarker(configDir); e != nil && firstErr == nil { + firstErr = e + } + return firstErr + }) } // RevokeTokenRemote calls the appropriate logout/revoke endpoint to invalidate the access token. diff --git a/internal/auth/token_test.go b/internal/auth/token_test.go index 33f77caa..4651b1b8 100644 --- a/internal/auth/token_test.go +++ b/internal/auth/token_test.go @@ -14,6 +14,7 @@ package auth import ( + "os" "testing" "time" @@ -25,8 +26,10 @@ import ( // written by these tests, and removes test data on completion. func cleanupKeychain(t *testing.T) { t.Helper() + SetRuntimeProfile("") t.Setenv(keychain.StorageDirEnv, t.TempDir()) t.Cleanup(func() { + SetRuntimeProfile("") _ = keychain.Remove(keychain.Service, keychain.AccountToken) }) } @@ -127,6 +130,271 @@ func TestTokenOverwrite(t *testing.T) { } } +func TestMultiProfileSaveLoadAndSwitch(t *testing.T) { + cleanupKeychain(t) + configDir := t.TempDir() + + dataA := testToken("at_a", "corp_a", "A Org") + dataB := testToken("at_b", "corp_b", "B Org") + if err := SaveTokenData(configDir, dataA); err != nil { + t.Fatalf("SaveTokenData(A) error = %v", err) + } + if err := SaveTokenData(configDir, dataB); err != nil { + t.Fatalf("SaveTokenData(B) error = %v", err) + } + + cfg, err := LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.PrimaryProfile != "corp_a" || cfg.CurrentProfile != "corp_b" || cfg.PreviousProfile != "corp_a" { + t.Fatalf("profile pointers = primary %q current %q previous %q", cfg.PrimaryProfile, cfg.CurrentProfile, cfg.PreviousProfile) + } + + loadedB, err := LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() error = %v", err) + } + if loadedB.AccessToken != "at_b" { + t.Fatalf("default token = %q, want at_b", loadedB.AccessToken) + } + loadedA, err := LoadTokenDataForProfile(configDir, "A Org") + if err != nil { + t.Fatalf("LoadTokenDataForProfile(A Org) error = %v", err) + } + if loadedA.AccessToken != "at_a" { + t.Fatalf("profile A token = %q, want at_a", loadedA.AccessToken) + } + + if _, err := SetCurrentProfile(configDir, "corp_a"); err != nil { + t.Fatalf("SetCurrentProfile(A) error = %v", err) + } + loadedA, err = LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() after switch error = %v", err) + } + if loadedA.AccessToken != "at_a" { + t.Fatalf("default token after switch = %q, want at_a", loadedA.AccessToken) + } + if _, err := UsePreviousProfile(configDir); err != nil { + t.Fatalf("UsePreviousProfile() error = %v", err) + } + loadedB, err = LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() after previous error = %v", err) + } + if loadedB.AccessToken != "at_b" { + t.Fatalf("default token after previous = %q, want at_b", loadedB.AccessToken) + } +} + +func TestRuntimeProfileOverrideDoesNotMutateCurrent(t *testing.T) { + cleanupKeychain(t) + configDir := t.TempDir() + + if err := SaveTokenData(configDir, testToken("at_a", "corp_a", "A Org")); err != nil { + t.Fatalf("SaveTokenData(A) error = %v", err) + } + if err := SaveTokenData(configDir, testToken("at_b", "corp_b", "B Org")); err != nil { + t.Fatalf("SaveTokenData(B) error = %v", err) + } + if _, err := SetCurrentProfile(configDir, "corp_a"); err != nil { + t.Fatalf("SetCurrentProfile(A) error = %v", err) + } + + SetRuntimeProfile("corp_b") + if err := SaveTokenData(configDir, testToken("at_b_refreshed", "corp_b", "B Org")); err != nil { + t.Fatalf("SaveTokenData(B refresh) error = %v", err) + } + SetRuntimeProfile("") + + cfg, err := LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.CurrentProfile != "corp_a" { + t.Fatalf("current profile = %q, want corp_a", cfg.CurrentProfile) + } + loadedB, err := LoadTokenDataForProfile(configDir, "corp_b") + if err != nil { + t.Fatalf("LoadTokenDataForProfile(B) error = %v", err) + } + if loadedB.AccessToken != "at_b_refreshed" { + t.Fatalf("profile B token = %q, want at_b_refreshed", loadedB.AccessToken) + } + loadedDefault, err := LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() error = %v", err) + } + if loadedDefault.AccessToken != "at_a" { + t.Fatalf("default token = %q, want at_a", loadedDefault.AccessToken) + } +} + +func TestDeleteProfilePreservesOtherProfiles(t *testing.T) { + cleanupKeychain(t) + configDir := t.TempDir() + + if err := SaveTokenData(configDir, testToken("at_a", "corp_a", "A Org")); err != nil { + t.Fatalf("SaveTokenData(A) error = %v", err) + } + if err := SaveTokenData(configDir, testToken("at_b", "corp_b", "B Org")); err != nil { + t.Fatalf("SaveTokenData(B) error = %v", err) + } + if err := DeleteTokenDataForProfile(configDir, "corp_b"); err != nil { + t.Fatalf("DeleteTokenDataForProfile(B) error = %v", err) + } + if _, err := LoadTokenDataForProfile(configDir, "corp_b"); err == nil { + t.Fatal("LoadTokenDataForProfile(B) error = nil after delete, want failure") + } + loadedA, err := LoadTokenDataForProfile(configDir, "corp_a") + if err != nil { + t.Fatalf("LoadTokenDataForProfile(A) error = %v", err) + } + if loadedA.AccessToken != "at_a" { + t.Fatalf("profile A token = %q, want at_a", loadedA.AccessToken) + } + cfg, err := LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if len(cfg.Profiles) != 1 || cfg.CurrentProfile != "corp_a" { + t.Fatalf("profiles after delete = %#v", cfg) + } +} + +func TestUpsertProfileFromTokenOverwritesSameCorp(t *testing.T) { + cleanupKeychain(t) + configDir := t.TempDir() + + first := testToken("at_first", "corp_same", "旧组织名") + if err := SaveTokenData(configDir, first); err != nil { + t.Fatalf("SaveTokenData(first) error = %v", err) + } + second := testToken("at_second", "corp_same", "新组织名") + second.UserID = "user_updated" + second.UserName = "Updated User" + second.ClientID = "client_updated" + if err := SaveTokenData(configDir, second); err != nil { + t.Fatalf("SaveTokenData(second) error = %v", err) + } + + cfg, err := LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if len(cfg.Profiles) != 1 { + t.Fatalf("profiles len = %d, want 1: %#v", len(cfg.Profiles), cfg.Profiles) + } + profile := cfg.Profiles[0] + if profile.CorpName != "新组织名" { + t.Fatalf("corpName = %q, want 新组织名", profile.CorpName) + } + if profile.UserID != "user_updated" || profile.UserName != "Updated User" || profile.ClientID != "client_updated" { + t.Fatalf("profile metadata was not overwritten: %#v", profile) + } + loaded, err := LoadTokenDataForProfile(configDir, "corp_same") + if err != nil { + t.Fatalf("LoadTokenDataForProfile() error = %v", err) + } + if loaded.AccessToken != "at_second" { + t.Fatalf("access token = %q, want at_second", loaded.AccessToken) + } +} + +func TestUpsertProfileFromTokenPromotesCorpIDNameToCorpName(t *testing.T) { + cleanupKeychain(t) + configDir := t.TempDir() + + first := testToken("at_first", "corp_same", "") + if err := SaveTokenData(configDir, first); err != nil { + t.Fatalf("SaveTokenData(first) error = %v", err) + } + second := testToken("at_second", "corp_same", "新组织名") + if err := SaveTokenData(configDir, second); err != nil { + t.Fatalf("SaveTokenData(second) error = %v", err) + } + + cfg, err := LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if len(cfg.Profiles) != 1 { + t.Fatalf("profiles len = %d, want 1: %#v", len(cfg.Profiles), cfg.Profiles) + } + if cfg.Profiles[0].Name != "新组织名" { + t.Fatalf("profile name = %q, want 新组织名", cfg.Profiles[0].Name) + } + + resolved, err := ResolveProfile(configDir, "新组织名") + if err != nil { + t.Fatalf("ResolveProfile(corpName) error = %v", err) + } + if resolved.CorpID != "corp_same" { + t.Fatalf("resolved corpId = %q, want corp_same", resolved.CorpID) + } +} + +func TestLoadProfilesPromotesLegacyCorpIDNameToCorpName(t *testing.T) { + configDir := t.TempDir() + raw := `{ + "version": 1, + "primaryProfile": "corp_same", + "currentProfile": "corp_same", + "profiles": [ + { + "name": "corp_same", + "corpId": "corp_same", + "corpName": "新组织名" + } + ] +}` + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.WriteFile(ProfilesPath(configDir), []byte(raw), 0o600); err != nil { + t.Fatalf("WriteFile(profiles.json) error = %v", err) + } + + cfg, err := LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if len(cfg.Profiles) != 1 { + t.Fatalf("profiles len = %d, want 1", len(cfg.Profiles)) + } + if cfg.Profiles[0].Name != "新组织名" { + t.Fatalf("profile name = %q, want 新组织名", cfg.Profiles[0].Name) + } +} + +func TestLegacyKeychainMigrationInitializesProfile(t *testing.T) { + cleanupKeychain(t) + configDir := t.TempDir() + + legacy := testToken("at_legacy", "corp_legacy", "Legacy Org") + if err := SaveTokenDataKeychain(legacy); err != nil { + t.Fatalf("SaveTokenDataKeychain() error = %v", err) + } + loaded, err := LoadTokenData(configDir) + if err != nil { + t.Fatalf("LoadTokenData() error = %v", err) + } + if loaded.AccessToken != "at_legacy" { + t.Fatalf("loaded token = %q, want at_legacy", loaded.AccessToken) + } + cfg, err := LoadProfiles(configDir) + if err != nil { + t.Fatalf("LoadProfiles() error = %v", err) + } + if cfg.PrimaryProfile != "corp_legacy" || cfg.CurrentProfile != "corp_legacy" { + t.Fatalf("profile pointers after migration = %#v", cfg) + } + if !TokenDataExistsKeychainForCorpID("corp_legacy") { + t.Fatal("corp-scoped token should exist after migration") + } +} + func TestTokenDataExistsKeychain(t *testing.T) { cleanupKeychain(t) @@ -152,6 +420,21 @@ func TestTokenDataExistsKeychain(t *testing.T) { } } +func testToken(accessToken, corpID, corpName string) *TokenData { + now := time.Now().UTC() + return &TokenData{ + AccessToken: accessToken, + RefreshToken: "rt_" + accessToken, + ExpiresAt: now.Add(2 * time.Hour), + RefreshExpAt: now.Add(30 * 24 * time.Hour), + CorpID: corpID, + CorpName: corpName, + UserID: "user_" + corpID, + UserName: "User " + corpID, + ClientID: "client_" + corpID, + } +} + func TestTokenValidityChecks(t *testing.T) { t.Parallel() diff --git a/internal/compat/registry.go b/internal/compat/registry.go index f610a934..c47e6dec 100644 --- a/internal/compat/registry.go +++ b/internal/compat/registry.go @@ -751,7 +751,7 @@ func collectSchemaFlags(cmd *cobra.Command, bindings []FlagBinding, params map[s "json": true, "params": true, "help": true, "format": true, "fields": true, "jq": true, "debug": true, "verbose": true, "dry-run": true, - "yes": true, "mock": true, "timeout": true, + "yes": true, "mock": true, "profile": true, "timeout": true, "client-id": true, "client-secret": true, } diff --git a/internal/compat/registry_test.go b/internal/compat/registry_test.go index 6c39d2fb..996d2d38 100644 --- a/internal/compat/registry_test.go +++ b/internal/compat/registry_test.go @@ -288,6 +288,7 @@ func TestCollectSchemaFlagsSkipsGlobalFlags(t *testing.T) { cmd.Flags().Bool("verbose", false, "Verbose") cmd.Flags().Bool("dry-run", false, "Dry run") cmd.Flags().String("format", "json", "Format") + cmd.Flags().String("profile", "", "Profile") cmd.Flags().String("json", "", "") cmd.Flags().String("params", "", "") @@ -296,6 +297,7 @@ func TestCollectSchemaFlagsSkipsGlobalFlags(t *testing.T) { _ = cmd.Flags().Set("verbose", "true") _ = cmd.Flags().Set("dry-run", "true") _ = cmd.Flags().Set("format", "table") + _ = cmd.Flags().Set("profile", "corp_profile") params := make(map[string]any) collectSchemaFlags(cmd, nil, params) @@ -304,7 +306,7 @@ func TestCollectSchemaFlagsSkipsGlobalFlags(t *testing.T) { t.Errorf("name = %v, want Bob", params["name"]) } // Global flags should be skipped - for _, skip := range []string{"debug", "verbose", "dry_run", "format"} { + for _, skip := range []string{"debug", "verbose", "dry_run", "format", "profile"} { if _, exists := params[skip]; exists { t.Errorf("%s should be skipped (global flag)", skip) } diff --git a/scripts/dev/test-multi-profile-e2e.sh b/scripts/dev/test-multi-profile-e2e.sh new file mode 100755 index 00000000..74fe6498 --- /dev/null +++ b/scripts/dev/test-multi-profile-e2e.sh @@ -0,0 +1,623 @@ +#!/usr/bin/env bash +# End-to-end regression script for multi-profile / multi-organization login. +# It uses an isolated DWS_CONFIG_DIR and DWS_KEYCHAIN_DIR, seeds post-login +# token results through the production auth storage API, then verifies the real +# dws CLI command surface. +# +# Usage: +# bash scripts/dev/test-multi-profile-e2e.sh +# bash scripts/dev/test-multi-profile-e2e.sh --skip-go-tests --verbose + +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +RUN_GO_TESTS=1 +VERBOSE=0 +KEEP_WORKDIR=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --skip-go-tests) + RUN_GO_TESTS=0 + shift + ;; + --verbose) + VERBOSE=1 + shift + ;; + --keep-workdir) + KEEP_WORKDIR=1 + shift + ;; + -h|--help) + sed -n '1,12p' "$0" + exit 0 + ;; + *) + echo "unknown option: $1" >&2 + exit 2 + ;; + esac +done + +mkdir -p "$ROOT/.tmp-bin" +WORKDIR="$(mktemp -d "$ROOT/.tmp-bin/multi-profile-e2e.XXXXXX")" +BIN="$WORKDIR/bin/dws" +HELPER_DIR="$WORKDIR/helper" +CONFIG_DIR="$WORKDIR/config" +KEYCHAIN_DIR="$WORKDIR/keychain" +CACHE_DIR="$WORKDIR/cache" +OUT_DIR="$WORKDIR/out" + +cleanup() { + if [[ "$KEEP_WORKDIR" -eq 1 ]]; then + echo "[INFO] kept workdir: $WORKDIR" + else + rm -rf "$WORKDIR" + fi +} +trap cleanup EXIT + +export DWS_CONFIG_DIR="$CONFIG_DIR" +export DWS_KEYCHAIN_DIR="$KEYCHAIN_DIR" +export DWS_DISABLE_KEYCHAIN=1 +export DWS_CACHE_DIR="$CACHE_DIR" +export DWS_PERF_REPORT= +export DWS_PERF_DEBUG= + +mkdir -p "$HELPER_DIR" "$CONFIG_DIR" "$KEYCHAIN_DIR" "$CACHE_DIR" "$OUT_DIR" "$(dirname "$BIN")" + +log() { + printf '\n==> %s\n' "$*" +} + +fail() { + echo "[FAIL] $*" >&2 + exit 1 +} + +run() { + if [[ "$VERBOSE" -eq 1 ]]; then + "$@" + else + "$@" >/dev/null + fi +} + +capture() { + local file="$1" + shift + if [[ "$VERBOSE" -eq 1 ]]; then + echo "+ $*" >&2 + fi + "$@" >"$file" 2>"$file.stderr" +} + +expect_contains() { + local file="$1" + local needle="$2" + if ! grep -F -- "$needle" "$file" >/dev/null; then + echo "----- $file -----" >&2 + cat "$file" >&2 + fail "expected $file to contain: $needle" + fi +} + +expect_not_contains_line_command() { + local file="$1" + local command="$2" + if grep -E "^[[:space:]]+$command([[:space:]]|$)" "$file" >/dev/null; then + echo "----- $file -----" >&2 + cat "$file" >&2 + fail "did not expect command '$command' in $file" + fi +} + +expect_fail() { + local needle="$1" + shift + local output + set +e + output="$("$@" 2>&1)" + local code=$? + set -e + if [[ "$code" -eq 0 ]]; then + echo "$output" >&2 + fail "expected command to fail: $*" + fi + if ! grep -F -- "$needle" <<<"$output" >/dev/null; then + echo "$output" >&2 + fail "expected failure output to contain: $needle" + fi +} + +cat >"$HELPER_DIR/main.go" <<'GOEOF' +package main + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + auth "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/auth" +) + +type profileListResponse struct { + Success bool `json:"success"` + PrimaryProfile string `json:"primaryProfile"` + CurrentProfile string `json:"currentProfile"` + PreviousProfile string `json:"previousProfile"` + Profiles []profileView `json:"profiles"` +} + +type profileUseResponse struct { + Success bool `json:"success"` + Profile profileView `json:"profile"` +} + +type profileView struct { + CorpID string `json:"corpId"` + CorpName string `json:"corpName"` + UserID string `json:"userId"` + UserName string `json:"userName"` + Status string `json:"status"` + IsPrimary bool `json:"isPrimary"` + IsCurrent bool `json:"isCurrent"` +} + +type authStatusResponse struct { + Success bool `json:"success"` + Authenticated bool `json:"authenticated"` + TokenValid bool `json:"token_valid"` + RefreshTokenValid bool `json:"refresh_token_valid"` + CorpID string `json:"corp_id"` + CorpName string `json:"corp_name"` + UserID string `json:"user_id"` + UserName string `json:"user_name"` +} + +type multiProfileResponse struct { + Success bool `json:"success"` + MultiProfile bool `json:"multiProfile"` + Summary multiProfileSummary `json:"summary"` + Profiles []multiProfileResult `json:"profiles"` +} + +type multiProfileSummary struct { + Total int `json:"total"` + Succeeded int `json:"succeeded"` + Failed int `json:"failed"` +} + +type multiProfileResult struct { + Selector string `json:"selector"` + CorpID string `json:"corpId"` + CorpName string `json:"corpName"` + OK bool `json:"ok"` + Result map[string]any `json:"result"` +} + +func main() { + if len(os.Args) < 2 { + die("missing helper command") + } + configDir := os.Getenv("DWS_CONFIG_DIR") + if strings.TrimSpace(configDir) == "" { + die("DWS_CONFIG_DIR is required") + } + switch os.Args[1] { + case "seed": + needArgs(7) + data := token(os.Args[2], os.Args[3], os.Args[4], os.Args[5], os.Args[6]) + must(auth.SaveTokenData(configDir, data)) + case "seed-legacy": + needArgs(7) + data := token(os.Args[2], os.Args[3], os.Args[4], os.Args[5], os.Args[6]) + must(auth.SaveTokenDataKeychain(data)) + must(auth.WriteTokenMarker(configDir)) + case "write-app-config": + needArgs(4) + must(auth.SaveAppConfig(configDir, &auth.AppConfig{ + ClientID: os.Args[2], + ClientSecret: auth.PlainSecret(os.Args[3]), + })) + case "assert-app-config": + needArgs(3) + cfg, err := auth.LoadAppConfig(configDir) + must(err) + switch os.Args[2] { + case "exists": + if cfg == nil || strings.TrimSpace(cfg.ClientID) == "" { + die("expected app config to exist") + } + case "absent": + if cfg != nil { + die("expected app config to be absent, got clientID=%q", cfg.ClientID) + } + default: + die("unknown app config expectation %q", os.Args[2]) + } + case "assert-profiles": + needArgs(6) + cfg, err := auth.LoadProfiles(configDir) + must(err) + wantCount := atoi(os.Args[2]) + if len(cfg.Profiles) != wantCount { + die("profiles len=%d, want %d: %#v", len(cfg.Profiles), wantCount, cfg.Profiles) + } + assertEqual("primaryProfile", cfg.PrimaryProfile, emptySentinel(os.Args[3])) + assertEqual("currentProfile", cfg.CurrentProfile, emptySentinel(os.Args[4])) + assertEqual("previousProfile", cfg.PreviousProfile, emptySentinel(os.Args[5])) + assertNoSecrets(configDir) + assertProfileMetadata(cfg) + case "assert-list-json": + needArgs(7) + var resp profileListResponse + raw := readJSON(os.Args[2], &resp) + if strings.Contains(string(raw), `"name"`) { + die("profile list JSON must not expose local name: %s", string(raw)) + } + if !resp.Success { + die("profile list success=false") + } + wantCount := atoi(os.Args[3]) + if len(resp.Profiles) != wantCount { + die("list profiles len=%d, want %d: %#v", len(resp.Profiles), wantCount, resp.Profiles) + } + assertEqual("list primaryProfile", resp.PrimaryProfile, emptySentinel(os.Args[4])) + assertEqual("list currentProfile", resp.CurrentProfile, emptySentinel(os.Args[5])) + assertEqual("list previousProfile", resp.PreviousProfile, emptySentinel(os.Args[6])) + for _, p := range resp.Profiles { + if strings.TrimSpace(p.CorpID) == "" || strings.TrimSpace(p.CorpName) == "" { + die("profile list item missing corp identity: %#v", p) + } + if p.CorpID == resp.PrimaryProfile && !p.IsPrimary { + die("profile %s should be primary", p.CorpID) + } + if p.CorpID == resp.CurrentProfile && !p.IsCurrent { + die("profile %s should be current", p.CorpID) + } + } + case "assert-switch-json": + needArgs(5) + var resp profileUseResponse + readJSON(os.Args[2], &resp) + if !resp.Success { + die("switch JSON success=false") + } + assertEqual("switch corpId", resp.Profile.CorpID, os.Args[3]) + assertEqual("switch corpName", resp.Profile.CorpName, os.Args[4]) + if !resp.Profile.IsCurrent { + die("switch profile isCurrent=false") + } + case "assert-status-json": + needArgs(6) + var resp authStatusResponse + readJSON(os.Args[2], &resp) + if !resp.Success || !resp.Authenticated || !resp.TokenValid || !resp.RefreshTokenValid { + die("bad auth status response: %#v", resp) + } + assertEqual("status corpId", resp.CorpID, os.Args[3]) + assertEqual("status corpName", resp.CorpName, os.Args[4]) + assertEqual("status userId", resp.UserID, os.Args[5]) + case "assert-multi-profile-json": + needArgs(5) + var resp multiProfileResponse + readJSON(os.Args[2], &resp) + if !resp.Success || !resp.MultiProfile { + die("bad multi-profile response: %#v", resp) + } + wantCount := atoi(os.Args[3]) + if len(resp.Profiles) != wantCount { + die("multi-profile len=%d, want %d: %#v", len(resp.Profiles), wantCount, resp.Profiles) + } + if resp.Summary.Total != wantCount || resp.Summary.Succeeded != wantCount || resp.Summary.Failed != 0 { + die("bad multi-profile summary: %#v", resp.Summary) + } + wantCorpIDs := strings.Split(os.Args[4], ",") + if len(wantCorpIDs) != wantCount { + die("want corpId count=%d, want %d", len(wantCorpIDs), wantCount) + } + for i, want := range wantCorpIDs { + want = strings.TrimSpace(want) + got := resp.Profiles[i] + if !got.OK { + die("profile %d ok=false: %#v", i, got) + } + assertEqual(fmt.Sprintf("multi-profile corpId[%d]", i), got.CorpID, want) + if got.Result["_mock"] != true { + die("profile %s result is not mock payload: %#v", got.CorpID, got.Result) + } + } + case "assert-token": + needArgs(5) + data, err := loadToken(configDir, os.Args[2]) + must(err) + assertEqual("token corpId", data.CorpID, os.Args[3]) + assertEqual("token access", data.AccessToken, os.Args[4]) + case "assert-empty-auth": + needArgs(2) + cfg, err := auth.LoadProfiles(configDir) + must(err) + if cfg.PrimaryProfile != "" || cfg.CurrentProfile != "" || cfg.PreviousProfile != "" || len(cfg.Profiles) != 0 { + die("expected empty profiles after reset, got %#v", cfg) + } + if auth.TokenDataExistsKeychain() { + die("legacy auth-token still exists") + } + case "assert-duplicate-name-fallback": + needArgs(4) + cfg, err := auth.LoadProfiles(configDir) + must(err) + p := findProfile(cfg, os.Args[2]) + if p == nil { + die("profile %q not found", os.Args[2]) + } + if p.CorpName != os.Args[3] { + die("profile %s corpName=%q, want %q", p.CorpID, p.CorpName, os.Args[3]) + } + if p.Name == os.Args[3] || !strings.HasPrefix(p.Name, os.Args[3]+"-") { + die("profile %s name=%q, want stable fallback prefix %q", p.CorpID, p.Name, os.Args[3]+"-") + } + default: + die("unknown helper command %q", os.Args[1]) + } +} + +func token(corpID, corpName, userID, userName, access string) *auth.TokenData { + return &auth.TokenData{ + AccessToken: access, + RefreshToken: "refresh-" + corpID, + PersistentCode: "persistent-" + corpID, + ExpiresAt: time.Now().Add(2 * time.Hour), + RefreshExpAt: time.Now().Add(720 * time.Hour), + CorpID: corpID, + CorpName: corpName, + UserID: userID, + UserName: userName, + ClientID: "client-" + corpID, + Source: "multi-profile-e2e", + } +} + +func needArgs(n int) { + if len(os.Args) != n { + die("%s: got %d args, want %d", os.Args[1], len(os.Args)-2, n-2) + } +} + +func loadToken(configDir, selector string) (*auth.TokenData, error) { + if selector == "default" { + return auth.LoadTokenData(configDir) + } + return auth.LoadTokenDataForProfile(configDir, selector) +} + +func readJSON(path string, dst any) []byte { + data, err := os.ReadFile(path) + must(err) + if err := json.Unmarshal(data, dst); err != nil { + die("parse %s: %v\n%s", path, err, string(data)) + } + return data +} + +func assertProfileMetadata(cfg *auth.ProfilesConfig) { + names := map[string]string{} + for _, p := range cfg.Profiles { + if strings.TrimSpace(p.CorpID) == "" || strings.TrimSpace(p.CorpName) == "" { + die("profile missing corp metadata: %#v", p) + } + if prev, ok := names[p.Name]; ok { + die("duplicate profile local name %q for %s and %s", p.Name, prev, p.CorpID) + } + names[p.Name] = p.CorpID + } +} + +func assertNoSecrets(configDir string) { + data, err := os.ReadFile(filepath.Join(configDir, "profiles.json")) + if err != nil { + if os.IsNotExist(err) { + return + } + must(err) + } + for _, forbidden := range []string{"access_token", "refresh_token", "persistent_code", "client_secret"} { + if strings.Contains(string(data), forbidden) { + die("profiles.json contains secret field %q", forbidden) + } + } +} + +func findProfile(cfg *auth.ProfilesConfig, corpID string) *auth.Profile { + for i := range cfg.Profiles { + if cfg.Profiles[i].CorpID == corpID { + return &cfg.Profiles[i] + } + } + return nil +} + +func atoi(raw string) int { + var n int + if _, err := fmt.Sscanf(raw, "%d", &n); err != nil { + die("invalid integer %q", raw) + } + return n +} + +func emptySentinel(s string) string { + if s == "_" { + return "" + } + return s +} + +func assertEqual(label, got, want string) { + if got != want { + die("%s=%q, want %q", label, got, want) + } +} + +func must(err error) { + if err != nil { + die("%v", err) + } +} + +func die(format string, args ...any) { + fmt.Fprintf(os.Stderr, format+"\n", args...) + os.Exit(1) +} +GOEOF + +cd "$ROOT" + +if [[ "$RUN_GO_TESTS" -eq 1 ]]; then + log "running multi-profile Go regressions" + go test -timeout 180s -count=1 ./internal/auth ./internal/app ./test/cli +fi + +log "building dws" +run go build -o "$BIN" ./cmd + +helper() { + go run "$HELPER_DIR" "$@" +} + +log "checking command surface" +capture "$OUT_DIR/root-help.txt" "$BIN" --help +expect_contains "$OUT_DIR/root-help.txt" "--profile" +expect_contains "$OUT_DIR/root-help.txt" "--yes" +expect_contains "$OUT_DIR/root-help.txt" "--dry-run" +expect_contains "$OUT_DIR/root-help.txt" "profile" +capture "$OUT_DIR/profile-help.txt" "$BIN" profile --help +expect_contains "$OUT_DIR/profile-help.txt" "list" +expect_contains "$OUT_DIR/profile-help.txt" "switch" +expect_contains "$OUT_DIR/profile-help.txt" "use" +expect_contains "$OUT_DIR/profile-help.txt" "--profile" +capture "$OUT_DIR/auth-login-help.txt" "$BIN" auth login --help +expect_contains "$OUT_DIR/auth-login-help.txt" "--device" +expect_contains "$OUT_DIR/auth-login-help.txt" "--token" +expect_contains "$OUT_DIR/auth-login-help.txt" "--recommend" +expect_contains "$OUT_DIR/auth-login-help.txt" "--yes" +capture "$OUT_DIR/skill-setup-help.txt" "$BIN" skill setup --help +expect_contains "$OUT_DIR/skill-setup-help.txt" "--mode" +expect_contains "$OUT_DIR/skill-setup-help.txt" "--target" +expect_contains "$OUT_DIR/skill-setup-help.txt" "--yes" +expect_contains "$OUT_DIR/skill-setup-help.txt" "--skill" +expect_contains "$OUT_DIR/skill-setup-help.txt" "--exclude" +capture "$OUT_DIR/upgrade-help.txt" "$BIN" upgrade --help +expect_contains "$OUT_DIR/upgrade-help.txt" "--dry-run" +expect_contains "$OUT_DIR/upgrade-help.txt" "--yes" +capture "$OUT_DIR/dev-connect-help.txt" "$BIN" dev connect --help +expect_contains "$OUT_DIR/dev-connect-help.txt" "--robot-client-id" +expect_contains "$OUT_DIR/dev-connect-help.txt" "--robot-client-secret" +expect_contains "$OUT_DIR/dev-connect-help.txt" "--unified-app-id" +expect_contains "$OUT_DIR/dev-connect-help.txt" "--agent-cmd" +expect_contains "$OUT_DIR/dev-connect-help.txt" "--daemon" +capture "$OUT_DIR/doc-delete-help.txt" "$BIN" doc delete --help +expect_contains "$OUT_DIR/doc-delete-help.txt" "--yes" +capture "$OUT_DIR/aitable-base-delete-help.txt" "$BIN" aitable base delete --help +expect_contains "$OUT_DIR/aitable-base-delete-help.txt" "--yes" +capture "$OUT_DIR/auth-help.txt" "$BIN" auth --help +expect_not_contains_line_command "$OUT_DIR/auth-help.txt" "switch" + +log "verifying empty profile list" +capture "$OUT_DIR/list-empty.json" "$BIN" profile list --format json +helper assert-list-json "$OUT_DIR/list-empty.json" 0 _ _ _ + +log "seeding first organization profile" +helper seed corp_alpha "Alpha Org" user_alpha "Alice Alpha" access-alpha-v1 +capture "$OUT_DIR/list-alpha.json" "$BIN" profile list --format json +helper assert-list-json "$OUT_DIR/list-alpha.json" 1 corp_alpha corp_alpha _ +helper assert-profiles 1 corp_alpha corp_alpha _ +helper assert-token default corp_alpha access-alpha-v1 +helper assert-token corp_alpha corp_alpha access-alpha-v1 +capture "$OUT_DIR/status-alpha-default.json" "$BIN" auth status --format json +helper assert-status-json "$OUT_DIR/status-alpha-default.json" corp_alpha "Alpha Org" user_alpha + +log "seeding second organization profile" +helper seed corp_beta "Beta Org" user_beta "Bob Beta" access-beta-v1 +capture "$OUT_DIR/list-alpha-beta.json" "$BIN" profile list --format json +helper assert-list-json "$OUT_DIR/list-alpha-beta.json" 2 corp_alpha corp_beta corp_alpha +helper assert-profiles 2 corp_alpha corp_beta corp_alpha +helper assert-token default corp_beta access-beta-v1 +helper assert-token corp_alpha corp_alpha access-alpha-v1 +helper assert-token corp_beta corp_beta access-beta-v1 + +log "refreshing existing organization without duplicating profile" +helper seed corp_beta "Beta Org" user_beta "Bob Beta" access-beta-v2 +capture "$OUT_DIR/list-beta-refresh.json" "$BIN" profile list --format json +helper assert-list-json "$OUT_DIR/list-beta-refresh.json" 2 corp_alpha corp_beta corp_alpha +helper assert-profiles 2 corp_alpha corp_beta corp_alpha +helper assert-token corp_beta corp_beta access-beta-v2 + +log "seeding duplicate organization name and checking stable fallback" +helper seed corp_gamma "Beta Org" user_gamma "Gina Gamma" access-gamma-v1 +capture "$OUT_DIR/list-duplicate-name.json" "$BIN" profile list --format json +helper assert-list-json "$OUT_DIR/list-duplicate-name.json" 3 corp_alpha corp_gamma corp_beta +helper assert-profiles 3 corp_alpha corp_gamma corp_beta +helper assert-duplicate-name-fallback corp_gamma "Beta Org" + +log "switching profiles and verifying legacy mirror" +capture "$OUT_DIR/switch-alpha.json" "$BIN" profile switch corp_alpha --format json +helper assert-switch-json "$OUT_DIR/switch-alpha.json" corp_alpha "Alpha Org" +helper assert-profiles 3 corp_alpha corp_alpha corp_gamma +helper assert-token default corp_alpha access-alpha-v1 +capture "$OUT_DIR/switch-beta.txt" "$BIN" profile switch corp_beta --format table +expect_contains "$OUT_DIR/switch-beta.txt" "Beta Org" +expect_contains "$OUT_DIR/switch-beta.txt" "corp_beta" +helper assert-profiles 3 corp_alpha corp_beta corp_alpha +helper assert-token default corp_beta access-beta-v2 +capture "$OUT_DIR/switch-previous.json" "$BIN" profile switch - --format json +helper assert-switch-json "$OUT_DIR/switch-previous.json" corp_alpha "Alpha Org" +helper assert-profiles 3 corp_alpha corp_alpha corp_beta +capture "$OUT_DIR/use-gamma.json" "$BIN" profile use corp_gamma --format json +helper assert-switch-json "$OUT_DIR/use-gamma.json" corp_gamma "Beta Org" +helper assert-profiles 3 corp_alpha corp_gamma corp_alpha + +log "checking profile switch validation" +expect_fail "profile selector required" "$BIN" profile switch +expect_fail "只能指定一个组织选择器" "$BIN" profile switch corp_alpha --corpId corp_beta +expect_fail "missing_org" "$BIN" profile switch missing_org + +log "checking one-shot profile override without changing current profile" +capture "$OUT_DIR/status-root-profile-alpha.json" "$BIN" --profile corp_alpha auth status --format json +helper assert-status-json "$OUT_DIR/status-root-profile-alpha.json" corp_alpha "Alpha Org" user_alpha +helper assert-profiles 3 corp_alpha corp_gamma corp_alpha +capture "$OUT_DIR/status-local-profile-beta.json" "$BIN" auth status --profile corp_beta --format json +helper assert-status-json "$OUT_DIR/status-local-profile-beta.json" corp_beta "Beta Org" user_beta +helper assert-profiles 3 corp_alpha corp_gamma corp_alpha +capture "$OUT_DIR/status-current-gamma.json" "$BIN" auth status --format json +helper assert-status-json "$OUT_DIR/status-current-gamma.json" corp_gamma "Beta Org" user_gamma +capture "$OUT_DIR/contact-multi-profile.json" "$BIN" --mock --profile corp_alpha, corp_beta contact user get-self --format json +helper assert-multi-profile-json "$OUT_DIR/contact-multi-profile.json" 2 corp_alpha,corp_beta +helper assert-profiles 3 corp_alpha corp_gamma corp_alpha +capture "$OUT_DIR/contact-multi-profile-leaf-profile.json" "$BIN" --mock contact user get-self --profile corp_alpha, corp_beta --format json +helper assert-multi-profile-json "$OUT_DIR/contact-multi-profile-leaf-profile.json" 2 corp_alpha,corp_beta +helper assert-profiles 3 corp_alpha corp_gamma corp_alpha + +log "checking auth reset cleanup" +helper write-app-config client-reset secret-reset +helper assert-app-config exists +capture "$OUT_DIR/auth-reset.txt" "$BIN" auth reset +expect_contains "$OUT_DIR/auth-reset.txt" "[OK]" +helper assert-empty-auth +helper assert-app-config absent + +log "checking legacy single-slot migration" +helper seed-legacy corp_legacy "Legacy Org" user_legacy "Lena Legacy" access-legacy-v1 +helper assert-profiles 0 _ _ _ +capture "$OUT_DIR/list-legacy-migrated.json" "$BIN" profile list --format json +helper assert-list-json "$OUT_DIR/list-legacy-migrated.json" 1 corp_legacy corp_legacy _ +helper assert-profiles 1 corp_legacy corp_legacy _ +helper assert-token default corp_legacy access-legacy-v1 +helper assert-token corp_legacy corp_legacy access-legacy-v1 + +log "multi-profile e2e passed" +echo "[PASS] isolated multi-profile chain completed" diff --git a/scripts/install-from-branch.sh b/scripts/install-from-branch.sh new file mode 100644 index 00000000..f7d2468a --- /dev/null +++ b/scripts/install-from-branch.sh @@ -0,0 +1,57 @@ +#!/bin/sh +# Copyright 2026 Alibaba Group +# Licensed under the Apache License, Version 2.0 +# +# Build and install dws directly from a Git branch checkout. +# +# Usage: +# curl -fsSL https://raw.githubusercontent.com/shangguanxuan633-lab/dingtalk-workspace-cli/codex/dws-multi-profile-login/scripts/install-from-branch.sh | sh +# +# Environment variables: +# DWS_SOURCE_REPO owner/repo to clone (default: shangguanxuan633-lab/dingtalk-workspace-cli) +# DWS_SOURCE_BRANCH branch to build (default: codex/dws-multi-profile-login) +# DWS_INSTALL_DIR passed through to scripts/install.sh (default there: ~/.local/bin) +# DWS_INSTALL_NAME passed through to scripts/install.sh (default: dws) +# DWS_NO_SKILLS passed through to scripts/install.sh (set 1 to skip skills) +# DWS_KEEP_SOURCE set 1 to keep the temporary source checkout + +set -eu + +REPO="${DWS_SOURCE_REPO:-shangguanxuan633-lab/dingtalk-workspace-cli}" +BRANCH="${DWS_SOURCE_BRANCH:-codex/dws-multi-profile-login}" +KEEP_SOURCE="${DWS_KEEP_SOURCE:-0}" + +say() { + printf ' %s\n' "$@" +} + +err() { + printf ' ❌ %s\n' "$@" >&2 + exit 1 +} + +need_cmd() { + command -v "$1" >/dev/null 2>&1 || err "Missing required command: $1" +} + +need_cmd git +need_cmd sh + +tmpdir="$(mktemp -d 2>/dev/null || mktemp -d -t dws-src)" +cleanup() { + if [ "$KEEP_SOURCE" != "1" ]; then + rm -rf "$tmpdir" + else + say "Source checkout kept at: $tmpdir" + fi +} +trap cleanup EXIT INT TERM + +say "Cloning dws source:" +say " repo: https://github.com/${REPO}.git" +say " branch: ${BRANCH}" + +git clone --depth 1 --branch "$BRANCH" "https://github.com/${REPO}.git" "$tmpdir" + +say "Building and installing from source..." +sh "$tmpdir/scripts/install.sh" diff --git a/skills/mono/SKILL.md b/skills/mono/SKILL.md index fee8b882..4f8dc473 100644 --- a/skills/mono/SKILL.md +++ b/skills/mono/SKILL.md @@ -27,6 +27,7 @@ cli_version: ">=1.0.15" - **脚本优先**:[scripts/](./scripts/) 下的 `python scripts/.py` 已封装翻页/轮询/批量逻辑,遇到对应场景(如 AI 表格批量导入导出、AI 应用创建轮询、文档创建后写内容、钉盘目录树等)**优先调用脚本**而非手写多步命令。脚本均支持 `--dry-run` 预览、`--format json` 输出,失败时回退到手动步骤 - **业务域最佳实践优先**:文档类多步任务先读 [04-document.md](./references/best_practices/04-document.md);AI 表格读取/统计/写入/导入导出先读 [06-data-analytics.md](./references/best_practices/06-data-analytics.md)。本仓库只迁入这些业务域 best practices,不引入其它产品行动指南。 - 知识库容器只用 `dws wiki space/member`;知识库内文件/文档的浏览、搜索、读取、创建、移动、复制统一切到 `dws doc`。`workspaceId` 只能传给 `wiki --workspace`、`doc --workspace` 或 `doc search --workspace-ids`,禁止传给 `doc list --folder`,也不要使用不存在的 `--space-id`。 +- 找群 / 找人 / 找数据在当前组织没命中、且 `dws profile list` 显示 ≥2 个组织时,对每个组织带一次性 `--profile ` 各搜一遍;命中即用,全部组织都没有才追问用户。禁止在当前组织搜不到就判定「不存在」或直接甩给用户选。 ## 开放平台文档 RAG / 错误码排查 @@ -72,6 +73,30 @@ cli_version: ">=1.0.15" 4. **Fallback 单产品路由**:仅当行动指南未命中,且用户意图明确是单一产品单步操作时,才按「产品总览」和「意图判断决策树」选择产品,并读取对应 `references/products/*.md`。 5. **追问**:以上步骤都无法判断时,主动追问用户澄清,严禁猜测命令、flag、URL、ID 或字段名。 +## 多组织处理 +dws 可同时登录多个钉钉组织,一个 profile = 一个已登录组织(corp)。当前 profile 决定本次命令用哪个组织的身份(corpId / userId 按当前 profile 自动注入,不是只支持单组织)。 + +**触发条件(命中任一即进入本节)**: +- 显式:用户提到 切换 / 换 / 跨组织、另一个钉钉、别的公司、看登录了哪些组织、当前是哪个组织、某人 / 某群 / 某数据在别的组织 +- 隐式(最常见、易漏):在当前组织读 / 搜没找到目标(群 / 人 / 数据),且 `dws profile list` 显示已登录 ≥2 个组织 —— 别急着判「不存在」,按下方跨组织铁律去其他组织找 +- 需要跨多个组织汇总 / 对比数据 +- 用户问认证状态 / 登录了哪些组织 / 主组织是哪个 + +**不触发**:只登录 1 个组织时,按当前组织正常处理,不带 `--profile`,不进本节。 + +命令: +- `dws profile list` — 列出已登录组织(主 / 当前标记、状态、有效期),只读元数据 +- `dws profile switch <名称|corpId|->` — 持久切换当前组织;`-` 切回上一个;无参数在交互终端弹选择器(非交互须显式传参)。`dws profile use` 是其别名 +- 全局 `--profile <名称|corpId>` — 单次指定本命令用哪个组织,一次性、不改当前组织 +- `dws auth login` — 再登一个组织即新增 profile(自动从授权账号取 corpId / corpName);同组织重复 login = 刷新 +- `dws auth status [--profile <名称>]` — 查看认证状态 + +多组织数据聚合步骤:`dws profile list` 拿到所有已登录组织,对每个组织带 `--profile ` 各取一次数,合并并标注来源组织;某组织失败则标「该组织暂不可用」并继续返回其余。 +安全护栏: +- 只有 `dws profile list` 显示 ≥2 个组织才启用上面的跨组织逻辑;单组织直接按当前组织走,不带 `--profile`。 +- 自动跨组织只对「读 / 搜」。写 / 发 / 删 / 撤回等操作默认只在当前组织做;确需带 `--profile` 跨组织写时,必须先与用户确认目标组织。 +- 持久切换 `dws profile switch`(改默认组织)按写操作对待:未经用户明确要求不得执行。跨组织找数一律用一次性 `--profile`,不改当前组织。 + ## 行动指南(优先匹配) > 将用户意图与下表做**语义比对**,不要求字面包含关键词。命中后必须读取该行动指南文件,并按其中固定路线执行;多个场景同时命中时,按下方「消歧规则」选择。 @@ -112,6 +137,7 @@ cli_version: ">=1.0.15" 用户提到"在线电子表格/钉钉表格/axls/工作表/单元格读写/合并单元格/筛选视图/导出 xlsx" → `sheet` 用户提到"待办/TODO/任务提醒/循环待办" → `todo` 用户提到"创建知识库/知识库列表/搜索知识库空间/wiki/团队空间/知识库成员管理/我的文档个人空间" → `wiki` +用户提到"切换组织/换组织/跨组织/另一个钉钉/别的公司/多组织/看所有组织/profile/登录了哪些组织" → `profile`(见「多组织 / profile」节) 关键区分: **dev(创建/配置/建联机器人)** vs **chat(查询/发消息已有机器人)**。`dws chat bot search/find` 只查询机器人;**建号**(创建钉钉智能体机器人)走 `dws dev app robot submit`;**建联**(把机器人接到本地 agent 的 Stream)走 `dws dev connect`。凡是"创建机器人""建机器人""接入 agent""建联"一律路由到 `dev`,禁止走 `chat`。 关键区分: aitable(数据表格) vs todo(待办任务) @@ -149,6 +175,7 @@ cli_version: ">=1.0.15" | `oa` | `approval reject` | 拒绝待审批(需加明确理由) | | `todo` | `task delete` | 删除待办 | | `minutes` | `replace-text` | 全文批量替换转写与摘要 | +| `auth` | `logout` | **默认退出所有已登录组织**;只退一个加 `--profile <名称\|corpId>`。注意:退主组织不会被拦,会静默把「主」改选为剩下第一个组织,退主前必须向用户确认 | ### 确认流程 ``` diff --git a/skills/multi/dingtalk-aisearch/SKILL.md b/skills/multi/dingtalk-aisearch/SKILL.md index d5bef865..557db423 100644 --- a/skills/multi/dingtalk-aisearch/SKILL.md +++ b/skills/multi/dingtalk-aisearch/SKILL.md @@ -23,6 +23,8 @@ metadata: > 命令参考:[aisearch.md](references/aisearch.md)。 +> 跨组织:当前组织搜不到人时,别判定「查无此人」——先 `dws profile list` 看有哪些已登录组织,再对每个组织带 `--profile ` 各搜一遍,全无才追问用户。详见 `dingtalk-profile` skill。 + ## 开放平台文档 RAG / 错误码排查 - 任何产品执行中,只要用户问开放平台 API、接口参数、字段含义、权限点、回调、SDK、配额、错误码,或命令返回上游 OpenAPI/SDK 错误,必须先用 `dws devdoc article search --query "<关键词>" --format json` 做官方文档 RAG。 diff --git a/skills/multi/dingtalk-chat/SKILL.md b/skills/multi/dingtalk-chat/SKILL.md index eaf7581c..3fe07faa 100644 --- a/skills/multi/dingtalk-chat/SKILL.md +++ b/skills/multi/dingtalk-chat/SKILL.md @@ -23,6 +23,8 @@ metadata: > 命令参考:[chat.md](references/chat.md);表情:[chat-emoji-list.md](references/chat-emoji-list.md);剧本:[01-messaging.md](references/01-messaging.md)。 +> 跨组织:当前组织搜不到群 / 单聊时,别判定「不存在」——先 `dws profile list` 看有哪些已登录组织,再对每个组织带 `--profile ` 各搜一遍,全无才追问用户。详见 `dingtalk-profile` skill。 + ## 开放平台文档 RAG / 错误码排查 - 任何产品执行中,只要用户问开放平台 API、接口参数、字段含义、权限点、回调、SDK、配额、错误码,或命令返回上游 OpenAPI/SDK 错误,必须先用 `dws devdoc article search --query "<关键词>" --format json` 做官方文档 RAG。 diff --git a/skills/multi/dingtalk-contact/SKILL.md b/skills/multi/dingtalk-contact/SKILL.md index 467cc1b5..5e8ad609 100644 --- a/skills/multi/dingtalk-contact/SKILL.md +++ b/skills/multi/dingtalk-contact/SKILL.md @@ -23,6 +23,8 @@ metadata: > 命令参考:[contact.md](references/contact.md);剧本:[08-directory.md](references/08-directory.md)。 +> 跨组织:当前组织查不到人时,别判定「查无此人」——先 `dws profile list` 看有哪些已登录组织,再对每个组织带 `--profile ` 各查一遍,全无才追问用户。详见 `dingtalk-profile` skill。 + ## 开放平台文档 RAG / 错误码排查 - 任何产品执行中,只要用户问开放平台 API、接口参数、字段含义、权限点、回调、SDK、配额、错误码,或命令返回上游 OpenAPI/SDK 错误,必须先用 `dws devdoc article search --query "<关键词>" --format json` 做官方文档 RAG。 diff --git a/skills/multi/dingtalk-profile/SKILL.md b/skills/multi/dingtalk-profile/SKILL.md new file mode 100644 index 00000000..af01f957 --- /dev/null +++ b/skills/multi/dingtalk-profile/SKILL.md @@ -0,0 +1,48 @@ +--- +name: dingtalk-profile +description: 钉钉多组织 / profile 管理与跨组织取数。Use when 用户说 切换组织/换组织/跨组织/另一个钉钉/别的公司/多组织/看登录了哪些组织/profile,或在当前组织找不到群/人/数据需要去其他组织找。命令前缀:dws profile / dws auth / 全局 --profile。 +cli_version: ">=1.0.40" +metadata: + category: product + stability: experimental + requires: + bins: + - dws +--- + +# 钉钉多组织 / profile Skill + +> 🧪 **EXPERIMENTAL · 试验版 / Preview** — multi 模式当前未达 stable 标准;接口、命名、跨 skill 引用后续可能调整。生产 / 共享环境请优先使用 mono 模式(`dws skill setup --mode mono`)。 + +> **PREREQUISITE:** Read the `dws-shared` skill first for auth, global flags, product routing, URL preflight, error codes, and safety rules. The `dws` binary must be on PATH. + + + +dws 可同时登录多个钉钉组织,一个 profile = 一个已登录组织(corp)。当前 profile 决定本次命令用哪个组织的身份(corpId / userId 按当前 profile 自动注入)。 + +## 触发条件(命中任一即用本 skill) +- 显式:用户提到 切换 / 换 / 跨组织、另一个钉钉、别的公司、看登录了哪些组织、当前是哪个组织、某人 / 某群 / 某数据在别的组织 +- 隐式(最常见、易漏):在当前组织读 / 搜没找到目标(群 / 人 / 数据),且 `dws profile list` 显示已登录 ≥2 个组织 —— 别急着判「不存在」,按下方跨组织铁律去其他组织找 +- 需要跨多个组织汇总 / 对比数据 +- 用户问认证状态 / 登录了哪些组织 / 主组织是哪个 + +**不触发**:只登录 1 个组织时,按当前组织正常处理,不带 `--profile`。 + +## 命令 +- `dws profile list` — 列出已登录组织(主 / 当前标记、状态、有效期),只读元数据 +- `dws profile switch <名称|corpId|->` — 持久切换当前组织;`-` 切回上一个;无参数在交互终端弹选择器(非交互须显式传参)。`dws profile use` 是别名 +- 全局 `--profile <名称|corpId>` — 单次指定本命令用哪个组织,一次性、不改当前组织 +- `dws auth login` — 再登一个组织即新增 profile(自动从授权账号取 corpId / corpName);同组织重复 login = 刷新 +- `dws auth status [--profile <名称>]` — 查看认证状态 + +## 跨组织铁律(必须执行,不得跳过) +「找群 / 找人 / 找数据」(chat search、aisearch / contact、doc / wiki 搜索等读 / 搜场景)在当前组织没命中、且 `dws profile list` 显示 ≥2 个组织时,对每个组织带一次性 `--profile ` 各搜一遍;命中即用,全部组织都没有才追问用户。禁止在当前组织搜不到就判定「不存在」或直接甩给用户选。 + +## 跨组织聚合(agent 编排,无内置 --all-orgs) +① `dws profile list` 拿到所有已登录组织 → ② 对每个组织带 `--profile ` 各取一次数 → ③ 合并并标注来源组织;某组织失败则标「该组织暂不可用」并继续返回其余。 + +## 安全护栏(务必遵守) +- 只有 `dws profile list` 显示 ≥2 个组织才启用跨组织逻辑;单组织直接按当前组织走,不带 `--profile`。 +- 自动跨组织只对「读 / 搜」。写 / 发 / 删 / 撤回等操作默认只在当前组织做;确需带 `--profile` 跨组织写时,必须先与用户确认目标组织。 +- 持久切换 `dws profile switch`(改默认组织)按写操作对待:未经用户明确要求不得执行。跨组织找数一律用一次性 `--profile`,不改当前组织。 +- `dws auth logout` 默认退出所有已登录组织;只退一个加 `--profile <名称|corpId>`。退主组织不会被拦截,会静默改选新主,执行前必须向用户确认。 diff --git a/skills/multi/dws-shared/SKILL.md b/skills/multi/dws-shared/SKILL.md new file mode 100644 index 00000000..f48e1563 --- /dev/null +++ b/skills/multi/dws-shared/SKILL.md @@ -0,0 +1,39 @@ +--- +name: dws-shared +description: dws 多 skill 模式的公共参考——认证、全局参数、多组织 / --profile 规则、安全底线。所有 dingtalk-* 子 skill 执行前先读本 skill。命令前缀:dws。 +cli_version: ">=1.0.40" +metadata: + category: productivity + stability: experimental + requires: + bins: + - dws +--- + +# DWS 公共参考(dws-shared) + +> 🧪 **EXPERIMENTAL · 试验版 / Preview** — multi 模式当前未达 stable 标准;生产 / 共享环境请优先使用 mono 模式(`dws skill setup --mode mono`)。 + +每个 dingtalk-* 子 skill 都把本 skill 列为 PREREQUISITE:执行任何产品命令前先读这里的认证、全局参数与多组织规则。`dws` 必须在 PATH 上。 + +## 认证 +- `dws auth login`(新登一个组织即新增 profile);`--device` 无头 / SSH 登录;`--recommend` 无交互批量授权 +- `dws auth status [--profile <名称|corpId>]` 查看认证状态 + +## 全局参数 +- 所有命令加 `--format json` 取可解析输出 +- 全局 `--profile <名称|corpId>`:单次指定本命令用哪个组织,一次性、不改默认组织 +- 危险 / 写 / 删操作执行前先向用户确认 + +## 多组织 / --profile(关键规则) +dws 可同时登录多个钉钉组织,一个 profile = 一个已登录组织(corp)。当前 profile 决定本次命令用哪个组织的身份(corpId / userId 自动注入)。 + +- **跨组织铁律**:任何「找群 / 找人 / 找数据」(如 chat / aisearch / contact / doc / wiki / aitable / sheet / minutes / mail / report / todo / calendar / oa 的搜索、列表、查询)在当前组织没命中、且 `dws profile list` 显示 ≥2 个组织时,对每个组织带一次性 `--profile ` 各搜一遍;命中即用,全部组织都没有才追问用户。禁止在当前组织搜不到就判定「不存在」或直接甩给用户选。 +- **单组织**:`dws profile list` 只有 1 个组织时,按当前组织正常处理,不带 `--profile`。 +- **安全护栏**:自动跨组织只对「读 / 搜」;写 / 发 / 删 / 撤回等操作默认只在当前组织做,确需带 `--profile` 跨组织写时先与用户确认目标组织;持久切换 `dws profile switch`(改默认组织)属写操作,未经用户明确要求不得执行。 +- 完整命令与跨组织聚合见 `dingtalk-profile` skill。 + +## 错误处理 +- `unknown command` / `unknown flag`:先跑 `dws --help` 查证再修正一次,别把自然语言当命令 / flag +- 认证失败 / token 过期:提示用户 `dws auth login` 重新登录 +- 业务错误码 / 接口语义:用 `dws devdoc article search --query "<关键词>" --format json` 查官方文档,不编造原因 diff --git a/test/cli_compat/helpers_test.go b/test/cli_compat/helpers_test.go index e782b704..e74bd565 100644 --- a/test/cli_compat/helpers_test.go +++ b/test/cli_compat/helpers_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/app" + authpkg "github.com/DingTalk-Real-AI/dingtalk-workspace-cli/internal/auth" "github.com/spf13/cobra" ) @@ -92,6 +93,10 @@ func getCapture(t *testing.T) *mcpCallCapture { func setupTestDeps(t *testing.T, _ string) *mcpCallCapture { t.Helper() + authpkg.SetRuntimeProfile("") + t.Cleanup(func() { + authpkg.SetRuntimeProfile("") + }) cap := &mcpCallCapture{} linkCapture(t, cap) return cap