Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/app/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,9 @@ func resolveIdentityHeaders() map[string]string {
if fn := edition.Get().MergeHeaders; fn != nil {
headers = fn(headers)
}
if fn := edition.Get().EnterpriseCredentialHeaders; fn != nil {
headers = fn(headers)
}
return headers
}

Expand Down
7 changes: 7 additions & 0 deletions internal/auth/classify_denial_reason_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ func TestClassifyDenialReason(t *testing.T) {
},
want: "channel_required",
},
{
name: "error ENTERPRISE_NOT_AUTHORIZED",
status: &CLIAuthStatus{
ErrorCode: "ENTERPRISE_NOT_AUTHORIZED",
},
want: "enterprise_not_authorized",
},
{
name: "error NO_AUTH",
status: &CLIAuthStatus{
Expand Down
8 changes: 8 additions & 0 deletions internal/auth/device_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,14 @@ func (p *DeviceFlowProvider) loginOnce(ctx context.Context, attempt int) (*Token
_, _ = fmt.Fprintln(p.output(), i18n.T(" 请升级到最新版本的 CLI 后重试。"))
_, _ = fmt.Fprintln(p.output(), "")
return nil, errors.New(i18n.T("当前组织已开启渠道管控,请升级到最新版本的 CLI 后重试"))
case "enterprise_not_authorized":
msg := i18n.T("本次请求未通过企业安全认证")
if authStatus != nil && strings.TrimSpace(authStatus.ErrorMsg) != "" {
msg = strings.TrimSpace(authStatus.ErrorMsg)
}
_, _ = fmt.Fprintln(p.output(), dfRed("⚠️ "+msg))
_, _ = fmt.Fprintln(p.output(), "")
return nil, errors.New(msg)
case "no_auth":
_, _ = fmt.Fprintln(p.output(), dfRed(i18n.T("⚠️ 认证已失效")))
_, _ = fmt.Fprintln(p.output(), i18n.T(" 请执行 dws auth 重新登录。"))
Expand Down
42 changes: 42 additions & 0 deletions internal/auth/edition_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2026 Alibaba Group
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package auth

import (
"net/http"
"strings"

"github.com/DingTalk-Real-AI/dingtalk-workspace-cli/pkg/edition"
)

// applyEditionEnterpriseCredentialHeaders injects overlay-provided enterprise
// credential headers (e.g. x-dws-enterprise-credential) into MCP control-plane
// and OAuth proxy requests.
func applyEditionEnterpriseCredentialHeaders(req *http.Request) {
if req == nil {
return
}
fn := edition.Get().EnterpriseCredentialHeaders
if fn == nil {
return
}
merged := fn(nil)
for k, v := range merged {
k = strings.TrimSpace(k)
v = strings.TrimSpace(v)
if k != "" && v != "" {
req.Header.Set(k, v)
}
}
}
115 changes: 115 additions & 0 deletions internal/auth/oauth_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ import (
"context"
"encoding/json"
"fmt"
"html"
"io"
"net/http"
"net/url"
"os"
"slices"
"strings"
"time"

"github.com/DingTalk-Real-AI/dingtalk-workspace-cli/pkg/config"
Expand Down Expand Up @@ -203,6 +205,7 @@ func (p *OAuthProvider) postJSON(ctx context.Context, endpoint string, body any)
return nil, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
applyEditionEnterpriseCredentialHeaders(req)

client := p.httpClient
if client == nil {
Expand Down Expand Up @@ -1118,6 +1121,112 @@ const channelDeniedHTML = `<!doctype html>
</body>
</html>`

const enterpriseDeniedHTML = `<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>钉钉 CLI</title>
<style>
body {
font-family:
-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
"Helvetica Neue", Arial, sans-serif;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
margin: 0;
background: #f5f5f5;
padding: 20px;
}
.card {
height: 600px;
width: 480px;
border-radius: 16px;
background: #ffffff;
box-sizing: border-box;
border: 1px solid #f2f2f6;
box-shadow: 0px 2px 4px 0px rgba(0, 0, 0, 0.12);
padding: 32px 24px 24px;
text-align: center;
display: flex;
justify-content: center;
align-items: center;
flex-direction: column;
}
.lock-icon {
width: 120px;
height: 120px;
margin: 0 auto;
object-fit: contain;
display: block;
}
h1 {
margin: 8px 0 0;
font-family:
"PingFang SC",
-apple-system,
BlinkMacSystemFont,
"Segoe UI",
Roboto,
"Helvetica Neue",
Arial,
sans-serif;
font-size: 18px;
font-weight: 600;
line-height: 44px;
text-align: center;
letter-spacing: normal;
color: #181c1f;
}
p {
margin: 0;
font-family:
"PingFang SC",
-apple-system,
BlinkMacSystemFont,
"Segoe UI",
Roboto,
"Helvetica Neue",
Arial,
sans-serif;
font-size: 14px;
font-weight: normal;
line-height: 21px;
text-align: center;
letter-spacing: normal;
color: rgba(24, 28, 31, 0.6);
}
</style>
</head>
<body>
<div class="card">
<img
class="lock-icon"
src="https://img.alicdn.com/imgextra/i4/O1CN01fS3xxz1vbzZSGjbe0_!!6000000006192-2-tps-480-480.png"
alt="lock icon"
/>
<h1>企业安全认证未通过</h1>
<p>__ENTERPRISE_DENIED_MSG__</p>
</div>
</body>
</html>`

// defaultEnterpriseDeniedMsg is shown when the server returns no errorMsg.
const defaultEnterpriseDeniedMsg = "本次请求未通过企业安全认证"

// renderEnterpriseDeniedHTML injects the server-provided denial message (falling
// back to the default text) into the enterprise-denied page. The message is
// HTML-escaped before insertion.
func renderEnterpriseDeniedHTML(serverMsg string) string {
msg := strings.TrimSpace(serverMsg)
if msg == "" {
msg = defaultEnterpriseDeniedMsg
}
return strings.ReplaceAll(enterpriseDeniedHTML, "__ENTERPRISE_DENIED_MSG__", html.EscapeString(msg)+" 此页面可以关闭。")
}

// CLIAuthStatus represents the response from /cli/cliAuthEnabled API.
type CLIAuthStatus struct {
Success bool `json:"success"`
Expand Down Expand Up @@ -1154,6 +1263,9 @@ func classifyDenialReason(status *CLIAuthStatus, currentChannel string) string {
if status.ErrorCode == "CHANNEL_REQUIRED" {
return "channel_required"
}
if status.ErrorCode == "ENTERPRISE_NOT_AUTHORIZED" {
return "enterprise_not_authorized"
}
if status.ErrorCode == "NO_AUTH" {
return "no_auth"
}
Expand Down Expand Up @@ -1243,6 +1355,7 @@ func (p *OAuthProvider) doCheckCLIAuthEnabled(ctx context.Context, accessToken s
if ch := os.Getenv("DWS_CHANNEL"); ch != "" {
req.Header.Set("x-dws-channel", ch)
}
applyEditionEnterpriseCredentialHeaders(req)

client := p.httpClient
if client == nil {
Expand Down Expand Up @@ -1294,6 +1407,7 @@ func doGetSuperAdmins(ctx context.Context, accessToken string) (*SuperAdminRespo
return nil, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("x-user-access-token", accessToken)
applyEditionEnterpriseCredentialHeaders(req)

resp, err := oauthHTTPClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -1341,6 +1455,7 @@ func doSendCliAuthApply(ctx context.Context, accessToken, adminStaffID string) (
return nil, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("x-user-access-token", accessToken)
applyEditionEnterpriseCredentialHeaders(req)

resp, err := oauthHTTPClient.Do(req)
if err != nil {
Expand Down
18 changes: 17 additions & 1 deletion internal/auth/oauth_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net"
"net/http"
"os"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -141,6 +142,7 @@ func (p *OAuthProvider) Login(ctx context.Context, force bool) (*TokenData, erro
err error
cliAuthDisabled bool
denialReason string
errorMsg string // server-provided errorMsg from /cli/cliAuthEnabled
}
resultCh := make(chan callbackResult, 1)
errCh := make(chan error, 1)
Expand Down Expand Up @@ -261,6 +263,13 @@ func (p *OAuthProvider) Login(ctx context.Context, force bool) (*TokenData, erro
}
cliAuthEnabled := denialReason == ""

// Server-provided errorMsg (nil-safe), surfaced both on the page and to
// the terminal so portal can update copy without releasing the CLI.
serverMsg := ""
if authStatus != nil {
serverMsg = authStatus.ErrorMsg
}

// Update CLI auth disabled state
callbackTokenMu.Lock()
callbackAuthDisabled = !cliAuthEnabled
Expand All @@ -275,6 +284,8 @@ func (p *OAuthProvider) Login(ctx context.Context, force bool) (*TokenData, erro
_, _ = fmt.Fprint(w, accessDeniedHTML)
case denialReason == "channel_not_allowed" || denialReason == "channel_required":
_, _ = fmt.Fprint(w, channelDeniedHTML)
case denialReason == "enterprise_not_authorized":
_, _ = fmt.Fprint(w, renderEnterpriseDeniedHTML(serverMsg))
default:
_, _ = fmt.Fprint(w, notEnabledHTML)
}
Expand All @@ -284,7 +295,7 @@ func (p *OAuthProvider) Login(ctx context.Context, force bool) (*TokenData, erro
}
// Notify main goroutine with full result
select {
case resultCh <- callbackResult{token: tokenData, cliAuthDisabled: !cliAuthEnabled, denialReason: denialReason}:
case resultCh <- callbackResult{token: tokenData, cliAuthDisabled: !cliAuthEnabled, denialReason: denialReason, errorMsg: serverMsg}:
default:
}
})
Expand Down Expand Up @@ -433,6 +444,11 @@ func (p *OAuthProvider) Login(ctx context.Context, force bool) (*TokenData, erro
return nil, errors.New(i18n.T("您不在该组织的 CLI 授权人员范围内,请联系组织管理员将您加入授权名单"))
case "channel_not_allowed", "channel_required":
return nil, errors.New(i18n.T("当前渠道未获得该组织授权,或组织已开启渠道管控,请联系组织管理员开通渠道访问权限,或升级到最新版本的 CLI"))
case "enterprise_not_authorized":
if msg := strings.TrimSpace(result.errorMsg); msg != "" {
return nil, errors.New(msg)
}
return nil, errors.New(i18n.T("本次请求未通过企业安全认证"))
}

_, _ = fmt.Fprintln(p.output(), "")
Expand Down
3 changes: 3 additions & 0 deletions pkg/edition/edition.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ type Hooks struct {
// --- HTTP headers ---
MergeHeaders func(base map[string]string) map[string]string

// --- EnterpriseCredential HTTP headers ---
EnterpriseCredentialHeaders func(base map[string]string) map[string]string

// --- auth ---
AuthClientID string // OAuth client ID for device-flow authorisation
AuthClientFromMCP bool // true → fetch client ID from MCP at runtime
Expand Down