Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 83 additions & 14 deletions internal/service/sub2api.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,14 @@ func (s *Sub2APIService) ListRemoteAccounts(ctx context.Context, server map[stri
continue
}
credentials := util.StringMap(account["credentials"])
accessToken := extractAccessToken(credentials)
if accessToken == "" {
if !hasSub2AccessToken(account, credentials) {
continue
}
id := util.Clean(account["id"])
if id == "" {
id = util.Clean(credentials["chatgpt_account_id"])
}
items = append(items, map[string]any{"id": id, "name": util.Clean(account["name"]), "email": firstNonEmpty(util.Clean(credentials["email"]), util.Clean(account["name"])), "plan_type": util.Clean(credentials["plan_type"]), "status": util.Clean(account["status"]), "expires_at": util.Clean(credentials["expires_at"]), "has_refresh_token": util.Clean(credentials["refresh_token"]) != ""})
items = append(items, map[string]any{"id": id, "name": util.Clean(account["name"]), "email": firstNonEmpty(util.Clean(credentials["email"]), util.Clean(account["name"])), "plan_type": util.Clean(credentials["plan_type"]), "status": util.Clean(account["status"]), "expires_at": util.Clean(credentials["expires_at"]), "has_refresh_token": hasSub2RefreshToken(account, credentials)})
}
if page*200 >= total || len(data) < 200 {
break
Expand Down Expand Up @@ -289,25 +288,59 @@ func (s *Sub2APIService) runImport(serverID string, server map[string]any, ids [
}

func (s *Sub2APIService) fetchAccessTokenForAccount(ctx context.Context, server map[string]any, accountID string) (string, error) {
account, err := s.fetchAccountFromDataExport(ctx, server, accountID)
if err != nil {
return "", err
}
if account == nil {
account, err = s.fetchAccountFromDetail(ctx, server, accountID)
if err != nil {
return "", err
}
}
token := extractAccessToken(util.StringMap(account["credentials"]))
if token == "" {
return "", fmt.Errorf("missing access_token")
}
return token, nil
}

func (s *Sub2APIService) fetchAccountFromDataExport(ctx context.Context, server map[string]any, accountID string) (map[string]any, error) {
baseURL := util.Clean(server["base_url"])
headers, err := s.authHeaders(ctx, server)
if err != nil {
return "", err
return nil, err
}
payload, status, err := s.getJSONWithStatus(ctx, strings.TrimRight(baseURL, "/")+"/api/v1/admin/accounts/data?ids="+urlQuery(accountID)+"&include_proxies=false", headers)
if err != nil {
if status == http.StatusNotFound {
return nil, nil
}
return nil, err
}
accounts := extractDataAccounts(payload)
if len(accounts) == 0 {
return nil, nil
}
return accounts[0], nil
}

func (s *Sub2APIService) fetchAccountFromDetail(ctx context.Context, server map[string]any, accountID string) (map[string]any, error) {
baseURL := util.Clean(server["base_url"])
headers, err := s.authHeaders(ctx, server)
if err != nil {
return nil, err
}
payload, err := s.getJSON(ctx, strings.TrimRight(baseURL, "/")+"/api/v1/admin/accounts/"+accountID, headers)
if err != nil {
return "", err
return nil, err
}
account := unwrapEnvelope(payload)
accountMap, ok := account.(map[string]any)
if !ok {
accountMap = payload
}
token := extractAccessToken(util.StringMap(accountMap["credentials"]))
if token == "" {
return "", fmt.Errorf("missing access_token")
}
return token, nil
return accountMap, nil
}

func (s *Sub2APIService) authHeaders(ctx context.Context, server map[string]any) (map[string]string, error) {
Expand Down Expand Up @@ -371,24 +404,29 @@ func (s *Sub2APIService) login(ctx context.Context, baseURL, email, password str
}

func (s *Sub2APIService) getJSON(ctx context.Context, url string, headers map[string]string) (map[string]any, error) {
payload, _, err := s.getJSONWithStatus(ctx, url, headers)
return payload, err
}

func (s *Sub2APIService) getJSONWithStatus(ctx context.Context, url string, headers map[string]string) (map[string]any, int, error) {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
for key, value := range headers {
req.Header.Set(key, value)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
return nil, 0, err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("sub2api request failed: HTTP %d %s", resp.StatusCode, string(data[:minInt(len(data), 200)]))
return nil, resp.StatusCode, fmt.Errorf("sub2api request failed: HTTP %d %s", resp.StatusCode, string(data[:minInt(len(data), 200)]))
}
var payload map[string]any
if json.Unmarshal(data, &payload) != nil {
return nil, fmt.Errorf("invalid payload")
return nil, resp.StatusCode, fmt.Errorf("invalid payload")
}
return payload, nil
return payload, resp.StatusCode, nil
}

func (s *Sub2APIService) updateJob(serverID string, updates map[string]any) {
Expand Down Expand Up @@ -426,6 +464,18 @@ func extractAccessToken(credentials map[string]any) string {
return ""
}

func sub2CredentialsStatus(account map[string]any) map[string]any {
return util.StringMap(account["credentials_status"])
}

func hasSub2AccessToken(account, credentials map[string]any) bool {
return extractAccessToken(credentials) != "" || util.ToBool(sub2CredentialsStatus(account)["has_access_token"])
}

func hasSub2RefreshToken(account, credentials map[string]any) bool {
return util.Clean(credentials["refresh_token"]) != "" || util.ToBool(sub2CredentialsStatus(account)["has_refresh_token"])
}

func unwrapEnvelope(payload map[string]any) any {
if _, hasData := payload["data"]; hasData {
if _, hasCode := payload["code"]; hasCode {
Expand All @@ -450,6 +500,25 @@ func extractPagedItems(payload map[string]any) ([]any, int) {
return []any{}, 0
}

func extractDataAccounts(payload map[string]any) []map[string]any {
data := unwrapEnvelope(payload)
obj, ok := data.(map[string]any)
if !ok {
return nil
}
raw, ok := asArray(obj["accounts"])
if !ok {
return nil
}
out := make([]map[string]any, 0, len(raw))
for _, item := range raw {
if account, ok := item.(map[string]any); ok {
out = append(out, account)
}
}
return out
}

func asArray(value any) ([]any, bool) {
if list, ok := value.([]any); ok {
return list, true
Expand Down
120 changes: 120 additions & 0 deletions internal/service/sub2api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)

Expand Down Expand Up @@ -61,3 +62,122 @@ func TestSub2APIListRemoteGroupsReturnsEmptyArrayForNullItems(t *testing.T) {
t.Fatalf("ListRemoteGroups() length = %d, want 0", len(groups))
}
}

func TestSub2APIListRemoteAccountsKeepsRedactedAccountsWithAccessTokenStatus(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/admin/accounts" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
if r.URL.Query().Get("platform") != "openai" || r.URL.Query().Get("type") != "oauth" {
t.Fatalf("unexpected query %s", r.URL.RawQuery)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"code":0,"message":"success","data":{"items":[{"id":123,"name":"user@example.com","status":"active","credentials":{"email":"user@example.com","plan_type":"Plus"},"credentials_status":{"has_access_token":true,"has_refresh_token":true}}],"total":1}}`))
}))
defer server.Close()

service := NewSub2APIService(NewSub2APIConfig(newTestStorageBackend(t)), nil)
accounts, err := service.ListRemoteAccounts(context.Background(), map[string]any{
"base_url": server.URL,
"api_key": "test-key",
})
if err != nil {
t.Fatalf("ListRemoteAccounts() error = %v", err)
}
if len(accounts) != 1 {
t.Fatalf("ListRemoteAccounts() length = %d, want 1: %#v", len(accounts), accounts)
}
if accounts[0]["id"] != "123" || accounts[0]["email"] != "user@example.com" || accounts[0]["plan_type"] != "Plus" {
t.Fatalf("ListRemoteAccounts() account = %#v", accounts[0])
}
if accounts[0]["has_refresh_token"] != true {
t.Fatalf("has_refresh_token = %#v, want true", accounts[0]["has_refresh_token"])
}
}

func TestSub2APIListRemoteAccountsSkipsRedactedAccountsWithoutAccessTokenStatus(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/admin/accounts" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"code":0,"message":"success","data":{"items":[{"id":123,"name":"missing-token","credentials":{"email":"missing@example.com"},"credentials_status":{"has_refresh_token":true}}],"total":1}}`))
}))
defer server.Close()

service := NewSub2APIService(NewSub2APIConfig(newTestStorageBackend(t)), nil)
accounts, err := service.ListRemoteAccounts(context.Background(), map[string]any{
"base_url": server.URL,
"api_key": "test-key",
})
if err != nil {
t.Fatalf("ListRemoteAccounts() error = %v", err)
}
if len(accounts) != 0 {
t.Fatalf("ListRemoteAccounts() length = %d, want 0: %#v", len(accounts), accounts)
}
}

func TestSub2APIFetchAccessTokenPrefersDataExportAndIgnoresRefreshToken(t *testing.T) {
var paths []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
paths = append(paths, r.URL.Path)
if r.URL.Path != "/api/v1/admin/accounts/data" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
if r.URL.Query().Get("ids") != "123" || r.URL.Query().Get("include_proxies") != "false" {
t.Fatalf("unexpected query %s", r.URL.RawQuery)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"code":0,"message":"success","data":{"accounts":[{"name":"user@example.com","credentials":{"access_token":"access-token-from-export","refresh_token":"refresh-token-must-not-be-imported","id_token":"id-token-must-not-be-imported"}}]}}`))
}))
defer server.Close()

service := NewSub2APIService(NewSub2APIConfig(newTestStorageBackend(t)), nil)
token, err := service.fetchAccessTokenForAccount(context.Background(), map[string]any{
"base_url": server.URL,
"api_key": "test-key",
}, "123")
if err != nil {
t.Fatalf("fetchAccessTokenForAccount() error = %v", err)
}
if token != "access-token-from-export" {
t.Fatalf("token = %q, want access-token-from-export", token)
}
if !reflect.DeepEqual(paths, []string{"/api/v1/admin/accounts/data"}) {
t.Fatalf("paths = %#v", paths)
}
}

func TestSub2APIFetchAccessTokenFallsBackToLegacyDetailWhenDataExportMissing(t *testing.T) {
var paths []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
paths = append(paths, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/api/v1/admin/accounts/data":
http.NotFound(w, r)
case "/api/v1/admin/accounts/123":
_, _ = w.Write([]byte(`{"code":0,"message":"success","data":{"credentials":{"access_token":"legacy-access-token","refresh_token":"legacy-refresh-token-must-not-be-imported"}}}`))
default:
t.Fatalf("unexpected path %s", r.URL.Path)
}
}))
defer server.Close()

service := NewSub2APIService(NewSub2APIConfig(newTestStorageBackend(t)), nil)
token, err := service.fetchAccessTokenForAccount(context.Background(), map[string]any{
"base_url": server.URL,
"api_key": "test-key",
}, "123")
if err != nil {
t.Fatalf("fetchAccessTokenForAccount() error = %v", err)
}
if token != "legacy-access-token" {
t.Fatalf("token = %q, want legacy-access-token", token)
}
wantPaths := []string{"/api/v1/admin/accounts/data", "/api/v1/admin/accounts/123"}
if !reflect.DeepEqual(paths, wantPaths) {
t.Fatalf("paths = %#v, want %#v", paths, wantPaths)
}
}