diff --git a/internal/service/sub2api.go b/internal/service/sub2api.go index dd4104eb4..508a3274a 100644 --- a/internal/service/sub2api.go +++ b/internal/service/sub2api.go @@ -200,15 +200,14 @@ func (s *Sub2APIService) ListRemoteAccounts(ctx context.Context, server map[stri continue } credentials := util.StringMap(account["credentials"]) - accessToken := extractAccessToken(credentials) - if accessToken == "" { + if !hasSub2AccessToken(account, credentials) { continue } id := util.Clean(account["id"]) if id == "" { id = util.Clean(credentials["chatgpt_account_id"]) } - items = append(items, map[string]any{"id": id, "name": util.Clean(account["name"]), "email": firstNonEmpty(util.Clean(credentials["email"]), util.Clean(account["name"])), "plan_type": util.Clean(credentials["plan_type"]), "status": util.Clean(account["status"]), "expires_at": util.Clean(credentials["expires_at"]), "has_refresh_token": util.Clean(credentials["refresh_token"]) != ""}) + items = append(items, map[string]any{"id": id, "name": util.Clean(account["name"]), "email": firstNonEmpty(util.Clean(credentials["email"]), util.Clean(account["name"])), "plan_type": util.Clean(credentials["plan_type"]), "status": util.Clean(account["status"]), "expires_at": util.Clean(credentials["expires_at"]), "has_refresh_token": hasSub2RefreshToken(account, credentials)}) } if page*200 >= total || len(data) < 200 { break @@ -289,25 +288,59 @@ func (s *Sub2APIService) runImport(serverID string, server map[string]any, ids [ } func (s *Sub2APIService) fetchAccessTokenForAccount(ctx context.Context, server map[string]any, accountID string) (string, error) { + account, err := s.fetchAccountFromDataExport(ctx, server, accountID) + if err != nil { + return "", err + } + if account == nil { + account, err = s.fetchAccountFromDetail(ctx, server, accountID) + if err != nil { + return "", err + } + } + token := extractAccessToken(util.StringMap(account["credentials"])) + if token == "" { + return "", fmt.Errorf("missing access_token") + } + return token, nil +} + +func (s *Sub2APIService) fetchAccountFromDataExport(ctx context.Context, server map[string]any, accountID string) (map[string]any, error) { baseURL := util.Clean(server["base_url"]) headers, err := s.authHeaders(ctx, server) if err != nil { - return "", err + return nil, err + } + payload, status, err := s.getJSONWithStatus(ctx, strings.TrimRight(baseURL, "/")+"/api/v1/admin/accounts/data?ids="+urlQuery(accountID)+"&include_proxies=false", headers) + if err != nil { + if status == http.StatusNotFound { + return nil, nil + } + return nil, err + } + accounts := extractDataAccounts(payload) + if len(accounts) == 0 { + return nil, nil + } + return accounts[0], nil +} + +func (s *Sub2APIService) fetchAccountFromDetail(ctx context.Context, server map[string]any, accountID string) (map[string]any, error) { + baseURL := util.Clean(server["base_url"]) + headers, err := s.authHeaders(ctx, server) + if err != nil { + return nil, err } payload, err := s.getJSON(ctx, strings.TrimRight(baseURL, "/")+"/api/v1/admin/accounts/"+accountID, headers) if err != nil { - return "", err + return nil, err } account := unwrapEnvelope(payload) accountMap, ok := account.(map[string]any) if !ok { accountMap = payload } - token := extractAccessToken(util.StringMap(accountMap["credentials"])) - if token == "" { - return "", fmt.Errorf("missing access_token") - } - return token, nil + return accountMap, nil } func (s *Sub2APIService) authHeaders(ctx context.Context, server map[string]any) (map[string]string, error) { @@ -371,24 +404,29 @@ func (s *Sub2APIService) login(ctx context.Context, baseURL, email, password str } func (s *Sub2APIService) getJSON(ctx context.Context, url string, headers map[string]string) (map[string]any, error) { + payload, _, err := s.getJSONWithStatus(ctx, url, headers) + return payload, err +} + +func (s *Sub2APIService) getJSONWithStatus(ctx context.Context, url string, headers map[string]string) (map[string]any, int, error) { req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) for key, value := range headers { req.Header.Set(key, value) } resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, err + return nil, 0, err } defer resp.Body.Close() data, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("sub2api request failed: HTTP %d %s", resp.StatusCode, string(data[:minInt(len(data), 200)])) + return nil, resp.StatusCode, fmt.Errorf("sub2api request failed: HTTP %d %s", resp.StatusCode, string(data[:minInt(len(data), 200)])) } var payload map[string]any if json.Unmarshal(data, &payload) != nil { - return nil, fmt.Errorf("invalid payload") + return nil, resp.StatusCode, fmt.Errorf("invalid payload") } - return payload, nil + return payload, resp.StatusCode, nil } func (s *Sub2APIService) updateJob(serverID string, updates map[string]any) { @@ -426,6 +464,18 @@ func extractAccessToken(credentials map[string]any) string { return "" } +func sub2CredentialsStatus(account map[string]any) map[string]any { + return util.StringMap(account["credentials_status"]) +} + +func hasSub2AccessToken(account, credentials map[string]any) bool { + return extractAccessToken(credentials) != "" || util.ToBool(sub2CredentialsStatus(account)["has_access_token"]) +} + +func hasSub2RefreshToken(account, credentials map[string]any) bool { + return util.Clean(credentials["refresh_token"]) != "" || util.ToBool(sub2CredentialsStatus(account)["has_refresh_token"]) +} + func unwrapEnvelope(payload map[string]any) any { if _, hasData := payload["data"]; hasData { if _, hasCode := payload["code"]; hasCode { @@ -450,6 +500,25 @@ func extractPagedItems(payload map[string]any) ([]any, int) { return []any{}, 0 } +func extractDataAccounts(payload map[string]any) []map[string]any { + data := unwrapEnvelope(payload) + obj, ok := data.(map[string]any) + if !ok { + return nil + } + raw, ok := asArray(obj["accounts"]) + if !ok { + return nil + } + out := make([]map[string]any, 0, len(raw)) + for _, item := range raw { + if account, ok := item.(map[string]any); ok { + out = append(out, account) + } + } + return out +} + func asArray(value any) ([]any, bool) { if list, ok := value.([]any); ok { return list, true diff --git a/internal/service/sub2api_test.go b/internal/service/sub2api_test.go index f0899b005..f3c0a2a97 100644 --- a/internal/service/sub2api_test.go +++ b/internal/service/sub2api_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "reflect" "testing" ) @@ -61,3 +62,122 @@ func TestSub2APIListRemoteGroupsReturnsEmptyArrayForNullItems(t *testing.T) { t.Fatalf("ListRemoteGroups() length = %d, want 0", len(groups)) } } + +func TestSub2APIListRemoteAccountsKeepsRedactedAccountsWithAccessTokenStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/admin/accounts" { + t.Fatalf("unexpected path %s", r.URL.Path) + } + if r.URL.Query().Get("platform") != "openai" || r.URL.Query().Get("type") != "oauth" { + t.Fatalf("unexpected query %s", r.URL.RawQuery) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"code":0,"message":"success","data":{"items":[{"id":123,"name":"user@example.com","status":"active","credentials":{"email":"user@example.com","plan_type":"Plus"},"credentials_status":{"has_access_token":true,"has_refresh_token":true}}],"total":1}}`)) + })) + defer server.Close() + + service := NewSub2APIService(NewSub2APIConfig(newTestStorageBackend(t)), nil) + accounts, err := service.ListRemoteAccounts(context.Background(), map[string]any{ + "base_url": server.URL, + "api_key": "test-key", + }) + if err != nil { + t.Fatalf("ListRemoteAccounts() error = %v", err) + } + if len(accounts) != 1 { + t.Fatalf("ListRemoteAccounts() length = %d, want 1: %#v", len(accounts), accounts) + } + if accounts[0]["id"] != "123" || accounts[0]["email"] != "user@example.com" || accounts[0]["plan_type"] != "Plus" { + t.Fatalf("ListRemoteAccounts() account = %#v", accounts[0]) + } + if accounts[0]["has_refresh_token"] != true { + t.Fatalf("has_refresh_token = %#v, want true", accounts[0]["has_refresh_token"]) + } +} + +func TestSub2APIListRemoteAccountsSkipsRedactedAccountsWithoutAccessTokenStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/admin/accounts" { + t.Fatalf("unexpected path %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"code":0,"message":"success","data":{"items":[{"id":123,"name":"missing-token","credentials":{"email":"missing@example.com"},"credentials_status":{"has_refresh_token":true}}],"total":1}}`)) + })) + defer server.Close() + + service := NewSub2APIService(NewSub2APIConfig(newTestStorageBackend(t)), nil) + accounts, err := service.ListRemoteAccounts(context.Background(), map[string]any{ + "base_url": server.URL, + "api_key": "test-key", + }) + if err != nil { + t.Fatalf("ListRemoteAccounts() error = %v", err) + } + if len(accounts) != 0 { + t.Fatalf("ListRemoteAccounts() length = %d, want 0: %#v", len(accounts), accounts) + } +} + +func TestSub2APIFetchAccessTokenPrefersDataExportAndIgnoresRefreshToken(t *testing.T) { + var paths []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + paths = append(paths, r.URL.Path) + if r.URL.Path != "/api/v1/admin/accounts/data" { + t.Fatalf("unexpected path %s", r.URL.Path) + } + if r.URL.Query().Get("ids") != "123" || r.URL.Query().Get("include_proxies") != "false" { + t.Fatalf("unexpected query %s", r.URL.RawQuery) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"code":0,"message":"success","data":{"accounts":[{"name":"user@example.com","credentials":{"access_token":"access-token-from-export","refresh_token":"refresh-token-must-not-be-imported","id_token":"id-token-must-not-be-imported"}}]}}`)) + })) + defer server.Close() + + service := NewSub2APIService(NewSub2APIConfig(newTestStorageBackend(t)), nil) + token, err := service.fetchAccessTokenForAccount(context.Background(), map[string]any{ + "base_url": server.URL, + "api_key": "test-key", + }, "123") + if err != nil { + t.Fatalf("fetchAccessTokenForAccount() error = %v", err) + } + if token != "access-token-from-export" { + t.Fatalf("token = %q, want access-token-from-export", token) + } + if !reflect.DeepEqual(paths, []string{"/api/v1/admin/accounts/data"}) { + t.Fatalf("paths = %#v", paths) + } +} + +func TestSub2APIFetchAccessTokenFallsBackToLegacyDetailWhenDataExportMissing(t *testing.T) { + var paths []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + paths = append(paths, r.URL.Path) + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/api/v1/admin/accounts/data": + http.NotFound(w, r) + case "/api/v1/admin/accounts/123": + _, _ = w.Write([]byte(`{"code":0,"message":"success","data":{"credentials":{"access_token":"legacy-access-token","refresh_token":"legacy-refresh-token-must-not-be-imported"}}}`)) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer server.Close() + + service := NewSub2APIService(NewSub2APIConfig(newTestStorageBackend(t)), nil) + token, err := service.fetchAccessTokenForAccount(context.Background(), map[string]any{ + "base_url": server.URL, + "api_key": "test-key", + }, "123") + if err != nil { + t.Fatalf("fetchAccessTokenForAccount() error = %v", err) + } + if token != "legacy-access-token" { + t.Fatalf("token = %q, want legacy-access-token", token) + } + wantPaths := []string{"/api/v1/admin/accounts/data", "/api/v1/admin/accounts/123"} + if !reflect.DeepEqual(paths, wantPaths) { + t.Fatalf("paths = %#v, want %#v", paths, wantPaths) + } +}