diff --git a/.env.example b/.env.example index 06cd75a3..8600c6aa 100644 --- a/.env.example +++ b/.env.example @@ -30,10 +30,6 @@ APP_BASE_PATH= # Project business timezone. Affects Today, daily aggregation, daily 03:00 cleanup, and log timestamps. Required: no. Default: Asia/Shanghai. TZ=Asia/Shanghai -# usage 同步模式:auto 启动时用 AUTH-only 探测 management data stream;成功则本进程固定使用 redis,失败则本进程固定使用 legacy_export。redis 只使用 data stream;legacy_export 使用旧兼容方式。必填:否。默认值:auto。 -# Usage sync mode: auto runs an AUTH-only startup probe for the management data stream; success fixes this process to redis, failure fixes it to legacy_export. redis uses only the data stream; legacy_export uses the old compatibility path. Required: no. Default: auto. -USAGE_SYNC_MODE=auto - # CPA management data stream 的 Redis/RESP TCP 地址。留空时默认使用 CPA_BASE_URL 的主机名加 8317 端口;如果通过 nginx stream 暴露到其它端口,请显式填写 host:port。 # Redis/RESP TCP address for the CPA management data stream. When empty, defaults to the CPA_BASE_URL hostname plus port 8317; set host:port explicitly when exposed through nginx stream on another port. REDIS_QUEUE_ADDR= @@ -46,10 +42,6 @@ REDIS_QUEUE_BATCH_SIZE=1000 # Idle check interval when the Redis queue is empty. Required: no. Default: 1s. REDIS_QUEUE_IDLE_INTERVAL=1s -# legacy_export 的拉取间隔。必填:否。默认值:5m。 -# Pull interval for legacy_export. Required: no. Default: 5m. -POLL_INTERVAL=5m - # 请求 CPA 接口时的超时时间。必填:否。默认值:30s。 # Timeout for requests to the CPA service. Required: no. Default: 30s. REQUEST_TIMEOUT=30s diff --git a/README.en.md b/README.en.md index f6417d4a..88a47ae9 100644 --- a/README.en.md +++ b/README.en.md @@ -4,7 +4,7 @@ CPA Usage Keeper is a standalone CPA usage persistence and dashboard service. -It relies on [CLIProxyAPI (CPA)](https://github.com/router-for-me/CLIProxyAPI) as the backend CPA data source and adds persistent storage and statistical analysis capabilities on top of CPA. The service periodically pulls CPA data, writes normalized events to SQLite, exposes aggregation APIs, and serves a built-in web dashboard for usage, pricing, request health, and model/API statistics. +It relies on [CLIProxyAPI (CPA)](https://github.com/router-for-me/CLIProxyAPI) as the backend CPA data source and adds persistent storage and statistical analysis capabilities on top of CPA. The service consumes events from the CPA Redis usage queue into SQLite, periodically pulls CPA metadata, exposes aggregation APIs, and serves a built-in web dashboard for usage, pricing, request health, and model/API statistics. ![cpa-usage-keeper-screenshot](https://images.bitskyline.com/i/2026/04/h9se9f.png) @@ -52,11 +52,9 @@ cp .env.example .env | `APP_PORT` | No | `8080` | HTTP listen port | | `APP_BASE_PATH` | No | root path | Subpath prefix such as `/cpa`; empty means `/` | | `TZ` | No | `Asia/Shanghai` | Project business timezone; affects Today, daily aggregation, scheduled tasks, and log timestamps | -| `USAGE_SYNC_MODE` | No | `auto` | Sync mode: `auto` probes at startup and then fixes the process to `redis` or `legacy_export`; can also be set explicitly to `redis` or `legacy_export` | | `REDIS_QUEUE_ADDR` | No | `CPA_BASE_URL` hostname + `8317` | CPA Redis/RESP TCP address; set `host:port` for non-default ports | | `REDIS_QUEUE_BATCH_SIZE` | No | `1000` | Maximum queue records per pull | | `REDIS_QUEUE_IDLE_INTERVAL` | No | `1s` | Empty queue check interval | -| `POLL_INTERVAL` | No | `5m` | Pull interval for `legacy_export` | | `REQUEST_TIMEOUT` | No | `30s` | CPA request timeout | | `WORK_DIR` | No | `./data` | Application work directory; database, logs, and backups default to `app.db`, `logs/`, and `backups/` under it | | `LOG_LEVEL` | No | `info` | Log level | diff --git a/README.md b/README.md index 0617f6bd..94fe8735 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ `CPA Usage Keeper` 是一个独立的 CPA 用量持久化与可视化服务。 -它依赖 [CLIProxyAPI(CPA)](https://github.com/router-for-me/CLIProxyAPI) 作为后端 CPA 数据来源,目标是在 CPA 之上补充持久化存储与统计分析能力。服务会定时拉取 CPA 数据,将规范化后的事件写入 SQLite,暴露聚合 API,并提供内置 Web Dashboard 用于查看 usage、pricing、request health 和 model/API 维度的统计信息。 +它依赖 [CLIProxyAPI(CPA)](https://github.com/router-for-me/CLIProxyAPI) 作为后端 CPA 数据来源,目标是在 CPA 之上补充持久化存储与统计分析能力。服务会从 CPA Redis usage 队列消费事件并写入 SQLite,定时拉取 CPA metadata,暴露聚合 API,并提供内置 Web Dashboard 用于查看 usage、pricing、request health 和 model/API 维度的统计信息。 ![cpa-usage-keeper-screenshot](https://images.bitskyline.com/i/2026/04/h9se9f.png) @@ -52,11 +52,9 @@ cp .env.example .env | `APP_PORT` | 否 | `8080` | HTTP 监听端口 | | `APP_BASE_PATH` | 否 | 根路径 | 子路径部署前缀,例如 `/cpa`;留空表示 `/` | | `TZ` | 否 | `Asia/Shanghai` | 项目业务时区,影响 Today、按天聚合、定时任务和日志时间 | -| `USAGE_SYNC_MODE` | 否 | `auto` | 同步模式:`auto` 启动时探测后固定为 `redis` 或 `legacy_export`;也可显式设置 `redis`、`legacy_export` | | `REDIS_QUEUE_ADDR` | 否 | `CPA_BASE_URL` 主机名 + `8317` | CPA Redis/RESP TCP 地址;非默认端口时填写 `host:port` | | `REDIS_QUEUE_BATCH_SIZE` | 否 | `1000` | 每次最多拉取的队列记录数 | | `REDIS_QUEUE_IDLE_INTERVAL` | 否 | `1s` | 队列为空时的检查间隔 | -| `POLL_INTERVAL` | 否 | `5m` | `legacy_export` 拉取间隔 | | `REQUEST_TIMEOUT` | 否 | `30s` | CPA 请求超时 | | `WORK_DIR` | 否 | `./data` | 应用工作目录;数据库、日志和备份默认分别写入 `app.db`、`logs/`、`backups/` | | `LOG_LEVEL` | 否 | `info` | 日志级别 | diff --git a/docker-compose.example.yml b/docker-compose.example.yml index efe665f1..bfc58bf4 100644 --- a/docker-compose.example.yml +++ b/docker-compose.example.yml @@ -12,11 +12,9 @@ services: APP_PORT: ${APP_PORT:-8080} APP_BASE_PATH: ${APP_BASE_PATH:-} TZ: ${TZ:-Asia/Shanghai} - USAGE_SYNC_MODE: ${USAGE_SYNC_MODE:-auto} REDIS_QUEUE_ADDR: ${REDIS_QUEUE_ADDR:-} REDIS_QUEUE_BATCH_SIZE: ${REDIS_QUEUE_BATCH_SIZE:-1000} REDIS_QUEUE_IDLE_INTERVAL: ${REDIS_QUEUE_IDLE_INTERVAL:-1s} - POLL_INTERVAL: ${POLL_INTERVAL:-5m} REQUEST_TIMEOUT: ${REQUEST_TIMEOUT:-30s} WORK_DIR: ${WORK_DIR:-./data} LOG_LEVEL: ${LOG_LEVEL:-info} diff --git a/internal/api/auth_files.go b/internal/api/auth_files.go deleted file mode 100644 index 4c591cc2..00000000 --- a/internal/api/auth_files.go +++ /dev/null @@ -1,52 +0,0 @@ -package api - -import ( - "net/http" - - "cpa-usage-keeper/internal/models" - "cpa-usage-keeper/internal/service" - "github.com/gin-gonic/gin" -) - -type authFilesResponse struct { - Files []authFileResponse `json:"files"` -} - -type authFileResponse struct { - AuthIndex string `json:"auth_index"` - Name string `json:"name,omitempty"` - Email string `json:"email,omitempty"` - Type string `json:"type,omitempty"` - Provider string `json:"provider,omitempty"` -} - -func registerAuthFileRoutes(router gin.IRoutes, authFileProvider service.AuthFileProvider) { - router.GET("/auth-files", func(c *gin.Context) { - if authFileProvider == nil { - c.JSON(http.StatusOK, authFilesResponse{Files: []authFileResponse{}}) - return - } - - files, err := authFileProvider.ListAuthFiles(c.Request.Context()) - if err != nil { - writeInternalError(c, "list auth files failed", err) - return - } - - response := make([]authFileResponse, 0, len(files)) - for _, file := range files { - response = append(response, mapAuthFileResponse(file)) - } - c.JSON(http.StatusOK, authFilesResponse{Files: response}) - }) -} - -func mapAuthFileResponse(file models.AuthFile) authFileResponse { - return authFileResponse{ - AuthIndex: file.AuthIndex, - Name: file.Name, - Email: file.Email, - Type: file.Type, - Provider: file.Provider, - } -} diff --git a/internal/api/auth_files_test.go b/internal/api/auth_files_test.go deleted file mode 100644 index 6226f28a..00000000 --- a/internal/api/auth_files_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package api - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "cpa-usage-keeper/internal/models" -) - -type authFileStub struct { - files []models.AuthFile - err error -} - -func (s authFileStub) ListAuthFiles(context.Context) ([]models.AuthFile, error) { - return s.files, s.err -} - -func TestAuthFilesRouteReturnsEmptyResponseWithoutProvider(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, nil, nil, AuthConfig{}, nil, "") - req := httptest.NewRequest(http.MethodGet, "/api/v1/auth-files", nil) - resp := httptest.NewRecorder() - - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK || !contains(resp.Body.String(), `"files":[]`) { - t.Fatalf("unexpected response: %d %s", resp.Code, resp.Body.String()) - } -} - -func TestAuthFilesRouteReturnsStoredMetadata(t *testing.T) { - router := NewRouter(nil, nil, nil, authFileStub{files: []models.AuthFile{{ - AuthIndex: "2", - Name: "Claude Desktop", - Email: "user@example.com", - Type: "claude", - Provider: "anthropic", - }}}, nil, nil, AuthConfig{}, nil, "") - req := httptest.NewRequest(http.MethodGet, "/api/v1/auth-files", nil) - resp := httptest.NewRecorder() - - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) - } - body := resp.Body.String() - if !(contains(body, `"auth_index":"2"`) && contains(body, `"email":"user@example.com"`) && contains(body, `"provider":"anthropic"`)) { - t.Fatalf("unexpected response body: %s", body) - } -} diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 47ba6821..650d724f 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -11,7 +11,7 @@ import ( ) func TestAuthSessionReportsAuthenticatedWhenDisabled(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, nil, nil, AuthConfig{Enabled: false}, nil, "") + router := NewRouter(nil, nil, nil, nil, AuthConfig{Enabled: false}, nil, "") resp := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/session", nil) @@ -25,7 +25,7 @@ func TestAuthSessionReportsAuthenticatedWhenDisabled(t *testing.T) { func TestAuthProtectedRouteRequiresSessionWhenEnabled(t *testing.T) { sessions := auth.NewSessionManager(time.Hour) config := AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour} - router := NewRouter(nil, nil, nil, nil, nil, nil, config, NewAuthHandler(config, sessions), "") + router := NewRouter(nil, nil, nil, nil, config, NewAuthHandler(config, sessions), "") resp := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/overview", nil) @@ -40,7 +40,7 @@ func TestAuthLoginSetsCookieAndUnlocksProtectedRoute(t *testing.T) { sessions := auth.NewSessionManager(time.Hour) config := AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour} handler := NewAuthHandler(config, sessions) - router := NewRouter(nil, nil, nil, nil, nil, nil, config, handler, "") + router := NewRouter(nil, nil, nil, nil, config, handler, "") loginResp := httptest.NewRecorder() loginReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(`{"password":"secret"}`)) @@ -74,7 +74,7 @@ func TestAuthLoginSetsCookieAndUnlocksProtectedRoute(t *testing.T) { func TestAuthLoginRejectsWrongPassword(t *testing.T) { sessions := auth.NewSessionManager(time.Hour) config := AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour} - router := NewRouter(nil, nil, nil, nil, nil, nil, config, NewAuthHandler(config, sessions), "") + router := NewRouter(nil, nil, nil, nil, config, NewAuthHandler(config, sessions), "") resp := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(`{"password":"wrong"}`)) req.Header.Set("Content-Type", "application/json") @@ -89,7 +89,7 @@ func TestAuthLoginRejectsWrongPassword(t *testing.T) { func TestAuthLoginRateLimitsRepeatedFailures(t *testing.T) { sessions := auth.NewSessionManager(time.Hour) config := AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour} - router := NewRouter(nil, nil, nil, nil, nil, nil, config, NewAuthHandler(config, sessions), "") + router := NewRouter(nil, nil, nil, nil, config, NewAuthHandler(config, sessions), "") for i := 0; i < 5; i++ { resp := httptest.NewRecorder() @@ -116,7 +116,7 @@ func TestAuthLoginRateLimitsRepeatedFailures(t *testing.T) { func TestAuthLoginAllowsCorrectPasswordAfterRateLimitThreshold(t *testing.T) { sessions := auth.NewSessionManager(time.Hour) config := AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour} - router := NewRouter(nil, nil, nil, nil, nil, nil, config, NewAuthHandler(config, sessions), "") + router := NewRouter(nil, nil, nil, nil, config, NewAuthHandler(config, sessions), "") for i := 0; i < 5; i++ { resp := httptest.NewRecorder() @@ -144,7 +144,7 @@ func TestAuthLogoutDeletesSessionCookie(t *testing.T) { sessions := auth.NewSessionManager(time.Hour) config := AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour} handler := NewAuthHandler(config, sessions) - router := NewRouter(nil, nil, nil, nil, nil, nil, config, handler, "") + router := NewRouter(nil, nil, nil, nil, config, handler, "") loginResp := httptest.NewRecorder() loginReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(`{"password":"secret"}`)) @@ -183,7 +183,7 @@ func TestSubpathAuthUsesPrefixedRoutesAndCookiePath(t *testing.T) { sessions := auth.NewSessionManager(time.Hour) config := AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour, BasePath: "/cpa"} handler := NewAuthHandler(config, sessions) - router := NewRouter(nil, nil, nil, nil, nil, nil, config, handler, "/cpa") + router := NewRouter(nil, nil, nil, nil, config, handler, "/cpa") sessionResp := httptest.NewRecorder() sessionReq := httptest.NewRequest(http.MethodGet, "/cpa/api/v1/auth/session", nil) diff --git a/internal/api/pricing_test.go b/internal/api/pricing_test.go index ad18e656..580a8b8a 100644 --- a/internal/api/pricing_test.go +++ b/internal/api/pricing_test.go @@ -39,7 +39,7 @@ func (s *pricingStub) DeletePricing(_ context.Context, model string) error { } func TestPricingRoutesReturnEmptyResponsesWithoutProvider(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, nil, nil, AuthConfig{}, nil, "") usedReq := httptest.NewRequest(http.MethodGet, "/api/v1/models/used", nil) usedResp := httptest.NewRecorder() @@ -57,7 +57,7 @@ func TestPricingRoutesReturnEmptyResponsesWithoutProvider(t *testing.T) { } func TestPricingRoutesReturnConfiguredData(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, nil, &pricingStub{ + router := NewRouter(nil, nil, nil, &pricingStub{ usedModels: []string{"claude-sonnet"}, pricing: []models.ModelPriceSetting{{ Model: "claude-sonnet", @@ -91,7 +91,7 @@ func TestUpdatePricingRoute(t *testing.T) { CachePricePer1M: 0.3, }, } - router := NewRouter(nil, nil, nil, nil, nil, provider, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, nil, provider, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodPut, "/api/v1/pricing/claude-sonnet", strings.NewReader(`{"prompt_price_per_1m":3,"completion_price_per_1m":15,"cache_price_per_1m":0.3}`)) req.Header.Set("Content-Type", "application/json") @@ -112,7 +112,7 @@ func TestUpdatePricingRouteAcceptsModelInBody(t *testing.T) { CachePricePer1M: 0.3, }, } - router := NewRouter(nil, nil, nil, nil, nil, provider, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, nil, provider, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodPut, "/api/v1/pricing", strings.NewReader(`{"model":"openai/gpt-4.1","prompt_price_per_1m":3,"completion_price_per_1m":15,"cache_price_per_1m":0.3}`)) req.Header.Set("Content-Type", "application/json") @@ -129,7 +129,7 @@ func TestUpdatePricingRouteAcceptsModelInBody(t *testing.T) { func TestDeletePricingRoute(t *testing.T) { provider := &pricingStub{} - router := NewRouter(nil, nil, nil, nil, nil, provider, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, nil, provider, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodDelete, "/api/v1/pricing?model=openai%2Fgpt-4.1", nil) resp := httptest.NewRecorder() diff --git a/internal/api/provider_metadata.go b/internal/api/provider_metadata.go deleted file mode 100644 index a3d8b22b..00000000 --- a/internal/api/provider_metadata.go +++ /dev/null @@ -1,51 +0,0 @@ -package api - -import ( - "net/http" - - "cpa-usage-keeper/internal/models" - "cpa-usage-keeper/internal/service" - "github.com/gin-gonic/gin" -) - -type providerMetadataListResponse struct { - Items []providerMetadataResponse `json:"items"` -} - -type providerMetadataResponse struct { - LookupKey string `json:"lookup_key"` - ProviderType string `json:"provider_type,omitempty"` - DisplayName string `json:"display_name,omitempty"` - ProviderKey string `json:"provider_key,omitempty"` -} - -func registerProviderMetadataRoutes(router gin.IRoutes, provider service.ProviderMetadataProvider) { - router.GET("/provider-metadata", func(c *gin.Context) { - if provider == nil { - c.JSON(http.StatusOK, providerMetadataListResponse{Items: []providerMetadataResponse{}}) - return - } - - items, err := provider.ListProviderMetadata(c.Request.Context()) - if err != nil { - writeInternalError(c, "list provider metadata failed", err) - return - } - - response := make([]providerMetadataResponse, 0, len(items)) - for _, item := range items { - response = append(response, mapProviderMetadataResponse(item)) - } - c.JSON(http.StatusOK, providerMetadataListResponse{Items: response}) - }) -} - -func mapProviderMetadataResponse(item models.ProviderMetadata) providerMetadataResponse { - resolved := usageSourceResolutionFromMetadata(item, item.LookupKey) - return providerMetadataResponse{ - LookupKey: resolved.SourceKey, - ProviderType: item.ProviderType, - DisplayName: resolved.DisplayName, - ProviderKey: resolved.SourceKey, - } -} diff --git a/internal/api/provider_metadata_test.go b/internal/api/provider_metadata_test.go deleted file mode 100644 index 6f2a32ac..00000000 --- a/internal/api/provider_metadata_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package api - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "testing" - - "cpa-usage-keeper/internal/models" -) - -type providerMetadataStub struct { - items []models.ProviderMetadata - err error -} - -func (s providerMetadataStub) ListProviderMetadata(context.Context) ([]models.ProviderMetadata, error) { - return s.items, s.err -} - -func TestProviderMetadataRouteReturnsEmptyResponseWithoutProvider(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, nil, nil, AuthConfig{}, nil, "") - req := httptest.NewRequest(http.MethodGet, "/api/v1/provider-metadata", nil) - resp := httptest.NewRecorder() - - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK || !contains(resp.Body.String(), `"items":[]`) { - t.Fatalf("unexpected response: %d %s", resp.Code, resp.Body.String()) - } -} - -func TestProviderMetadataRouteReturnsStoredMetadata(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, providerMetadataStub{items: []models.ProviderMetadata{{ - LookupKey: "sk-test-1234", - ProviderType: "openai", - DisplayName: "ChatGPT Mirror", - ProviderKey: "openai:ChatGPT Mirror", - }}}, nil, AuthConfig{}, nil, "") - req := httptest.NewRequest(http.MethodGet, "/api/v1/provider-metadata", nil) - resp := httptest.NewRecorder() - - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) - } - body := resp.Body.String() - if contains(body, `sk-test-1234`) { - t.Fatalf("expected raw lookup key to be redacted from response body: %s", body) - } - if !(contains(body, `"lookup_key":"openai:ChatGPT Mirror"`) && contains(body, `"display_name":"ChatGPT Mirror"`) && contains(body, `"provider_key":"openai:ChatGPT Mirror"`)) { - t.Fatalf("unexpected response body: %s", body) - } -} - -func TestProviderMetadataRouteHidesInternalErrors(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, providerMetadataStub{err: errors.New("database contains sk-secret-1234")}, nil, AuthConfig{}, nil, "") - req := httptest.NewRequest(http.MethodGet, "/api/v1/provider-metadata", nil) - resp := httptest.NewRecorder() - - router.ServeHTTP(resp, req) - - body := resp.Body.String() - if resp.Code != http.StatusInternalServerError { - t.Fatalf("expected status 500, got %d", resp.Code) - } - if contains(body, "sk-secret-1234") || contains(body, "database contains") { - t.Fatalf("expected internal error details to be hidden, got %s", body) - } - if !contains(body, `"error":"internal server error"`) { - t.Fatalf("expected stable internal error response, got %s", body) - } -} - -func TestProviderMetadataRouteDisambiguatesSameNamedProviders(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, providerMetadataStub{items: []models.ProviderMetadata{ - {ID: 1, LookupKey: "sk-test-1234", ProviderType: "openai", DisplayName: "Shared"}, - {ID: 2, LookupKey: "sk-test-5678", ProviderType: "openai", DisplayName: "Shared"}, - }}, nil, AuthConfig{}, nil, "") - req := httptest.NewRequest(http.MethodGet, "/api/v1/provider-metadata", nil) - resp := httptest.NewRecorder() - - router.ServeHTTP(resp, req) - - body := resp.Body.String() - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) - } - if !contains(body, `"lookup_key":"provider:1"`) || !contains(body, `"lookup_key":"provider:2"`) { - t.Fatalf("expected provider ids to disambiguate same display names, got %s", body) - } - if contains(body, "sk-test-1234") || contains(body, "sk-test-5678") { - t.Fatalf("expected raw lookup keys to be redacted, got %s", body) - } -} diff --git a/internal/api/router.go b/internal/api/router.go index 78c014af..3477de11 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -6,6 +6,7 @@ import ( "errors" "io" "io/fs" + "log/slog" "net/http" "path" "strconv" @@ -45,16 +46,19 @@ type SyncRunner interface { SyncNow(ctx context.Context) error } +type syncUserMessageError interface { + UserMessage() string +} + func NewRouter( staticFS fs.FS, statusProvider StatusProvider, usageProvider service.UsageProvider, - authFileProvider service.AuthFileProvider, - providerMetadataProvider service.ProviderMetadataProvider, pricingProvider service.PricingProvider, authConfig AuthConfig, authHandler *authHandler, basePath string, + usageIdentityProviders ...service.UsageIdentityProvider, ) *gin.Engine { router := gin.New() router.Use(gin.Recovery()) @@ -73,16 +77,20 @@ func NewRouter( } authHandler.registerRoutes(authGroup) + var usageIdentityProvider service.UsageIdentityProvider + if len(usageIdentityProviders) > 0 { + usageIdentityProvider = usageIdentityProviders[0] + } + protected := apiV1.Group("") protected.Use(authHandler.middleware()) registerStatusRoutes(protected, statusProvider) registerSyncRoutes(protected, statusProvider, &syncLimiter{window: manualSyncRateLimitWindow}) registerUsageOverviewRoute(protected, usageProvider) registerUsageAnalysisRoute(protected, usageProvider) - registerUsageEventsRoute(protected, usageProvider, authFileProvider, providerMetadataProvider) - registerUsageCredentialsRoute(protected, usageProvider, authFileProvider, providerMetadataProvider) - registerAuthFileRoutes(protected, authFileProvider) - registerProviderMetadataRoutes(protected, providerMetadataProvider) + registerUsageEventsRoute(protected, usageProvider, usageIdentityProvider) + registerUsageCredentialsRoute(protected, usageProvider, usageIdentityProvider) + registerUsageIdentityRoutes(protected, usageIdentityProvider) registerPricingRoutes(protected, pricingProvider) if staticFS != nil { @@ -208,6 +216,14 @@ func registerStatusRoutes(router gin.IRoutes, statusProvider StatusProvider) { }) } +func manualSyncErrorMessage(err error) string { + var userMessage syncUserMessageError + if errors.As(err, &userMessage) && userMessage.UserMessage() != "" { + return userMessage.UserMessage() + } + return "manual sync failed" +} + func registerSyncRoutes(router gin.IRoutes, statusProvider StatusProvider, limiter *syncLimiter) { router.POST("/sync", func(c *gin.Context) { if limiter != nil && !limiter.allow(time.Now()) { @@ -226,10 +242,9 @@ func registerSyncRoutes(router gin.IRoutes, statusProvider StatusProvider, limit c.JSON(http.StatusConflict, gin.H{"error": "sync already running"}) return } - if !errors.Is(err, poller.ErrSyncCompletedWithWarnings) { - writeInternalError(c, "manual sync failed", err) - return - } + slog.Error("manual sync failed", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": manualSyncErrorMessage(err)}) + return } if statusProvider, ok := syncRunner.(StatusProvider); ok { diff --git a/internal/api/router_test.go b/internal/api/router_test.go index 95928156..c28d2a64 100644 --- a/internal/api/router_test.go +++ b/internal/api/router_test.go @@ -31,6 +31,18 @@ type syncStatusStub struct { err error } +type userFacingSyncError struct { + message string +} + +func (e userFacingSyncError) Error() string { + return e.message +} + +func (e userFacingSyncError) UserMessage() string { + return e.message +} + func (s statusStub) Status() poller.Status { return s.status } @@ -45,7 +57,7 @@ func (s *syncStatusStub) SyncNow(context.Context) error { } func TestHealthzReturnsOK(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, nil, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodGet, "/healthz", nil) resp := httptest.NewRecorder() @@ -65,7 +77,7 @@ func TestStatusReturnsPollerState(t *testing.T) { LastError: "boom", LastWarning: "metadata unavailable", LastStatus: "completed_with_warnings", - }}, nil, nil, nil, nil, AuthConfig{}, nil, "") + }}, nil, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodGet, "/api/v1/status", nil) resp := httptest.NewRecorder() @@ -89,7 +101,7 @@ func TestStatusReturnsProjectTimezone(t *testing.T) { t.Cleanup(func() { time.Local = previousLocal }) time.Local = location - router := NewRouter(nil, nil, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, nil, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodGet, "/api/v1/status", nil) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) @@ -103,7 +115,7 @@ func TestStatusReturnsProjectTimezone(t *testing.T) { } func TestStatusReturnsEmptyStateWithoutProvider(t *testing.T) { - router := NewRouter(nil, nil, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, nil, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodGet, "/api/v1/status", nil) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) @@ -119,7 +131,7 @@ func TestStatusReturnsEmptyStateWithoutProvider(t *testing.T) { func TestManualSyncTriggersSyncRunner(t *testing.T) { lastRunAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) syncer := &syncStatusStub{status: poller.Status{Running: true, LastRunAt: lastRunAt, LastStatus: "completed"}} - router := NewRouter(nil, syncer, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, syncer, nil, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodPost, "/api/v1/sync", nil) resp := httptest.NewRecorder() @@ -139,7 +151,7 @@ func TestManualSyncTriggersSyncRunner(t *testing.T) { func TestManualSyncReturnsConflictWhenAlreadyRunning(t *testing.T) { syncer := &syncStatusStub{err: poller.ErrSyncAlreadyRunning} - router := NewRouter(nil, syncer, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, syncer, nil, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodPost, "/api/v1/sync", nil) resp := httptest.NewRecorder() @@ -150,29 +162,44 @@ func TestManualSyncReturnsConflictWhenAlreadyRunning(t *testing.T) { } } -func TestManualSyncReturnsWarningsAsSuccessfulStatus(t *testing.T) { +func TestManualSyncReturnsWarningsAsError(t *testing.T) { syncer := &syncStatusStub{ status: poller.Status{LastStatus: "completed_with_warnings", LastWarning: "metadata unavailable"}, err: poller.ErrSyncCompletedWithWarnings, } - router := NewRouter(nil, syncer, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, syncer, nil, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodPost, "/api/v1/sync", nil) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) + if resp.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500, got %d", resp.Code) } - body := resp.Body.String() - if !(contains(body, `"last_status":"completed_with_warnings"`) && contains(body, `"last_warning":"metadata unavailable"`)) { + if body := resp.Body.String(); !contains(body, `"error":"manual sync failed"`) { + t.Fatalf("unexpected response body: %s", body) + } +} + +func TestManualSyncReturnsUserFacingStageError(t *testing.T) { + syncer := &syncStatusStub{err: userFacingSyncError{message: "metadata sync failed"}} + router := NewRouter(nil, syncer, nil, nil, AuthConfig{}, nil, "") + req := httptest.NewRequest(http.MethodPost, "/api/v1/sync", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500, got %d", resp.Code) + } + if body := resp.Body.String(); !contains(body, `"error":"metadata sync failed"`) { t.Fatalf("unexpected response body: %s", body) } } func TestManualSyncRateLimitsRepeatedRequests(t *testing.T) { syncer := &syncStatusStub{status: poller.Status{LastStatus: "completed"}} - router := NewRouter(nil, syncer, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, syncer, nil, nil, AuthConfig{}, nil, "") firstResp := httptest.NewRecorder() firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/sync", nil) @@ -197,7 +224,7 @@ func TestSubpathRoutesOnlyServePrefixedEndpoints(t *testing.T) { router := NewRouter(nil, statusStub{status: poller.Status{ Running: true, LastRunAt: lastRunAt, - }}, nil, nil, nil, nil, AuthConfig{BasePath: "/cpa"}, nil, "/cpa") + }}, nil, nil, AuthConfig{BasePath: "/cpa"}, nil, "/cpa") for _, testCase := range []struct { path string @@ -223,7 +250,7 @@ func TestSubpathStaticRoutesServeOnlyUnderPrefix(t *testing.T) { "assets/app.js": "console.log('ok')", }) - router := NewRouter(staticFS, nil, nil, nil, nil, nil, AuthConfig{BasePath: "/cpa"}, nil, "/cpa") + router := NewRouter(staticFS, nil, nil, nil, AuthConfig{BasePath: "/cpa"}, nil, "/cpa") for _, testCase := range []struct { path string @@ -267,7 +294,7 @@ func TestRootStaticRouteInjectsEmptyBasePath(t *testing.T) { "index.html": `app`, }) - router := NewRouter(staticFS, nil, nil, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(staticFS, nil, nil, nil, AuthConfig{}, nil, "") resp := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) router.ServeHTTP(resp, req) diff --git a/internal/api/usage_analysis_test.go b/internal/api/usage_analysis_test.go index 9d3d5b4f..3a2c8ae9 100644 --- a/internal/api/usage_analysis_test.go +++ b/internal/api/usage_analysis_test.go @@ -82,7 +82,7 @@ func TestUsageAnalysisReturnsAggregatedRows(t *testing.T) { LatencySampleCount: 2, }}, }} - router := NewRouter(nil, nil, provider, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/analysis?range=24h", nil) resp := httptest.NewRecorder() @@ -116,7 +116,7 @@ func TestUsageAnalysisReturnsAggregatedRows(t *testing.T) { } func TestUsageAnalysisRequiresAuthWhenEnabled(t *testing.T) { - router := NewRouter(nil, nil, &usageAnalysisStub{}, nil, nil, nil, AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour}, nil, "") + router := NewRouter(nil, nil, &usageAnalysisStub{}, nil, AuthConfig{Enabled: true, LoginPassword: "secret", SessionTTL: time.Hour}, nil, "") req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/analysis", nil) resp := httptest.NewRecorder() diff --git a/internal/api/usage_credentials.go b/internal/api/usage_credentials.go index 07b6c19a..b554f947 100644 --- a/internal/api/usage_credentials.go +++ b/internal/api/usage_credentials.go @@ -24,8 +24,7 @@ type usageCredentialPayload struct { func registerUsageCredentialsRoute( router gin.IRoutes, usageProvider service.UsageProvider, - authFileProvider service.AuthFileProvider, - providerMetadataProvider service.ProviderMetadataProvider, + usageIdentityProvider service.UsageIdentityProvider, ) { router.GET("/usage/credentials", func(c *gin.Context) { if usageProvider == nil { @@ -45,12 +44,12 @@ func registerUsageCredentialsRoute( return } - authFiles, providerMetadata, err := loadUsageResolutionData(c, authFileProvider, providerMetadataProvider) + identities, err := loadUsageResolutionData(c, usageIdentityProvider) if err != nil { writeInternalError(c, "load usage resolution data failed", err) return } - resolver := newUsageSourceResolver(authFiles, providerMetadata) + resolver := newUsageSourceResolver(identities) c.JSON(http.StatusOK, usageCredentialsResponse{Credentials: buildUsageCredentialsPayload(rows, resolver)}) }) } diff --git a/internal/api/usage_events.go b/internal/api/usage_events.go index 82ac8e87..057bb8e5 100644 --- a/internal/api/usage_events.go +++ b/internal/api/usage_events.go @@ -1,9 +1,12 @@ package api import ( + "fmt" "net/http" + "strings" "time" + "cpa-usage-keeper/internal/models" "cpa-usage-keeper/internal/service" "github.com/gin-gonic/gin" ) @@ -53,8 +56,7 @@ type usageEventTokenPayload struct { func registerUsageEventsRoute( router gin.IRoutes, usageProvider service.UsageProvider, - authFileProvider service.AuthFileProvider, - providerMetadataProvider service.ProviderMetadataProvider, + usageIdentityProvider service.UsageIdentityProvider, ) { router.GET("/usage/events/filters", func(c *gin.Context) { if usageProvider == nil { @@ -62,27 +64,20 @@ func registerUsageEventsRoute( return } - filter, err := parseUsageTimeFilterQuery(c.Request, time.Now().UTC()) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - options, err := usageProvider.ListUsageEventFilterOptions(c.Request.Context(), filter) + options, err := usageProvider.ListUsageEventFilterOptions(c.Request.Context(), service.UsageFilter{}) if err != nil { writeInternalError(c, "list usage event filter options failed", err) return } - authFiles, providerMetadata, err := loadUsageResolutionData(c, authFileProvider, providerMetadataProvider) + identities, err := loadUsageResolutionData(c, usageIdentityProvider) if err != nil { writeInternalError(c, "load usage resolution data failed", err) return } - resolver := newUsageSourceResolver(authFiles, providerMetadata) c.JSON(http.StatusOK, usageEventFilterOptionsResponse{ Models: options.Models, - Sources: buildUsageSourceFilterOptions(options.Sources, resolver), + Sources: buildUsageSourceFilterOptions(options.Sources, identities), }) }) @@ -97,14 +92,10 @@ func registerUsageEventsRoute( c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - - authFiles, providerMetadata, err := loadUsageResolutionData(c, authFileProvider, providerMetadataProvider) - if err != nil { - writeInternalError(c, "load usage resolution data failed", err) + if err := applyUsageEventsSourceFilter(&filter); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - resolver := newUsageSourceResolver(authFiles, providerMetadata) - filter.Source = resolver.rawSourceForPublicValue(filter.Source) rows, err := usageProvider.ListUsageEvents(c.Request.Context(), filter) if err != nil { @@ -112,10 +103,15 @@ func registerUsageEventsRoute( return } + identities, err := loadUsageResolutionData(c, usageIdentityProvider) + if err != nil { + writeInternalError(c, "load usage resolution data failed", err) + return + } c.JSON(http.StatusOK, usageEventsResponse{ - Events: buildUsageEventsPayload(rows.Events, resolver), + Events: buildUsageEventsPayload(rows.Events), Models: rows.Models, - Sources: buildUsageSourceFilterOptions(rows.Sources, resolver), + Sources: buildUsageSourceFilterOptions(rows.Sources, identities), TotalCount: rows.TotalCount, Page: rows.Page, PageSize: rows.PageSize, @@ -124,23 +120,54 @@ func registerUsageEventsRoute( }) } -func buildUsageEventsPayload(rows []service.UsageEventRecord, resolver usageSourceResolver) []usageEventPayload { +func applyUsageEventsSourceFilter(filter *service.UsageFilter) error { + if filter == nil { + return nil + } + source := strings.TrimSpace(filter.Source) + if source == "" { + return nil + } + if value, ok := strings.CutPrefix(source, "auth:"); ok { + value = strings.TrimSpace(value) + if value == "" { + return fmt.Errorf("source auth filter value is required") + } + filter.AuthType = "oauth" + filter.AuthIndex = value + filter.Source = value + filter.Provider = "" + return nil + } + if value, ok := strings.CutPrefix(source, "provider:"); ok { + value = strings.TrimSpace(value) + if value == "" { + return fmt.Errorf("source provider filter value is required") + } + filter.AuthType = "apikey" + filter.Provider = value + filter.Source = "" + filter.AuthIndex = "" + } + return nil +} + +func buildUsageEventsPayload(rows []service.UsageEventRecord) []usageEventPayload { if len(rows) == 0 { return []usageEventPayload{} } payload := make([]usageEventPayload, 0, len(rows)) for _, row := range rows { - resolved := resolver.resolve(row.Source, row.AuthIndex) + source, sourceKey := usageEventPublicSource(row) payload = append(payload, usageEventPayload{ - ID: row.ID, - Timestamp: row.Timestamp.UTC().Format(time.RFC3339), - Model: row.Model, - Source: resolved.DisplayName, - SourceType: resolved.SourceType, - SourceKey: resolved.SourceKey, - AuthIndex: row.AuthIndex, - Failed: row.Failed, - LatencyMS: row.LatencyMS, + ID: row.ID, + Timestamp: row.Timestamp.UTC().Format(time.RFC3339), + Model: row.Model, + Source: source, + SourceKey: sourceKey, + AuthIndex: row.AuthIndex, + Failed: row.Failed, + LatencyMS: row.LatencyMS, Tokens: usageEventTokenPayload{ InputTokens: row.InputTokens, OutputTokens: row.OutputTokens, @@ -153,14 +180,66 @@ func buildUsageEventsPayload(rows []service.UsageEventRecord, resolver usageSour return payload } -func buildUsageSourceFilterOptions(sources []string, resolver usageSourceResolver) []usageSourceFilterOption { - if len(sources) == 0 { +func usageEventPublicSource(row service.UsageEventRecord) (string, string) { + switch strings.TrimSpace(row.AuthType) { + case "apikey": + provider := strings.TrimSpace(row.Provider) + if provider == "" { + provider = "AI Provider" + } + return provider, "provider:" + provider + case "oauth": + source := firstNonEmptyString(row.Source, row.AuthIndex, "unknown") + return source, "auth:" + source + default: + if provider := strings.TrimSpace(row.Provider); provider != "" { + return provider, "provider:" + provider + } + source := firstNonEmptyString(row.Source, row.AuthIndex, "unknown") + return source, "auth:" + source + } +} + +func buildUsageSourceFilterOptions(sources []string, identities []models.UsageIdentity) []usageSourceFilterOption { + if len(identities) == 0 { return []usageSourceFilterOption{} } - options := make([]usageSourceFilterOption, 0, len(sources)) - for _, source := range sources { - resolved := resolver.resolve(source, "") - options = append(options, usageSourceFilterOption{Value: resolved.SourceKey, Label: resolved.DisplayName}) + options := make([]usageSourceFilterOption, 0, len(identities)) + seen := make(map[string]struct{}, len(identities)) + for _, identity := range identities { + if identity.TotalRequests == 0 { + continue + } + option, ok := usageSourceFilterOptionFromIdentity(identity) + if !ok { + continue + } + if _, exists := seen[option.Value]; exists { + continue + } + seen[option.Value] = struct{}{} + options = append(options, option) } return options } + +func usageSourceFilterOptionFromIdentity(identity models.UsageIdentity) (usageSourceFilterOption, bool) { + switch identity.AuthType { + case models.UsageIdentityAuthTypeAuthFile: + value := strings.TrimSpace(identity.Identity) + if value == "" { + return usageSourceFilterOption{}, false + } + label := firstNonEmptyString(identity.Name, value) + return usageSourceFilterOption{Value: "auth:" + value, Label: label}, true + case models.UsageIdentityAuthTypeAIProvider: + provider := safeAIProviderDisplayValue(identity.Provider, identity.Identity, "") + label := firstNonEmptyString(provider, safeAIProviderDisplayValue(identity.Name, identity.Identity, ""), safeAIProviderDisplayValue(identity.Type, identity.Identity, "")) + if label == "" { + return usageSourceFilterOption{}, false + } + return usageSourceFilterOption{Value: "provider:" + label, Label: label}, true + default: + return usageSourceFilterOption{}, false + } +} diff --git a/internal/api/usage_events_test.go b/internal/api/usage_events_test.go index 2b896012..7c881fdc 100644 --- a/internal/api/usage_events_test.go +++ b/internal/api/usage_events_test.go @@ -65,6 +65,8 @@ func TestUsageEventsReturnsFilteredRows(t *testing.T) { ID: 42, Timestamp: time.Date(2026, 4, 22, 11, 0, 0, 0, time.UTC), Model: "claude-sonnet", + AuthType: "apikey", + Provider: "OpenAI Mirror", Source: "sk-provider-key", AuthIndex: "2", Failed: false, @@ -75,7 +77,7 @@ func TestUsageEventsReturnsFilteredRows(t *testing.T) { CachedTokens: 1, TotalTokens: 18, }}} - router := NewRouter(nil, nil, provider, authFileStub{files: []models.AuthFile{{AuthIndex: "2", Email: "user@example.com", Type: "auth-file"}}}, providerMetadataStub{items: []models.ProviderMetadata{{LookupKey: "sk-provider-key", ProviderType: "openai", DisplayName: "OpenAI Mirror", ProviderKey: "openai:OpenAI Mirror"}}}, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/events?range=24h", nil) resp := httptest.NewRecorder() @@ -94,14 +96,11 @@ func TestUsageEventsReturnsFilteredRows(t *testing.T) { if !contains(body, `"source":"OpenAI Mirror"`) { t.Fatalf("expected resolved source display in response body: %s", body) } - if contains(body, `sk-provider-key`) { - t.Fatalf("expected raw source to be redacted from response body: %s", body) + if contains(body, `sk-provider-key`) || contains(body, `sk-provider-prefix`) { + t.Fatalf("expected raw source values to be redacted from response body: %s", body) } - if !contains(body, `"source_type":"openai"`) { - t.Fatalf("expected source type in response body: %s", body) - } - if !contains(body, `"source_key":"openai:OpenAI Mirror"`) { - t.Fatalf("expected source key in response body: %s", body) + if contains(body, `"source_type"`) || !contains(body, `"source_key":"provider:OpenAI Mirror"`) { + t.Fatalf("expected provider source key from usage event provider only, got %s", body) } if !contains(body, `"auth_index":"2"`) { t.Fatalf("expected auth index in response body: %s", body) @@ -120,10 +119,10 @@ func TestUsageEventsReturnsFilteredRows(t *testing.T) { } } -func TestUsageEventsPassesPaginationAndServerFilters(t *testing.T) { +func TestUsageEventsPassesPaginationAndProviderSourceFilter(t *testing.T) { provider := &usageEventsStub{eventsPage: &service.UsageEventsPage{Events: []service.UsageEventRecord{}, TotalCount: 0, Page: 3, PageSize: 100, TotalPages: 0}} - router := NewRouter(nil, nil, provider, nil, providerMetadataStub{items: []models.ProviderMetadata{{LookupKey: "source-a", ProviderType: "openai", DisplayName: "Provider A", ProviderKey: "openai:Provider A"}}}, nil, AuthConfig{}, nil, "") - req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/events?page=3&page_size=100&model=claude-sonnet&source=openai:Provider%20A&result=failed", nil) + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "") + req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/events?page=3&page_size=100&model=claude-sonnet&source=provider:OpenAI%20Mirror&result=failed", nil) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) @@ -134,8 +133,8 @@ func TestUsageEventsPassesPaginationAndServerFilters(t *testing.T) { if provider.lastFilter.Page != 3 || provider.lastFilter.PageSize != 100 || provider.lastFilter.Offset != 200 { t.Fatalf("expected pagination filter, got %+v", provider.lastFilter) } - if provider.lastFilter.Model != "claude-sonnet" || provider.lastFilter.Source != "source-a" || provider.lastFilter.Result != "failed" { - t.Fatalf("expected server-side filters, got %+v", provider.lastFilter) + if provider.lastFilter.Model != "claude-sonnet" || provider.lastFilter.AuthType != "apikey" || provider.lastFilter.Provider != "OpenAI Mirror" || provider.lastFilter.Source != "" || provider.lastFilter.Result != "failed" { + t.Fatalf("expected provider source filter to be translated, got %+v", provider.lastFilter) } body := resp.Body.String() if !contains(body, `"page":3`) || !contains(body, `"page_size":100`) || !contains(body, `"total_count":0`) || !contains(body, `"total_pages":0`) { @@ -143,16 +142,52 @@ func TestUsageEventsPassesPaginationAndServerFilters(t *testing.T) { } } +func TestUsageEventsPassesAuthSourceFilter(t *testing.T) { + provider := &usageEventsStub{eventsPage: &service.UsageEventsPage{Events: []service.UsageEventRecord{}, TotalCount: 0, Page: 1, PageSize: 100, TotalPages: 0}} + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "") + req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/events?source=auth:2", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.Code) + } + if provider.lastFilter.AuthType != "oauth" || provider.lastFilter.AuthIndex != "2" || provider.lastFilter.Source != "2" { + t.Fatalf("expected auth source filter to be translated, got %+v", provider.lastFilter) + } +} + +func TestUsageEventsRejectsEmptyPrefixedSourceFilter(t *testing.T) { + for _, path := range []string{"/api/v1/usage/events?source=auth:", "/api/v1/usage/events?source=provider:"} { + t.Run(path, func(t *testing.T) { + provider := &usageEventsStub{eventsPage: &service.UsageEventsPage{Events: []service.UsageEventRecord{}, TotalCount: 0, Page: 1, PageSize: 100, TotalPages: 0}} + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "") + req := httptest.NewRequest(http.MethodGet, path, nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d with body %s", resp.Code, resp.Body.String()) + } + if provider.filterCalls != 0 { + t.Fatalf("expected invalid source filter not to query events, got %d calls", provider.filterCalls) + } + }) + } +} + func TestUsageEventsReturnsFilterOptions(t *testing.T) { provider := &usageEventsStub{eventsPage: &service.UsageEventsPage{ Events: []service.UsageEventRecord{{ - ID: 7, Timestamp: time.Date(2026, 4, 22, 11, 0, 0, 0, time.UTC), Model: "gpt-5", Source: "source-a", Failed: true, + ID: 7, Timestamp: time.Date(2026, 4, 22, 11, 0, 0, 0, time.UTC), Model: "gpt-5", AuthType: "apikey", Provider: "Provider A", Source: "source-a", Failed: true, }}, Models: []string{"claude-sonnet", "gpt-5"}, Sources: []string{"source-a", "source-b"}, TotalCount: 2, Page: 1, PageSize: 20, TotalPages: 1, }} - router := NewRouter(nil, nil, provider, authFileStub{files: []models.AuthFile{{AuthIndex: "1", Email: "user@example.com", Type: "auth-file"}}}, providerMetadataStub{items: []models.ProviderMetadata{{LookupKey: "source-a", ProviderType: "openai", DisplayName: "Provider A", ProviderKey: "openai:Provider A"}, {LookupKey: "source-b", ProviderType: "anthropic", DisplayName: "Provider B", ProviderKey: "anthropic:Provider B"}}}, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "", usageIdentitiesStub{items: []models.UsageIdentity{{ID: 1, Name: "sk-source-prefix", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "source-a", Type: "openai", Provider: "Provider A", TotalRequests: 1}, {ID: 2, Name: "Provider A", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "source-b", Type: "openai", Provider: "Provider A", TotalRequests: 1}, {ID: 3, Name: "Auth User", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-1", Type: "claude", Provider: "Claude", TotalRequests: 1}}}) req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/events", nil) resp := httptest.NewRecorder() @@ -165,10 +200,10 @@ func TestUsageEventsReturnsFilterOptions(t *testing.T) { if !contains(body, `"models":["claude-sonnet","gpt-5"]`) { t.Fatalf("expected model filter options, got %s", body) } - if !contains(body, `"sources":[`) || !contains(body, `"value":"openai:Provider A"`) || !contains(body, `"label":"Provider A"`) || !contains(body, `"value":"anthropic:Provider B"`) || !contains(body, `"label":"Provider B"`) { - t.Fatalf("expected resolved source filter options, got %s", body) + if !contains(body, `"sources":[`) || !contains(body, `"value":"auth:auth-1"`) || !contains(body, `"label":"Auth User"`) || !contains(body, `"value":"provider:Provider A"`) || !contains(body, `"label":"Provider A"`) { + t.Fatalf("expected prefixed source filter options, got %s", body) } - if contains(body, `"value":"source-a"`) || contains(body, `"value":"source-b"`) { + if contains(body, `"value":"source-a"`) || contains(body, `"value":"source-b"`) || contains(body, `"provider:1"`) || contains(body, `"provider:2"`) || contains(body, `sk-source-prefix`) { t.Fatalf("expected raw source filter values to be redacted, got %s", body) } } @@ -178,7 +213,7 @@ func TestUsageEventFilterOptionsReturnsStableModelsAndSources(t *testing.T) { Models: []string{"claude-sonnet", "gpt-5"}, Sources: []string{"source-a", "source-b"}, }} - router := NewRouter(nil, nil, provider, authFileStub{files: []models.AuthFile{{AuthIndex: "1", Email: "user@example.com", Type: "auth-file"}}}, providerMetadataStub{items: []models.ProviderMetadata{{LookupKey: "source-a", ProviderType: "openai", DisplayName: "Provider A", ProviderKey: "openai:Provider A"}, {LookupKey: "source-b", ProviderType: "anthropic", DisplayName: "Provider B", ProviderKey: "anthropic:Provider B"}}}, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "", usageIdentitiesStub{items: []models.UsageIdentity{{ID: 1, Name: "sk-source-prefix", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "source-a", Type: "openai", Provider: "Provider A", TotalRequests: 3}, {ID: 2, Name: "Provider A", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "source-b", Type: "openai", Provider: "Provider A"}, {ID: 3, Name: "Auth User", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-1", Type: "claude", Provider: "Claude", TotalRequests: 2}, {ID: 4, Name: "Zero Request User", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-zero", Type: "claude", Provider: "Claude"}, {ID: 5, Name: "Zero Provider", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "source-zero", Type: "openai", Provider: "Zero Provider"}}}) req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/events/filters?range=24h&model=ignored&source=ignored&result=failed&page=3&page_size=20", nil) resp := httptest.NewRecorder() @@ -190,19 +225,22 @@ func TestUsageEventFilterOptionsReturnsStableModelsAndSources(t *testing.T) { if provider.filterOptionCalls != 1 || provider.filterCalls != 0 { t.Fatalf("expected filter options endpoint only, events=%d filterOptions=%d", provider.filterCalls, provider.filterOptionCalls) } - if provider.lastFilter.Range != "24h" || provider.lastFilter.Model != "" || provider.lastFilter.Source != "" || provider.lastFilter.Result != "" || provider.lastFilter.Page != 0 || provider.lastFilter.PageSize != 0 { - t.Fatalf("expected time range only filter, got %+v", provider.lastFilter) + if provider.lastFilter.Range != "" || provider.lastFilter.StartTime != nil || provider.lastFilter.EndTime != nil || provider.lastFilter.Model != "" || provider.lastFilter.Source != "" || provider.lastFilter.Result != "" || provider.lastFilter.Page != 0 || provider.lastFilter.PageSize != 0 { + t.Fatalf("expected filters endpoint to ignore query filters, got %+v", provider.lastFilter) } body := resp.Body.String() if !contains(body, `"models":["claude-sonnet","gpt-5"]`) { t.Fatalf("expected stable model filter options, got %s", body) } - if !contains(body, `"sources":[`) || !contains(body, `"value":"openai:Provider A"`) || !contains(body, `"label":"Provider A"`) || !contains(body, `"value":"anthropic:Provider B"`) || !contains(body, `"label":"Provider B"`) { - t.Fatalf("expected stable resolved source filter options, got %s", body) + if !contains(body, `"sources":[`) || !contains(body, `"value":"auth:auth-1"`) || !contains(body, `"label":"Auth User"`) || !contains(body, `"value":"provider:Provider A"`) || !contains(body, `"label":"Provider A"`) { + t.Fatalf("expected stable prefixed source filter options, got %s", body) } - if contains(body, `"value":"source-a"`) || contains(body, `"value":"source-b"`) { + if contains(body, `"value":"source-a"`) || contains(body, `"value":"source-b"`) || contains(body, `"provider:1"`) || contains(body, `"provider:2"`) || contains(body, `sk-source-prefix`) { t.Fatalf("expected raw source filter values to be redacted, got %s", body) } + if contains(body, `Zero Request User`) || contains(body, `Zero Provider`) || contains(body, `auth-zero`) || contains(body, `source-zero`) { + t.Fatalf("expected zero-request source filter options to be omitted, got %s", body) + } } func TestUsageCredentialsReturnsAggregatedRows(t *testing.T) { @@ -217,7 +255,7 @@ func TestUsageCredentialsReturnsAggregatedRows(t *testing.T) { Failed: true, RequestCount: 1, }}} - router := NewRouter(nil, nil, provider, authFileStub{files: []models.AuthFile{{AuthIndex: "2", Email: "user@example.com", Type: "auth-file"}}}, providerMetadataStub{items: []models.ProviderMetadata{{LookupKey: "sk-provider-key", ProviderType: "openai", DisplayName: "OpenAI Mirror", ProviderKey: "openai:OpenAI Mirror"}}}, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "", usageIdentitiesStub{items: []models.UsageIdentity{{ID: 1, Name: "sk-provider-prefix", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "sk-provider-key", Type: "openai", Provider: "OpenAI Mirror"}}}) req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/credentials?range=24h", nil) resp := httptest.NewRecorder() @@ -236,9 +274,12 @@ func TestUsageCredentialsReturnsAggregatedRows(t *testing.T) { if !contains(body, `"source_type":"openai"`) { t.Fatalf("expected source type in response body: %s", body) } - if !contains(body, `"source_key":"openai:OpenAI Mirror"`) { + if !contains(body, `"source_key":"provider:1"`) { t.Fatalf("expected source key in response body: %s", body) } + if contains(body, `sk-provider-key`) || contains(body, `sk-provider-prefix`) { + t.Fatalf("expected raw source values to be redacted from response body: %s", body) + } if !contains(body, `"success_count":3`) || !contains(body, `"failure_count":1`) || !contains(body, `"total_count":4`) { t.Fatalf("expected aggregated counts in response body: %s", body) } diff --git a/internal/api/usage_identities.go b/internal/api/usage_identities.go new file mode 100644 index 00000000..380c8f33 --- /dev/null +++ b/internal/api/usage_identities.go @@ -0,0 +1,125 @@ +package api + +import ( + "net/http" + "strings" + "time" + + "cpa-usage-keeper/internal/models" + "cpa-usage-keeper/internal/redact" + "cpa-usage-keeper/internal/service" + "github.com/gin-gonic/gin" +) + +type usageIdentitiesResponse struct { + Identities []usageIdentityResponse `json:"identities"` +} + +type usageIdentityResponse struct { + ID uint `json:"id"` + Name string `json:"name"` + AuthType models.UsageIdentityAuthType `json:"auth_type"` + AuthTypeName string `json:"auth_type_name"` + Identity string `json:"identity"` + Type string `json:"type"` + Provider string `json:"provider"` + TotalRequests int64 `json:"total_requests"` + SuccessCount int64 `json:"success_count"` + FailureCount int64 `json:"failure_count"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + ReasoningTokens int64 `json:"reasoning_tokens"` + CachedTokens int64 `json:"cached_tokens"` + TotalTokens int64 `json:"total_tokens"` + LastAggregatedUsageEventID uint `json:"last_aggregated_usage_event_id"` + FirstUsedAt *time.Time `json:"first_used_at,omitempty"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + StatsUpdatedAt *time.Time `json:"stats_updated_at,omitempty"` + IsDeleted bool `json:"is_deleted"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` +} + +func registerUsageIdentityRoutes(router gin.IRoutes, usageIdentityProvider service.UsageIdentityProvider) { + router.GET("/usage/identities", func(c *gin.Context) { + if usageIdentityProvider == nil { + c.JSON(http.StatusOK, usageIdentitiesResponse{Identities: []usageIdentityResponse{}}) + return + } + + items, err := usageIdentityProvider.ListUsageIdentities(c.Request.Context()) + if err != nil { + writeInternalError(c, "list usage identities failed", err) + return + } + + response := make([]usageIdentityResponse, 0, len(items)) + for _, item := range items { + response = append(response, mapUsageIdentityResponse(item)) + } + c.JSON(http.StatusOK, usageIdentitiesResponse{Identities: response}) + }) +} + +func mapUsageIdentityResponse(item models.UsageIdentity) usageIdentityResponse { + identity := item.Identity + name := item.Name + identityType := item.Type + provider := item.Provider + if item.AuthType == models.UsageIdentityAuthTypeAIProvider { + identity = redact.APIKeyDisplayName(item.Identity) + identityType = safeAIProviderDisplayValue(item.Type, item.Identity, item.AuthTypeName) + provider = safeAIProviderDisplayValue(item.Provider, item.Identity, firstNonEmptyString(identityType, identity)) + name = safeAIProviderDisplayValue(item.Name, item.Identity, firstNonEmptyString(provider, identityType, identity)) + } + + return usageIdentityResponse{ + ID: item.ID, + Name: name, + AuthType: item.AuthType, + AuthTypeName: item.AuthTypeName, + Identity: identity, + Type: identityType, + Provider: provider, + TotalRequests: item.TotalRequests, + SuccessCount: item.SuccessCount, + FailureCount: item.FailureCount, + InputTokens: item.InputTokens, + OutputTokens: item.OutputTokens, + ReasoningTokens: item.ReasoningTokens, + CachedTokens: item.CachedTokens, + TotalTokens: item.TotalTokens, + LastAggregatedUsageEventID: item.LastAggregatedUsageEventID, + FirstUsedAt: item.FirstUsedAt, + LastUsedAt: item.LastUsedAt, + StatsUpdatedAt: item.StatsUpdatedAt, + IsDeleted: item.IsDeleted, + CreatedAt: item.CreatedAt, + UpdatedAt: item.UpdatedAt, + DeletedAt: item.DeletedAt, + } +} + +func safeAIProviderDisplayValue(value, rawIdentity, fallback string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return fallback + } + if isSensitiveUsageIdentityValue(trimmed, rawIdentity) { + return fallback + } + return trimmed +} + +func isSensitiveUsageIdentityValue(value, rawIdentity string) bool { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return false + } + if raw := strings.TrimSpace(rawIdentity); raw != "" && strings.Contains(trimmed, raw) { + return true + } + lower := strings.ToLower(trimmed) + return strings.Contains(lower, "sk-") || strings.Contains(lower, "aiza") || strings.Contains(lower, "cr_") || strings.Contains(lower, "cr-") +} diff --git a/internal/api/usage_identities_test.go b/internal/api/usage_identities_test.go new file mode 100644 index 00000000..001e9462 --- /dev/null +++ b/internal/api/usage_identities_test.go @@ -0,0 +1,144 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "cpa-usage-keeper/internal/models" + "cpa-usage-keeper/internal/redact" +) + +type usageIdentitiesStub struct { + items []models.UsageIdentity + err error +} + +func (s usageIdentitiesStub) ListUsageIdentities(context.Context) ([]models.UsageIdentity, error) { + return s.items, s.err +} + +func TestUsageIdentitiesRouteReturnsMetadataStatsAndDeletedRows(t *testing.T) { + firstUsedAt := time.Date(2026, 5, 4, 8, 0, 0, 0, time.UTC) + lastUsedAt := time.Date(2026, 5, 4, 9, 0, 0, 0, time.UTC) + statsUpdatedAt := time.Date(2026, 5, 4, 10, 0, 0, 0, time.UTC) + createdAt := time.Date(2026, 5, 3, 8, 0, 0, 0, time.UTC) + updatedAt := time.Date(2026, 5, 4, 10, 30, 0, 0, time.UTC) + deletedAt := time.Date(2026, 5, 4, 11, 0, 0, 0, time.UTC) + + router := NewRouter(nil, nil, nil, nil, AuthConfig{}, nil, "", usageIdentitiesStub{items: []models.UsageIdentity{ + { + ID: 1, + Name: "Claude Desktop", + AuthType: models.UsageIdentityAuthTypeAuthFile, + AuthTypeName: "oauth", + Identity: "2", + Type: "auth-file", + Provider: "anthropic", + TotalRequests: 10, + SuccessCount: 8, + FailureCount: 2, + InputTokens: 100, + OutputTokens: 200, + ReasoningTokens: 30, + CachedTokens: 40, + TotalTokens: 370, + LastAggregatedUsageEventID: 99, + FirstUsedAt: &firstUsedAt, + LastUsedAt: &lastUsedAt, + StatsUpdatedAt: &statsUpdatedAt, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + { + ID: 2, + Name: "Deleted Provider", + AuthType: models.UsageIdentityAuthTypeAIProvider, + AuthTypeName: "apikey", + Identity: "sk-deleted-provider-secret", + Type: "openai", + Provider: "OpenAI", + IsDeleted: true, + DeletedAt: &deletedAt, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + }}) + req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/identities", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + body := resp.Body.String() + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", resp.Code, body) + } + if !contains(body, `"identities":[`) || !contains(body, `"id":1`) || !contains(body, `"identity":"2"`) { + t.Fatalf("expected auth file identity row in response, got %s", body) + } + for _, expected := range []string{ + `"name":"Claude Desktop"`, + `"auth_type":1`, + `"auth_type_name":"oauth"`, + `"type":"auth-file"`, + `"provider":"anthropic"`, + `"total_requests":10`, + `"success_count":8`, + `"failure_count":2`, + `"input_tokens":100`, + `"output_tokens":200`, + `"reasoning_tokens":30`, + `"cached_tokens":40`, + `"total_tokens":370`, + `"last_aggregated_usage_event_id":99`, + `"first_used_at":"2026-05-04T08:00:00Z"`, + `"last_used_at":"2026-05-04T09:00:00Z"`, + `"stats_updated_at":"2026-05-04T10:00:00Z"`, + `"is_deleted":true`, + `"deleted_at":"2026-05-04T11:00:00Z"`, + } { + if !contains(body, expected) { + t.Fatalf("expected %s in response body: %s", expected, body) + } + } +} + +func TestUsageIdentitiesRouteMasksAIProviderIdentity(t *testing.T) { + rawLookupKey := "sk-live-secret-value" + rawPrefix := "sk-live-prefix" + maskedLookupKey := redact.APIKeyDisplayName(rawLookupKey) + router := NewRouter(nil, nil, nil, nil, AuthConfig{}, nil, "", usageIdentitiesStub{items: []models.UsageIdentity{ + {ID: 1, Name: rawPrefix, AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: rawLookupKey, Type: "openai " + rawLookupKey, Provider: "OpenAI " + rawPrefix}, + }}) + req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/identities", nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + body := resp.Body.String() + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", resp.Code, body) + } + if contains(body, rawLookupKey) || contains(body, rawPrefix) { + t.Fatalf("expected raw AI provider lookup values to be hidden, got %s", body) + } + if !contains(body, `"identity":"`+maskedLookupKey+`"`) { + t.Fatalf("expected masked AI provider identity %q in response body: %s", maskedLookupKey, body) + } +} + +func TestUsageIdentityReplacesLegacyMetadataRoutes(t *testing.T) { + router := NewRouter(nil, nil, nil, nil, AuthConfig{}, nil, "", usageIdentitiesStub{}) + for _, path := range []string{"/api/v1/auth-files", "/api/v1/provider-metadata"} { + req := httptest.NewRequest(http.MethodGet, path, nil) + resp := httptest.NewRecorder() + + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusNotFound { + t.Fatalf("expected %s to return 404, got %d: %s", path, resp.Code, resp.Body.String()) + } + } +} diff --git a/internal/api/usage_overview_test.go b/internal/api/usage_overview_test.go index 5982aa30..165a0488 100644 --- a/internal/api/usage_overview_test.go +++ b/internal/api/usage_overview_test.go @@ -8,7 +8,6 @@ import ( "time" "cpa-usage-keeper/internal/cpa" - "cpa-usage-keeper/internal/models" "cpa-usage-keeper/internal/service" ) @@ -68,7 +67,7 @@ func TestUsageOverviewResponseIncludesResolvedRangeAndTimezone(t *testing.T) { time.Local = location provider := &usageFilterStub{overview: &service.UsageOverviewSnapshot{}} - router := NewRouter(nil, nil, provider, nil, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/overview?range=custom&start=2026-04-20&end=2026-04-21", nil) resp := httptest.NewRecorder() @@ -147,7 +146,7 @@ func TestUsageOverviewReturnsFilteredSnapshot(t *testing.T) { }}, }, }} - router := NewRouter(nil, nil, provider, authFileStub{files: []models.AuthFile{{AuthIndex: "2", Email: "user@example.com", Type: "auth-file"}}}, nil, nil, AuthConfig{}, nil, "") + router := NewRouter(nil, nil, provider, nil, AuthConfig{}, nil, "") req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/overview?range=24h", nil) resp := httptest.NewRecorder() diff --git a/internal/api/usage_resolution.go b/internal/api/usage_resolution.go index 81530549..52ff2e2d 100644 --- a/internal/api/usage_resolution.go +++ b/internal/api/usage_resolution.go @@ -8,23 +8,10 @@ import ( func loadUsageResolutionData( c *gin.Context, - authFileProvider service.AuthFileProvider, - providerMetadataProvider service.ProviderMetadataProvider, -) ([]models.AuthFile, []models.ProviderMetadata, error) { - authFiles := []models.AuthFile{} - providerMetadata := []models.ProviderMetadata{} - var err error - if authFileProvider != nil { - authFiles, err = authFileProvider.ListAuthFiles(c.Request.Context()) - if err != nil { - return nil, nil, err - } + usageIdentityProvider service.UsageIdentityProvider, +) ([]models.UsageIdentity, error) { + if usageIdentityProvider == nil { + return []models.UsageIdentity{}, nil } - if providerMetadataProvider != nil { - providerMetadata, err = providerMetadataProvider.ListProviderMetadata(c.Request.Context()) - if err != nil { - return nil, nil, err - } - } - return authFiles, providerMetadata, nil + return usageIdentityProvider.ListUsageIdentities(c.Request.Context()) } diff --git a/internal/api/usage_source_resolution.go b/internal/api/usage_source_resolution.go index 9e15e9fa..b1c939c9 100644 --- a/internal/api/usage_source_resolution.go +++ b/internal/api/usage_source_resolution.go @@ -4,66 +4,41 @@ import ( "strconv" "strings" - "cpa-usage-keeper/internal/cpa" "cpa-usage-keeper/internal/models" "cpa-usage-keeper/internal/redact" ) type usageSourceResolver struct { - authFiles map[string]models.AuthFile - providerMetadata map[string]models.ProviderMetadata - providerRawByKey map[string]string + authIdentities map[string]models.UsageIdentity + providerIdentities map[string]models.UsageIdentity + providerRawByKey map[string]string } -func newUsageSourceResolver(authFiles []models.AuthFile, providerMetadata []models.ProviderMetadata) usageSourceResolver { - authFileMap := make(map[string]models.AuthFile, len(authFiles)) - for _, file := range authFiles { - authIndex := strings.TrimSpace(file.AuthIndex) - if authIndex == "" { +func newUsageSourceResolver(identities []models.UsageIdentity) usageSourceResolver { + authIdentities := make(map[string]models.UsageIdentity, len(identities)) + providerIdentities := make(map[string]models.UsageIdentity, len(identities)) + providerRawByKey := make(map[string]string, len(identities)) + for _, identity := range identities { + key := strings.TrimSpace(identity.Identity) + if key == "" { continue } - authFileMap[authIndex] = file - } - - providerMetadataMap := make(map[string]models.ProviderMetadata, len(providerMetadata)) - providerRawByKey := make(map[string]string, len(providerMetadata)) - for _, item := range providerMetadata { - lookupKey := strings.TrimSpace(item.LookupKey) - if lookupKey == "" { - continue - } - providerMetadataMap[lookupKey] = item - resolved := usageSourceResolutionFromMetadata(item, lookupKey) - if resolved.SourceKey != "" { - providerRawByKey[resolved.SourceKey] = lookupKey + switch identity.AuthType { + case models.UsageIdentityAuthTypeAuthFile: + authIdentities[key] = identity + case models.UsageIdentityAuthTypeAIProvider: + providerIdentities[key] = identity + resolved := usageSourceResolutionFromIdentity(identity, key) + if resolved.SourceKey != "" { + providerRawByKey[resolved.SourceKey] = key + } } } return usageSourceResolver{ - authFiles: authFileMap, - providerMetadata: providerMetadataMap, - providerRawByKey: providerRawByKey, - } -} - -func applyUsageSourceResolution(snapshot *cpa.StatisticsSnapshot, resolver usageSourceResolver) { - if snapshot == nil { - return - } - - for apiName, apiSnapshot := range snapshot.APIs { - for modelName, modelSnapshot := range apiSnapshot.Models { - for i := range modelSnapshot.Details { - resolved := resolver.resolve(modelSnapshot.Details[i].Source, modelSnapshot.Details[i].AuthIndex) - modelSnapshot.Details[i].SourceRaw = modelSnapshot.Details[i].Source - modelSnapshot.Details[i].Source = resolved.DisplayName - modelSnapshot.Details[i].SourceDisplay = resolved.DisplayName - modelSnapshot.Details[i].SourceType = resolved.SourceType - modelSnapshot.Details[i].SourceKey = resolved.SourceKey - } - apiSnapshot.Models[modelName] = modelSnapshot - } - snapshot.APIs[apiName] = apiSnapshot + authIdentities: authIdentities, + providerIdentities: providerIdentities, + providerRawByKey: providerRawByKey, } } @@ -73,20 +48,22 @@ type usageSourceResolution struct { SourceKey string } -func usageSourceResolutionFromMetadata(item models.ProviderMetadata, fallbackLookupKey string) usageSourceResolution { - displayName := firstNonEmptyString(item.DisplayName, item.ProviderType, redact.APIKeyDisplayName(fallbackLookupKey)) - providerType := strings.TrimSpace(item.ProviderType) - providerKey := strings.TrimSpace(item.ProviderKey) - if providerKey == "" && item.ID > 0 { - providerKey = "provider:" + uintToString(item.ID) - } - if providerKey == "" { - providerKey = "provider:" + firstNonEmptyString(providerType, displayName) +func usageSourceResolutionFromIdentity(item models.UsageIdentity, fallbackIdentity string) usageSourceResolution { + identityType := safeAIProviderDisplayValue(item.Type, fallbackIdentity, "") + displayName := firstNonEmptyString( + safeAIProviderDisplayValue(item.Name, fallbackIdentity, ""), + safeAIProviderDisplayValue(item.Provider, fallbackIdentity, ""), + identityType, + redact.APIKeyDisplayName(fallbackIdentity), + ) + sourceKey := "provider:" + uintToString(item.ID) + if item.ID == 0 { + sourceKey = "provider:" + redact.APIKeyDisplayName(fallbackIdentity) } return usageSourceResolution{ DisplayName: displayName, - SourceType: providerType, - SourceKey: providerKey, + SourceType: identityType, + SourceKey: sourceKey, } } @@ -104,18 +81,18 @@ func (r usageSourceResolver) rawSourceForPublicValue(value string) string { func (r usageSourceResolver) resolve(rawSource string, authIndex string) usageSourceResolution { normalizedSource := strings.TrimSpace(rawSource) if normalizedSource != "" { - if item, ok := r.providerMetadata[normalizedSource]; ok { - return usageSourceResolutionFromMetadata(item, normalizedSource) + if item, ok := r.providerIdentities[normalizedSource]; ok { + return usageSourceResolutionFromIdentity(item, normalizedSource) } } normalizedAuthIndex := strings.TrimSpace(authIndex) if normalizedAuthIndex != "" { - if file, ok := r.authFiles[normalizedAuthIndex]; ok { - displayName := firstNonEmptyString(file.Email, file.Label, file.Name, normalizedAuthIndex) + if identity, ok := r.authIdentities[normalizedAuthIndex]; ok { + displayName := firstNonEmptyString(identity.Name, normalizedAuthIndex) return usageSourceResolution{ DisplayName: displayName, - SourceType: firstNonEmptyString(file.Type, file.Provider), + SourceType: firstNonEmptyString(identity.Type, identity.Provider), SourceKey: "auth:" + normalizedAuthIndex, } } diff --git a/internal/api/usage_source_resolution_test.go b/internal/api/usage_source_resolution_test.go deleted file mode 100644 index daabd112..00000000 --- a/internal/api/usage_source_resolution_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package api - -import ( - "testing" - - "cpa-usage-keeper/internal/cpa" - "cpa-usage-keeper/internal/models" -) - -func TestApplyUsageSourceResolutionUsesSharedResolver(t *testing.T) { - snapshot := &cpa.StatisticsSnapshot{ - APIs: map[string]cpa.APISnapshot{ - "provider-a": { - Models: map[string]cpa.ModelSnapshot{ - "claude-sonnet": { - Details: []cpa.RequestDetail{{ - Source: "sk-provider-key", - AuthIndex: "2", - }}, - }, - }, - }, - }, - } - - resolver := newUsageSourceResolver( - []models.AuthFile{{ - AuthIndex: "2", - Email: "user@example.com", - Type: "codex", - }}, - []models.ProviderMetadata{{ - LookupKey: "sk-provider-key", - ProviderType: "openai", - DisplayName: "OpenAI Mirror", - ProviderKey: "openai:OpenAI Mirror", - }}, - ) - - applyUsageSourceResolution(snapshot, resolver) - - detail := snapshot.APIs["provider-a"].Models["claude-sonnet"].Details[0] - if detail.Source != "OpenAI Mirror" { - t.Fatalf("expected resolved source display, got %q", detail.Source) - } - if detail.SourceDisplay != "OpenAI Mirror" { - t.Fatalf("expected source display field to be populated, got %q", detail.SourceDisplay) - } - if detail.SourceType != "openai" { - t.Fatalf("expected provider metadata type to win, got %q", detail.SourceType) - } - if detail.SourceKey != "openai:OpenAI Mirror" { - t.Fatalf("expected provider key to be used for grouping, got %q", detail.SourceKey) - } - if detail.SourceRaw != "sk-provider-key" { - t.Fatalf("expected raw source to be preserved, got %q", detail.SourceRaw) - } -} diff --git a/internal/app/app.go b/internal/app/app.go index 1cebe3b7..51ce2a3d 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "sync" - "time" "cpa-usage-keeper/internal/api" "cpa-usage-keeper/internal/auth" @@ -33,31 +32,19 @@ type Options struct { } type App struct { - Config *config.Config - ConfiguredUsageSyncMode string - DB *gorm.DB - Router *gin.Engine - Poller Runner - Maintenance *StorageCleanupRunner - BackupMaintenance *DatabaseBackupRunner - LogCloser io.Closer + Config *config.Config + DB *gorm.DB + Router *gin.Engine + Poller Runner + Maintenance *StorageCleanupRunner + MetadataSync *MetadataSyncRunner + BackupMaintenance *DatabaseBackupRunner + LogCloser io.Closer backgroundCancel context.CancelFunc backgroundWG sync.WaitGroup } -var redisStartupProbe = func(ctx context.Context, cfg config.Config) error { - client := cpa.NewRedisQueueClient( - cfg.CPABaseURL, - cfg.RedisQueueAddr, - cfg.CPAManagementKey, - cfg.RequestTimeout, - cfg.RedisQueueKey, - cfg.RedisQueueBatchSize, - ) - return client.Probe(ctx) -} - func New() (*App, error) { return NewWithOptions(Options{}) } @@ -82,16 +69,12 @@ func NewWithConfig(cfg config.Config) (*App, error) { _ = logCloser.Close() return nil, err } - if err := runTemporaryStartupSnapshotRunsCleanup(db); err != nil { - _ = closeGormDB(db) - _ = logCloser.Close() - return nil, err - } - configuredUsageSyncMode := cfg.UsageSyncMode - cfg = resolveUsageSyncMode(context.Background(), cfg) syncService := service.NewSyncService(db, cfg) - backgroundPoller := newBackgroundRunner(syncService, cfg) + backgroundPoller := poller.NewRedisDrain(syncService, poller.RedisDrainConfig{ + IdleInterval: cfg.RedisQueueIdleInterval, + ErrorBackoff: cfg.RedisQueueErrorBackoff, + }) var backupMaintenance *DatabaseBackupRunner if cfg.BackupEnabled { sqlDB, err := db.DB() @@ -105,8 +88,7 @@ func NewWithConfig(cfg config.Config) (*App, error) { } usageService := service.NewUsageService(db) - authFileService := service.NewAuthFileService(db) - providerMetadataService := service.NewProviderMetadataService(db) + usageIdentityService := service.NewUsageIdentityService(db) pricingModelsClient := cpa.NewClient(cfg.CPABaseURL, cfg.CPAManagementKey, cfg.RequestTimeout) pricingService := service.NewPricingService(db, pricingModelsClient) sessionManager := auth.NewSessionManager(cfg.AuthSessionTTL) @@ -118,19 +100,17 @@ func NewWithConfig(cfg config.Config) (*App, error) { }, sessionManager) return &App{ - Config: &cfg, - ConfiguredUsageSyncMode: configuredUsageSyncMode, - DB: db, - Poller: backgroundPoller, - Maintenance: NewStorageCleanupRunner(syncService), - BackupMaintenance: backupMaintenance, - LogCloser: logCloser, + Config: &cfg, + DB: db, + Poller: backgroundPoller, + Maintenance: NewStorageCleanupRunner(syncService), + MetadataSync: NewMetadataSyncRunner(syncService, cfg.MetadataSyncInterval), + BackupMaintenance: backupMaintenance, + LogCloser: logCloser, Router: api.NewRouter( webui.Static, - backgroundPoller, + newManualSyncRunner(backgroundPoller, syncService), usageService, - authFileService, - providerMetadataService, pricingService, api.AuthConfig{ Enabled: cfg.AuthEnabled, @@ -140,6 +120,7 @@ func NewWithConfig(cfg config.Config) (*App, error) { }, authHandler, cfg.AppBasePath, + usageIdentityService, ), }, nil } @@ -155,53 +136,6 @@ func closeGormDB(db *gorm.DB) error { return sqlDB.Close() } -func resolveUsageSyncMode(ctx context.Context, cfg config.Config) config.Config { - if cfg.UsageSyncMode != "auto" { - return cfg - } - if err := redisStartupProbe(ctx, cfg); err != nil { - cfg.UsageSyncMode = "legacy_export" - logrus.WithError(err).WithFields(logrus.Fields{ - "configured_mode": "auto", - "effective_mode": cfg.UsageSyncMode, - }).Info("usage sync auto mode resolved") - return cfg - } - cfg.UsageSyncMode = "redis" - logrus.WithFields(logrus.Fields{ - "configured_mode": "auto", - "effective_mode": cfg.UsageSyncMode, - }).Info("usage sync auto mode resolved") - return cfg -} - -func newBackgroundRunner(syncService *service.SyncService, cfg config.Config) Runner { - if cfg.UsageSyncMode == "redis" { - return poller.NewRedisDrain(syncService, poller.RedisDrainConfig{ - IdleInterval: cfg.RedisQueueIdleInterval, - ErrorBackoff: cfg.RedisQueueErrorBackoff, - MetadataInterval: cfg.RedisMetadataSyncInterval, - }) - } - return poller.New(syncService, cfg.PollInterval) -} - -// runTemporaryStartupSnapshotRunsCleanup 是启动期额外执行的 snapshot_runs 治理入口,和每日清理共用 CleanupSnapshotRuns 语义。 -// 它只处理 snapshot_runs 并执行 VACUUM,不包含每日 CleanupStorage 中的 redis_usage_inboxes 清理。 -func runTemporaryStartupSnapshotRunsCleanup(db *gorm.DB) error { - logrus.Info("temporary snapshot runs cleanup started") - if _, err := repository.CleanupSnapshotRuns(db, time.Now()); err != nil { - logrus.WithError(err).Error("temporary snapshot runs cleanup failed") - return err - } - if err := repository.Vacuum(db); err != nil { - logrus.WithError(err).Error("temporary snapshot runs cleanup failed") - return err - } - logrus.Info("temporary snapshot runs cleanup completed") - return nil -} - func (a *App) Close() error { if a == nil { return nil @@ -226,15 +160,6 @@ func (a *App) Run() error { return fmt.Errorf("application is not initialized") } - configuredMode := a.ConfiguredUsageSyncMode - if configuredMode == "" { - configuredMode = a.Config.UsageSyncMode - } - logrus.WithFields(logrus.Fields{ - "configured_mode": configuredMode, - "effective_mode": a.Config.UsageSyncMode, - }).Info("usage sync mode selected") - ctx := a.startBackgroundContext() defer a.stopBackgroundTasks() if a.Poller != nil { @@ -251,6 +176,13 @@ func (a *App) Run() error { } }) } + if a.MetadataSync != nil { + a.startBackgroundTask(func() { + if err := a.MetadataSync.Run(ctx); err != nil { + logrus.Errorf("metadata sync stopped: %v", err) + } + }) + } if a.BackupMaintenance != nil { a.startBackgroundTask(func() { if err := a.BackupMaintenance.Run(ctx); err != nil { diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 5c78abb2..aa70bb27 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -3,21 +3,17 @@ package app import ( "bytes" "context" - "errors" - "strings" "testing" "time" "cpa-usage-keeper/internal/config" - "cpa-usage-keeper/internal/models" "cpa-usage-keeper/internal/poller" - "cpa-usage-keeper/internal/repository" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) func TestAppCloseClosesDatabase(t *testing.T) { - app, err := NewWithConfig(testAppConfig(t, "legacy_export")) + app, err := NewWithConfig(testAppConfig(t)) if err != nil { t.Fatalf("NewWithConfig returned error: %v", err) } @@ -35,8 +31,8 @@ func TestAppCloseClosesDatabase(t *testing.T) { } } -func TestNewWithConfigBuildsPollerAndRouter(t *testing.T) { - app, err := NewWithConfig(testAppConfig(t, "legacy_export")) +func TestNewWithConfigBuildsRedisDrainAndRouter(t *testing.T) { + app, err := NewWithConfig(testAppConfig(t)) if err != nil { t.Fatalf("NewWithConfig returned error: %v", err) } @@ -53,10 +49,13 @@ func TestNewWithConfigBuildsPollerAndRouter(t *testing.T) { if app.BackupMaintenance == nil { t.Fatal("expected database backup runner to be initialized") } + if app.MetadataSync == nil { + t.Fatal("expected metadata sync runner to be initialized") + } } func TestNewWithConfigSkipsBackupRunnerWhenDisabled(t *testing.T) { - cfg := testAppConfig(t, "legacy_export") + cfg := testAppConfig(t) cfg.BackupEnabled = false app, err := NewWithConfig(cfg) if err != nil { @@ -68,22 +67,8 @@ func TestNewWithConfigSkipsBackupRunnerWhenDisabled(t *testing.T) { } } -func TestNewWithConfigSelectsLegacyPoller(t *testing.T) { - app, err := NewWithConfig(testAppConfig(t, "legacy_export")) - if err != nil { - t.Fatalf("NewWithConfig returned error: %v", err) - } - defer app.Close() - if _, ok := app.Poller.(*poller.Poller); !ok { - t.Fatalf("expected legacy_export to use interval poller, got %T", app.Poller) - } - if app.Maintenance == nil { - t.Fatal("expected maintenance cleanup runner to be initialized") - } -} - func TestNewWithConfigSelectsRedisDrain(t *testing.T) { - app, err := NewWithConfig(testAppConfig(t, "redis")) + app, err := NewWithConfig(testAppConfig(t)) if err != nil { t.Fatalf("NewWithConfig returned error: %v", err) } @@ -96,223 +81,37 @@ func TestNewWithConfigSelectsRedisDrain(t *testing.T) { } } -func TestNewWithConfigRunsTemporaryStartupSnapshotCleanup(t *testing.T) { - cfg := testAppConfig(t, "legacy_export") - seedDB, err := repository.OpenDatabase(cfg) - if err != nil { - t.Fatalf("OpenDatabase returned error: %v", err) - } - oldRun, err := repository.CreateSnapshotRun(seedDB, repository.SnapshotRunInput{FetchedAt: time.Now().AddDate(0, 0, -8), RawPayload: []byte(`old`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun old returned error: %v", err) - } - latestRun, err := repository.CreateSnapshotRun(seedDB, repository.SnapshotRunInput{FetchedAt: time.Now(), RawPayload: []byte(`latest`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun latest returned error: %v", err) - } - sqlDB, err := seedDB.DB() - if err != nil { - t.Fatalf("load seed sql db: %v", err) - } - if err := sqlDB.Close(); err != nil { - t.Fatalf("close seed sql db: %v", err) - } - - app, err := NewWithConfig(cfg) - if err != nil { - t.Fatalf("NewWithConfig returned error: %v", err) - } - defer app.Close() - - var remaining []models.SnapshotRun - if err := app.DB.Order("id asc").Find(&remaining).Error; err != nil { - t.Fatalf("load remaining snapshot runs: %v", err) - } - if len(remaining) != 1 || remaining[0].ID != latestRun.ID { - t.Fatalf("expected startup cleanup to keep only latest snapshot %d and delete %d, got %+v", latestRun.ID, oldRun.ID, remaining) - } -} - func TestNewWithConfigCreatesIndependentMaintenanceRunner(t *testing.T) { - for _, mode := range []string{"redis", "legacy_export"} { - t.Run(mode, func(t *testing.T) { - app, err := NewWithConfig(testAppConfig(t, mode)) - if err != nil { - t.Fatalf("NewWithConfig returned error: %v", err) - } - defer app.Close() - if app.Poller == nil { - t.Fatal("expected sync poller to be initialized") - } - if app.Maintenance == nil { - t.Fatal("expected independent maintenance runner to be initialized") - } - }) - } -} - -func TestNewWithConfigAutoUsesRedisDrainWhenStartupProbeSucceeds(t *testing.T) { - probeCalls := 0 - withRedisStartupProbe(t, func(context.Context, config.Config) error { - probeCalls++ - return nil - }) - - app, err := NewWithConfig(testAppConfig(t, "auto")) + app, err := NewWithConfig(testAppConfig(t)) if err != nil { t.Fatalf("NewWithConfig returned error: %v", err) } defer app.Close() - - if probeCalls != 1 { - t.Fatalf("expected one startup probe, got %d", probeCalls) - } - if app.Config.UsageSyncMode != "redis" { - t.Fatalf("expected effective mode redis, got %q", app.Config.UsageSyncMode) - } - if _, ok := app.Poller.(*poller.RedisDrain); !ok { - t.Fatalf("expected auto with successful probe to use redis drain, got %T", app.Poller) - } - if app.Maintenance == nil { - t.Fatal("expected maintenance cleanup runner to be initialized") - } -} - -func TestNewWithConfigAutoUsesLegacyPollerWhenStartupProbeFails(t *testing.T) { - probeCalls := 0 - withRedisStartupProbe(t, func(context.Context, config.Config) error { - probeCalls++ - return errors.New("redis unavailable") - }) - - app, err := NewWithConfig(testAppConfig(t, "auto")) - if err != nil { - t.Fatalf("NewWithConfig returned error: %v", err) - } - defer app.Close() - - if probeCalls != 1 { - t.Fatalf("expected one startup probe, got %d", probeCalls) - } - if app.Config.UsageSyncMode != "legacy_export" { - t.Fatalf("expected effective mode legacy_export, got %q", app.Config.UsageSyncMode) - } - if app.Config.PollInterval != time.Minute { - t.Fatalf("expected auto resolved legacy poller to keep configured poll interval, got %s", app.Config.PollInterval) - } - if _, ok := app.Poller.(*poller.Poller); !ok { - t.Fatalf("expected auto with failed probe to use legacy poller, got %T", app.Poller) + if app.Poller == nil { + t.Fatal("expected sync poller to be initialized") } if app.Maintenance == nil { - t.Fatal("expected maintenance cleanup runner to be initialized") - } -} - -func TestResolveUsageSyncModeLogsEffectiveMode(t *testing.T) { - for _, tc := range []struct { - name string - probeErr error - effective string - }{ - {name: "redis", effective: "redis"}, - {name: "legacy_export", probeErr: errors.New("redis unavailable"), effective: "legacy_export"}, - } { - t.Run(tc.name, func(t *testing.T) { - logs := captureAppInfoLogs(t) - withRedisStartupProbe(t, func(context.Context, config.Config) error { - return tc.probeErr - }) - - cfg := testAppConfig(t, "auto") - resolved := resolveUsageSyncMode(context.Background(), cfg) - if resolved.UsageSyncMode != tc.effective { - t.Fatalf("expected effective mode %q, got %q", tc.effective, resolved.UsageSyncMode) - } - if resolved.PollInterval != cfg.PollInterval { - t.Fatalf("expected poll interval to remain %s, got %s", cfg.PollInterval, resolved.PollInterval) - } - content := logs.String() - for _, expected := range []string{"level=info", "msg=\"usage sync auto mode resolved\"", "configured_mode=auto", "effective_mode=" + tc.effective} { - if !strings.Contains(content, expected) { - t.Fatalf("expected auto resolution log to contain %q, got %q", expected, content) - } - } - }) - } -} - -func TestNewWithConfigDoesNotProbeForExplicitModes(t *testing.T) { - withRedisStartupProbe(t, func(context.Context, config.Config) error { - t.Fatal("unexpected startup probe") - return nil - }) - - for _, mode := range []string{"redis", "legacy_export"} { - t.Run(mode, func(t *testing.T) { - app, err := NewWithConfig(testAppConfig(t, mode)) - if err != nil { - t.Fatalf("NewWithConfig returned error: %v", err) - } - defer app.Close() - if app.Config.UsageSyncMode != mode { - t.Fatalf("expected mode %q to remain unchanged, got %q", mode, app.Config.UsageSyncMode) - } - }) - } -} - -func TestTemporaryStartupSnapshotCleanupLogsStartAndSuccess(t *testing.T) { - logs := captureAppInfoLogs(t) - db, err := repository.OpenDatabase(testAppConfig(t, "legacy_export")) - if err != nil { - t.Fatalf("OpenDatabase returned error: %v", err) - } - - if err := runTemporaryStartupSnapshotRunsCleanup(db); err != nil { - t.Fatalf("runTemporaryStartupSnapshotRunsCleanup returned error: %v", err) - } - sqlDB, err := db.DB() - if err != nil { - t.Fatalf("load sql db: %v", err) - } - if err := sqlDB.Close(); err != nil { - t.Fatalf("close sql db: %v", err) - } - - content := logs.String() - for _, expected := range []string{"level=info", "msg=\"temporary snapshot runs cleanup started\"", "msg=\"temporary snapshot runs cleanup completed\""} { - if !strings.Contains(content, expected) { - t.Fatalf("expected temporary cleanup log to contain %q, got %q", expected, content) - } - } -} - -func TestTemporaryStartupSnapshotCleanupLogsFailure(t *testing.T) { - logs := captureAppInfoLogs(t) - - if err := runTemporaryStartupSnapshotRunsCleanup(nil); err == nil { - t.Fatal("expected runTemporaryStartupSnapshotRunsCleanup to return an error") - } - - content := logs.String() - for _, expected := range []string{"level=error", "msg=\"temporary snapshot runs cleanup failed\""} { - if !strings.Contains(content, expected) { - t.Fatalf("expected temporary cleanup error log to contain %q, got %q", expected, content) - } + t.Fatal("expected independent maintenance runner to be initialized") } } func TestRunStartsPollerAndMaintenanceIndependently(t *testing.T) { - cfg := testAppConfig(t, "redis") + cfg := testAppConfig(t) cfg.AppPort = "invalid-port" pollerStarted := make(chan struct{}) maintenanceStarted := make(chan struct{}) + metadataStarted := make(chan struct{}) backupStarted := make(chan struct{}) maintenance := NewStorageCleanupRunner(&maintenanceSyncStub{}) maintenance.sleep = func(context.Context, time.Duration) bool { close(maintenanceStarted) return false } + metadataRunner := NewMetadataSyncRunner(&metadataSyncStub{}, time.Second) + metadataRunner.sleep = func(context.Context, time.Duration) bool { + close(metadataStarted) + return false + } backupRunner := NewDatabaseBackupRunner(&databaseBackupWriterStub{}, nil, time.Second, 0) backupRunner.sleep = func(context.Context, time.Duration) bool { close(backupStarted) @@ -323,6 +122,7 @@ func TestRunStartsPollerAndMaintenanceIndependently(t *testing.T) { Router: gin.New(), Poller: &appRunStub{started: pollerStarted}, Maintenance: maintenance, + MetadataSync: metadataRunner, BackupMaintenance: backupRunner, } @@ -340,6 +140,11 @@ func TestRunStartsPollerAndMaintenanceIndependently(t *testing.T) { t.Fatal("expected maintenance runner to start") } select { + case <-metadataStarted: + case <-time.After(time.Second): + t.Fatal("expected metadata sync runner to start") + } + select { case <-backupStarted: case <-time.After(time.Second): t.Fatal("expected database backup runner to start") @@ -347,7 +152,7 @@ func TestRunStartsPollerAndMaintenanceIndependently(t *testing.T) { } func TestRunCancelsBackgroundTasksWhenRouterStops(t *testing.T) { - cfg := testAppConfig(t, "redis") + cfg := testAppConfig(t) cfg.AppPort = "invalid-port" backupStarted := make(chan struct{}) backupCanceled := make(chan struct{}) @@ -379,39 +184,6 @@ func TestRunCancelsBackgroundTasksWhenRouterStops(t *testing.T) { } } -func TestRunLogsConfiguredUsageSyncMode(t *testing.T) { - var logs bytes.Buffer - previousOutput := logrus.StandardLogger().Out - previousFormatter := logrus.StandardLogger().Formatter - previousLevel := logrus.GetLevel() - logrus.SetOutput(&logs) - logrus.SetFormatter(&logrus.TextFormatter{DisableTimestamp: true}) - logrus.SetLevel(logrus.InfoLevel) - t.Cleanup(func() { - logrus.SetOutput(previousOutput) - logrus.SetFormatter(previousFormatter) - logrus.SetLevel(previousLevel) - }) - - cfg := testAppConfig(t, "redis") - cfg.AppPort = "invalid-port" - app := &App{ - Config: &cfg, - Router: gin.New(), - } - - if err := app.Run(); err == nil { - t.Fatal("expected Run to return an error for invalid port") - } - - content := logs.String() - for _, expected := range []string{"msg=\"usage sync mode selected\"", "configured_mode=redis", "effective_mode=redis"} { - if !strings.Contains(content, expected) { - t.Fatalf("expected usage sync mode log to contain %q, got %q", expected, content) - } - } -} - type appRunStub struct { started chan struct{} } @@ -446,31 +218,22 @@ func captureAppInfoLogs(t *testing.T) *bytes.Buffer { return &logs } -func withRedisStartupProbe(t *testing.T, probe func(context.Context, config.Config) error) { - t.Helper() - previous := redisStartupProbe - redisStartupProbe = probe - t.Cleanup(func() { redisStartupProbe = previous }) -} - -func testAppConfig(t *testing.T, syncMode string) config.Config { +func testAppConfig(t *testing.T) config.Config { t.Helper() return config.Config{ - AppPort: "8080", - CPABaseURL: "https://cpa.example.com", - CPAManagementKey: "secret", - UsageSyncMode: syncMode, - PollInterval: time.Minute, - RedisQueueIdleInterval: time.Second, - RedisQueueErrorBackoff: 10 * time.Second, - RedisMetadataSyncInterval: 30 * time.Second, - SQLitePath: t.TempDir() + "/app.db", - BackupEnabled: true, - BackupDir: t.TempDir() + "/backups", - BackupRetentionDays: 7, - RequestTimeout: 5 * time.Second, - LogLevel: "info", - LogFileEnabled: false, - LogRetentionDays: 7, + AppPort: "8080", + CPABaseURL: "https://cpa.example.com", + CPAManagementKey: "secret", + RedisQueueIdleInterval: time.Second, + RedisQueueErrorBackoff: 10 * time.Second, + MetadataSyncInterval: 30 * time.Second, + SQLitePath: t.TempDir() + "/app.db", + BackupEnabled: true, + BackupDir: t.TempDir() + "/backups", + BackupRetentionDays: 7, + RequestTimeout: 5 * time.Second, + LogLevel: "info", + LogFileEnabled: false, + LogRetentionDays: 7, } } diff --git a/internal/app/manual_sync.go b/internal/app/manual_sync.go new file mode 100644 index 00000000..a47f0e4f --- /dev/null +++ b/internal/app/manual_sync.go @@ -0,0 +1,60 @@ +package app + +import ( + "context" + "fmt" + + "cpa-usage-keeper/internal/poller" +) + +type manualSyncRunner struct { + redis Runner + metadata MetadataSyncer +} + +type manualSyncStageError struct { + message string + err error +} + +func (e manualSyncStageError) Error() string { + return fmt.Sprintf("%s: %v", e.message, e.err) +} + +func (e manualSyncStageError) Unwrap() error { + return e.err +} + +func (e manualSyncStageError) UserMessage() string { + return e.message +} + +func newManualSyncRunner(redis Runner, metadata MetadataSyncer) *manualSyncRunner { + return &manualSyncRunner{redis: redis, metadata: metadata} +} + +func (r *manualSyncRunner) Status() poller.Status { + if r == nil || r.redis == nil { + return poller.Status{} + } + return r.redis.Status() +} + +func (r *manualSyncRunner) SyncNow(ctx context.Context) error { + if r == nil { + return fmt.Errorf("manual sync runner is nil") + } + if r.redis == nil { + return fmt.Errorf("manual redis syncer is nil") + } + if err := r.redis.SyncNow(ctx); err != nil { + return manualSyncStageError{message: "redis sync failed", err: err} + } + if r.metadata == nil { + return fmt.Errorf("manual metadata syncer is nil") + } + if err := r.metadata.SyncMetadata(ctx); err != nil { + return manualSyncStageError{message: "metadata sync failed", err: err} + } + return nil +} diff --git a/internal/app/manual_sync_test.go b/internal/app/manual_sync_test.go new file mode 100644 index 00000000..e95f9441 --- /dev/null +++ b/internal/app/manual_sync_test.go @@ -0,0 +1,100 @@ +package app + +import ( + "context" + "errors" + "reflect" + "testing" + + "cpa-usage-keeper/internal/poller" +) + +type manualRedisSyncStub struct { + status poller.Status + err error + calls int + order *[]string +} + +func (s *manualRedisSyncStub) Run(context.Context) error { + return nil +} + +func (s *manualRedisSyncStub) Status() poller.Status { + return s.status +} + +func (s *manualRedisSyncStub) SyncNow(context.Context) error { + s.calls++ + if s.order != nil { + *s.order = append(*s.order, "redis") + } + return s.err +} + +type manualMetadataSyncStub struct { + err error + calls int + order *[]string +} + +func (s *manualMetadataSyncStub) SyncMetadata(context.Context) error { + s.calls++ + if s.order != nil { + *s.order = append(*s.order, "metadata") + } + return s.err +} + +func TestManualSyncRunnerRunsRedisThenMetadata(t *testing.T) { + var order []string + redis := &manualRedisSyncStub{status: poller.Status{LastStatus: "completed"}, order: &order} + metadata := &manualMetadataSyncStub{order: &order} + runner := newManualSyncRunner(redis, metadata) + + if err := runner.SyncNow(context.Background()); err != nil { + t.Fatalf("SyncNow returned error: %v", err) + } + if !reflect.DeepEqual(order, []string{"redis", "metadata"}) { + t.Fatalf("expected redis then metadata sync, got %v", order) + } + if runner.Status().LastStatus != "completed" { + t.Fatalf("expected status to delegate to redis runner, got %+v", runner.Status()) + } +} + +func TestManualSyncRunnerReturnsRedisErrorWithoutMetadata(t *testing.T) { + redisErr := errors.New("redis failed") + redis := &manualRedisSyncStub{err: redisErr} + metadata := &manualMetadataSyncStub{} + runner := newManualSyncRunner(redis, metadata) + + err := runner.SyncNow(context.Background()) + if !errors.Is(err, redisErr) { + t.Fatalf("expected redis error, got %v", err) + } + if err == nil || err.Error() != "redis sync failed: redis failed" { + t.Fatalf("expected redis-specific sync error, got %v", err) + } + if metadata.calls != 0 { + t.Fatalf("expected metadata sync not to run after redis failure, got %d calls", metadata.calls) + } +} + +func TestManualSyncRunnerReturnsMetadataError(t *testing.T) { + metadataErr := errors.New("metadata failed") + redis := &manualRedisSyncStub{} + metadata := &manualMetadataSyncStub{err: metadataErr} + runner := newManualSyncRunner(redis, metadata) + + err := runner.SyncNow(context.Background()) + if !errors.Is(err, metadataErr) { + t.Fatalf("expected metadata error, got %v", err) + } + if err == nil || err.Error() != "metadata sync failed: metadata failed" { + t.Fatalf("expected metadata-specific sync error, got %v", err) + } + if redis.calls != 1 || metadata.calls != 1 { + t.Fatalf("expected redis and metadata to run once, got redis=%d metadata=%d", redis.calls, metadata.calls) + } +} diff --git a/internal/app/metadata_sync.go b/internal/app/metadata_sync.go new file mode 100644 index 00000000..a8cdd814 --- /dev/null +++ b/internal/app/metadata_sync.go @@ -0,0 +1,74 @@ +package app + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sirupsen/logrus" +) + +type MetadataSyncer interface { + SyncMetadata(ctx context.Context) error +} + +type MetadataSyncRunner struct { + syncer MetadataSyncer + interval time.Duration + sleep func(context.Context, time.Duration) bool + + mu sync.Mutex + running bool +} + +func NewMetadataSyncRunner(syncer MetadataSyncer, interval time.Duration) *MetadataSyncRunner { + return &MetadataSyncRunner{ + syncer: syncer, + interval: interval, + sleep: maintenanceSleepContext, + } +} + +// Run 启动独立 metadata 同步任务:启动后立即执行一次,之后按固定间隔刷新 auth files 和 provider metadata。 +func (r *MetadataSyncRunner) Run(ctx context.Context) error { + if err := r.validate(); err != nil { + return err + } + logrus.Info("metadata sync task started") + r.setRunning(true) + defer r.setRunning(false) + + delay := time.Duration(0) + for { + if !r.sleep(ctx, delay) { + return nil + } + if err := r.syncer.SyncMetadata(ctx); err != nil { + logrus.WithError(err).Error("metadata sync failed") + } + delay = r.interval + } +} + +func (r *MetadataSyncRunner) validate() error { + if r == nil { + return fmt.Errorf("metadata sync runner is nil") + } + if r.syncer == nil { + return fmt.Errorf("metadata syncer is nil") + } + if r.interval <= 0 { + return fmt.Errorf("metadata sync interval must be positive") + } + if r.sleep == nil { + r.sleep = maintenanceSleepContext + } + return nil +} + +func (r *MetadataSyncRunner) setRunning(running bool) { + r.mu.Lock() + defer r.mu.Unlock() + r.running = running +} diff --git a/internal/app/metadata_sync_test.go b/internal/app/metadata_sync_test.go new file mode 100644 index 00000000..97549ec2 --- /dev/null +++ b/internal/app/metadata_sync_test.go @@ -0,0 +1,83 @@ +package app + +import ( + "context" + "errors" + "strings" + "testing" + "time" +) + +type metadataSyncStub struct { + calls int + errs []error +} + +func (s *metadataSyncStub) SyncMetadata(context.Context) error { + s.calls++ + call := s.calls + if len(s.errs) >= call { + return s.errs[call-1] + } + if len(s.errs) > 0 { + return s.errs[len(s.errs)-1] + } + return nil +} + +func TestMetadataSyncRunnerRunsImmediatelyThenAtInterval(t *testing.T) { + syncer := &metadataSyncStub{} + runner := NewMetadataSyncRunner(syncer, 15*time.Minute) + var delays []time.Duration + runner.sleep = func(_ context.Context, d time.Duration) bool { + delays = append(delays, d) + return len(delays) < 3 + } + + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) + } + if syncer.calls != 2 { + t.Fatalf("expected two metadata sync calls, got %d", syncer.calls) + } + expected := []time.Duration{0, 15 * time.Minute, 15 * time.Minute} + if len(delays) != len(expected) { + t.Fatalf("expected delays %+v, got %+v", expected, delays) + } + for i, want := range expected { + if delays[i] != want { + t.Fatalf("expected delay %d to be %s, got %s", i, want, delays[i]) + } + } +} + +func TestMetadataSyncRunnerLogsFailureAndContinues(t *testing.T) { + logs := captureAppInfoLogs(t) + syncer := &metadataSyncStub{errs: []error{errors.New("metadata endpoint failed"), nil}} + runner := NewMetadataSyncRunner(syncer, time.Minute) + sleepCalls := 0 + runner.sleep = func(context.Context, time.Duration) bool { + sleepCalls++ + return sleepCalls < 3 + } + + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) + } + if syncer.calls != 2 { + t.Fatalf("expected runner to continue after metadata error, got %d calls", syncer.calls) + } + content := logs.String() + if !strings.Contains(content, "level=error") || !strings.Contains(content, "msg=\"metadata sync failed\"") { + t.Fatalf("expected metadata sync failure error log, got %q", content) + } +} + +func TestMetadataSyncRunnerValidatesConfig(t *testing.T) { + if err := NewMetadataSyncRunner(nil, time.Minute).Run(context.Background()); err == nil { + t.Fatal("expected nil syncer validation error") + } + if err := NewMetadataSyncRunner(&metadataSyncStub{}, 0).Run(context.Background()); err == nil { + t.Fatal("expected non-positive interval validation error") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index ff75a8b3..9b9b151b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,10 +15,10 @@ import ( ) const ( - DefaultTimeZone = "Asia/Shanghai" - RedisQueueKeyDefault = cpa.ManagementUsageQueueKey - RedisQueueErrorBackoffDefault = 10 * time.Second - RedisMetadataSyncIntervalDefault = 30 * time.Second + DefaultTimeZone = "Asia/Shanghai" + RedisQueueKeyDefault = cpa.ManagementUsageQueueKey + RedisQueueErrorBackoffDefault = 10 * time.Second + MetadataSyncIntervalDefault = 30 * time.Second ) var ( @@ -40,10 +40,6 @@ type Config struct { CPABaseURL string // CPAManagementKey 是访问 CPA 管理数据的密钥。 CPAManagementKey string - // PollInterval 是 legacy export 拉取间隔。 - PollInterval time.Duration - // UsageSyncMode 决定使用 auto、redis 或 legacy_export;auto 会在启动时解析为一种有效模式。 - UsageSyncMode string // RedisQueueAddr 是 CPA management data stream 的 TCP 地址,空值时按 CPA_BASE_URL 推导。 RedisQueueAddr string // RedisQueueKey 是 CPA usage 队列名。 @@ -54,8 +50,8 @@ type Config struct { RedisQueueIdleInterval time.Duration // RedisQueueErrorBackoff 是 Redis 临时错误后的固定退避间隔。 RedisQueueErrorBackoff time.Duration - // RedisMetadataSyncInterval 是 Redis drain 模式下 metadata 的固定刷新间隔。 - RedisMetadataSyncInterval time.Duration + // MetadataSyncInterval 是 auth files 和 provider metadata 的固定刷新间隔。 + MetadataSyncInterval time.Duration // WorkDir 是应用工作目录,数据库、日志和备份默认从这里派生。 WorkDir string // SQLitePath 是 SQLite 数据库文件路径。 @@ -111,16 +107,6 @@ func Load(options LoadOptions) (*Config, error) { return nil, err } - usageSyncMode := getString("USAGE_SYNC_MODE", "auto") - if usageSyncMode != "auto" && usageSyncMode != "redis" && usageSyncMode != "legacy_export" { - return nil, fmt.Errorf("USAGE_SYNC_MODE must be one of auto, redis, legacy_export") - } - - pollInterval, err := getDuration("POLL_INTERVAL", 5*time.Minute) - if err != nil { - return nil, err - } - redisQueueBatchSize, err := getInt("REDIS_QUEUE_BATCH_SIZE", 1000) if err != nil { return nil, err @@ -196,32 +182,30 @@ func Load(options LoadOptions) (*Config, error) { workDir := getString("WORK_DIR", DefaultWorkDir) cfg := &Config{ - AppPort: getString("APP_PORT", "8080"), - AppBasePath: appBasePath, - CPABaseURL: strings.TrimSpace(os.Getenv("CPA_BASE_URL")), - CPAManagementKey: strings.TrimSpace(os.Getenv("CPA_MANAGEMENT_KEY")), - PollInterval: pollInterval, - UsageSyncMode: usageSyncMode, - RedisQueueAddr: strings.TrimSpace(os.Getenv("REDIS_QUEUE_ADDR")), - RedisQueueKey: RedisQueueKeyDefault, - RedisQueueBatchSize: redisQueueBatchSize, - RedisQueueIdleInterval: redisQueueIdleInterval, - RedisQueueErrorBackoff: RedisQueueErrorBackoffDefault, - RedisMetadataSyncInterval: RedisMetadataSyncIntervalDefault, - WorkDir: workDir, - SQLitePath: filepath.Join(workDir, workDirDatabaseName), - BackupEnabled: backupEnabled, - BackupDir: filepath.Join(workDir, workDirBackupsName), - BackupInterval: backupInterval, - BackupRetentionDays: backupRetentionDays, - RequestTimeout: requestTimeout, - LogLevel: getString("LOG_LEVEL", "info"), - LogFileEnabled: logFileEnabled, - LogDir: filepath.Join(workDir, workDirLogsName), - LogRetentionDays: logRetentionDays, - AuthEnabled: authEnabled, - LoginPassword: strings.TrimSpace(os.Getenv("LOGIN_PASSWORD")), - AuthSessionTTL: authSessionTTL, + AppPort: getString("APP_PORT", "8080"), + AppBasePath: appBasePath, + CPABaseURL: strings.TrimSpace(os.Getenv("CPA_BASE_URL")), + CPAManagementKey: strings.TrimSpace(os.Getenv("CPA_MANAGEMENT_KEY")), + RedisQueueAddr: strings.TrimSpace(os.Getenv("REDIS_QUEUE_ADDR")), + RedisQueueKey: RedisQueueKeyDefault, + RedisQueueBatchSize: redisQueueBatchSize, + RedisQueueIdleInterval: redisQueueIdleInterval, + RedisQueueErrorBackoff: RedisQueueErrorBackoffDefault, + MetadataSyncInterval: MetadataSyncIntervalDefault, + WorkDir: workDir, + SQLitePath: filepath.Join(workDir, workDirDatabaseName), + BackupEnabled: backupEnabled, + BackupDir: filepath.Join(workDir, workDirBackupsName), + BackupInterval: backupInterval, + BackupRetentionDays: backupRetentionDays, + RequestTimeout: requestTimeout, + LogLevel: getString("LOG_LEVEL", "info"), + LogFileEnabled: logFileEnabled, + LogDir: filepath.Join(workDir, workDirLogsName), + LogRetentionDays: logRetentionDays, + AuthEnabled: authEnabled, + LoginPassword: strings.TrimSpace(os.Getenv("LOGIN_PASSWORD")), + AuthSessionTTL: authSessionTTL, } if cfg.CPABaseURL == "" { return nil, fmt.Errorf("CPA_BASE_URL is required") diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ec9a1e8a..4af74430 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -126,9 +126,6 @@ func TestLoadFromEnvAppliesDefaults(t *testing.T) { if cfg.AuthSessionTTL != 7*24*time.Hour { t.Fatalf("expected default auth session ttl 168h, got %s", cfg.AuthSessionTTL) } - if cfg.UsageSyncMode != "auto" { - t.Fatalf("expected default usage sync mode auto, got %s", cfg.UsageSyncMode) - } if cfg.RedisQueueAddr != "" { t.Fatalf("expected default redis queue addr to be empty, got %q", cfg.RedisQueueAddr) } @@ -138,17 +135,14 @@ func TestLoadFromEnvAppliesDefaults(t *testing.T) { if cfg.RedisQueueBatchSize != 1000 { t.Fatalf("expected default redis queue batch size 1000, got %d", cfg.RedisQueueBatchSize) } - if cfg.PollInterval != 5*time.Minute { - t.Fatalf("expected default legacy export poll interval 5m, got %s", cfg.PollInterval) - } if cfg.RedisQueueIdleInterval != time.Second { t.Fatalf("expected default redis queue idle interval 1s, got %s", cfg.RedisQueueIdleInterval) } if cfg.RedisQueueErrorBackoff != RedisQueueErrorBackoffDefault { t.Fatalf("expected default redis queue error backoff 10s, got %s", cfg.RedisQueueErrorBackoff) } - if cfg.RedisMetadataSyncInterval != RedisMetadataSyncIntervalDefault { - t.Fatalf("expected default redis metadata sync interval 30s, got %s", cfg.RedisMetadataSyncInterval) + if cfg.MetadataSyncInterval != MetadataSyncIntervalDefault { + t.Fatalf("expected default metadata sync interval 30s, got %s", cfg.MetadataSyncInterval) } if !cfg.LogFileEnabled { t.Fatal("expected log file output to be enabled by default") @@ -372,33 +366,15 @@ func TestLoadFromEnvRequiresCriticalValues(t *testing.T) { }) } -func TestLoadFromEnvUsesLegacyExportPollDefault(t *testing.T) { - t.Setenv("CPA_BASE_URL", "http://127.0.0.1:"+cpa.ManagementRedisDefaultPort) - t.Setenv("CPA_MANAGEMENT_KEY", "secret") - t.Setenv("USAGE_SYNC_MODE", "legacy_export") - - cfg, err := LoadFromEnv() - if err != nil { - t.Fatalf("LoadFromEnv returned error: %v", err) - } - - if cfg.PollInterval != 5*time.Minute { - t.Fatalf("expected legacy export poll interval 5m, got %s", cfg.PollInterval) - } -} - -func TestLoadFromEnvUsesExplicitPollIntervalOverride(t *testing.T) { +func TestLoadFromEnvIgnoresRemovedLegacySyncEnvVars(t *testing.T) { t.Setenv("CPA_BASE_URL", "http://127.0.0.1:"+cpa.ManagementRedisDefaultPort) t.Setenv("CPA_MANAGEMENT_KEY", "secret") - t.Setenv("POLL_INTERVAL", "1m") + t.Setenv("USAGE_SYNC_MODE", "invalid") + t.Setenv("POLL_INTERVAL", "not-a-duration") - cfg, err := LoadFromEnv() + _, err := LoadFromEnv() if err != nil { - t.Fatalf("LoadFromEnv returned error: %v", err) - } - - if cfg.PollInterval != time.Minute { - t.Fatalf("expected explicit poll interval 1m, got %s", cfg.PollInterval) + t.Fatalf("LoadFromEnv should ignore removed legacy sync env vars, got error: %v", err) } } @@ -431,17 +407,6 @@ func TestLoadFromEnvIgnoresRemovedRedisQueueKeyOverride(t *testing.T) { } } -func TestLoadFromEnvRejectsInvalidUsageSyncMode(t *testing.T) { - t.Setenv("CPA_BASE_URL", "http://127.0.0.1:"+cpa.ManagementRedisDefaultPort) - t.Setenv("CPA_MANAGEMENT_KEY", "secret") - t.Setenv("USAGE_SYNC_MODE", "invalid") - - _, err := LoadFromEnv() - if err == nil || err.Error() != "USAGE_SYNC_MODE must be one of auto, redis, legacy_export" { - t.Fatalf("expected usage sync mode validation error, got %v", err) - } -} - func TestLoadFromEnvRejectsNonPositiveRedisQueueBatchSize(t *testing.T) { t.Setenv("CPA_BASE_URL", "http://127.0.0.1:"+cpa.ManagementRedisDefaultPort) t.Setenv("CPA_MANAGEMENT_KEY", "secret") @@ -459,7 +424,6 @@ func TestLoadFromEnvParsesOverrides(t *testing.T) { t.Setenv("WORK_DIR", "/tmp/work") t.Setenv("APP_PORT", "9090") t.Setenv("APP_BASE_PATH", "/cpa/") - t.Setenv("POLL_INTERVAL", "1m") t.Setenv("BACKUP_ENABLED", "false") t.Setenv("BACKUP_INTERVAL", "2h") t.Setenv("BACKUP_RETENTION_DAYS", "7") @@ -477,7 +441,7 @@ func TestLoadFromEnvParsesOverrides(t *testing.T) { t.Fatalf("LoadFromEnv returned error: %v", err) } - if cfg.AppPort != "9090" || cfg.AppBasePath != "/cpa" || cfg.PollInterval != time.Minute || cfg.WorkDir != "/tmp/work" || cfg.SQLitePath != filepath.Join("/tmp/work", "app.db") || cfg.BackupEnabled || cfg.BackupDir != filepath.Join("/tmp/work", "backups") || cfg.BackupInterval != 2*time.Hour || cfg.BackupRetentionDays != 7 || cfg.RequestTimeout != 15*time.Second || cfg.LogLevel != "debug" || cfg.LogFileEnabled || cfg.LogDir != filepath.Join("/tmp/work", "logs") || cfg.LogRetentionDays != 14 || !cfg.AuthEnabled || cfg.LoginPassword != "top-secret" || cfg.AuthSessionTTL != 12*time.Hour || cfg.RedisQueueIdleInterval != 2*time.Second { + if cfg.AppPort != "9090" || cfg.AppBasePath != "/cpa" || cfg.WorkDir != "/tmp/work" || cfg.SQLitePath != filepath.Join("/tmp/work", "app.db") || cfg.BackupEnabled || cfg.BackupDir != filepath.Join("/tmp/work", "backups") || cfg.BackupInterval != 2*time.Hour || cfg.BackupRetentionDays != 7 || cfg.RequestTimeout != 15*time.Second || cfg.LogLevel != "debug" || cfg.LogFileEnabled || cfg.LogDir != filepath.Join("/tmp/work", "logs") || cfg.LogRetentionDays != 14 || !cfg.AuthEnabled || cfg.LoginPassword != "top-secret" || cfg.AuthSessionTTL != 12*time.Hour || cfg.RedisQueueIdleInterval != 2*time.Second { t.Fatalf("unexpected config override result: %+v", cfg) } } @@ -540,8 +504,8 @@ func TestLoadFromEnvIgnoresRemovedRedisDrainEnvOverrides(t *testing.T) { if err != nil { t.Fatalf("LoadFromEnv returned error: %v", err) } - if cfg.RedisQueueErrorBackoff != RedisQueueErrorBackoffDefault || cfg.RedisMetadataSyncInterval != RedisMetadataSyncIntervalDefault { - t.Fatalf("expected removed env overrides to be ignored, got error_backoff=%s metadata_interval=%s", cfg.RedisQueueErrorBackoff, cfg.RedisMetadataSyncInterval) + if cfg.RedisQueueErrorBackoff != RedisQueueErrorBackoffDefault || cfg.MetadataSyncInterval != MetadataSyncIntervalDefault { + t.Fatalf("expected removed env overrides to be ignored, got error_backoff=%s metadata_interval=%s", cfg.RedisQueueErrorBackoff, cfg.MetadataSyncInterval) } } diff --git a/internal/cpa/client.go b/internal/cpa/client.go index 3fd9068d..5571cd98 100644 --- a/internal/cpa/client.go +++ b/internal/cpa/client.go @@ -16,12 +16,6 @@ type Client struct { httpClient *http.Client } -type ExportResult struct { - StatusCode int - Body []byte - Payload UsageExport -} - func (c *Client) doJSONRequest(ctx context.Context, path string, target any, kind string, configure func(*http.Request)) (int, []byte, error) { if c == nil { return 0, nil, fmt.Errorf("cpa client is nil") @@ -80,17 +74,6 @@ func NewClient(baseURL, managementKey string, timeout time.Duration) *Client { } } -func (c *Client) FetchUsageExport(ctx context.Context) (*ExportResult, error) { - result := &ExportResult{} - statusCode, body, err := c.doManagementJSONRequest(ctx, cpaManagementUsageExportEndpoint, &result.Payload, "export") - result.StatusCode = statusCode - result.Body = body - if err != nil { - return result, err - } - return result, nil -} - func (c *Client) FetchExternalAPIKeys(ctx context.Context) (*ExternalAPIKeysResult, error) { result := &ExternalAPIKeysResult{} statusCode, body, err := c.doManagementJSONRequest(ctx, cpaManagementExternalAPIKeysEndpoint, &result.Payload, "external api keys") @@ -136,43 +119,93 @@ func (c *Client) FetchAuthFiles(ctx context.Context) (*AuthFilesResult, error) { } func (c *Client) FetchGeminiAPIKeys(ctx context.Context) (*ProviderKeyConfigResult, error) { - return c.fetchProviderKeyConfig(ctx, cpaManagementGeminiAPIKeyEndpoint, "gemini api keys") + return c.fetchProviderKeyConfig(ctx, cpaManagementGeminiAPIKeyEndpoint, "gemini-api-key", "gemini api keys") } func (c *Client) FetchClaudeAPIKeys(ctx context.Context) (*ProviderKeyConfigResult, error) { - return c.fetchProviderKeyConfig(ctx, cpaManagementClaudeAPIKeyEndpoint, "claude api keys") + return c.fetchProviderKeyConfig(ctx, cpaManagementClaudeAPIKeyEndpoint, "claude-api-key", "claude api keys") } func (c *Client) FetchCodexAPIKeys(ctx context.Context) (*ProviderKeyConfigResult, error) { - return c.fetchProviderKeyConfig(ctx, cpaManagementCodexAPIKeyEndpoint, "codex api keys") + return c.fetchProviderKeyConfig(ctx, cpaManagementCodexAPIKeyEndpoint, "codex-api-key", "codex api keys") } func (c *Client) FetchVertexAPIKeys(ctx context.Context) (*ProviderKeyConfigResult, error) { - return c.fetchProviderKeyConfig(ctx, cpaManagementVertexAPIKeyEndpoint, "vertex api keys") + return c.fetchProviderKeyConfig(ctx, cpaManagementVertexAPIKeyEndpoint, "vertex-api-key", "vertex api keys") } -func (c *Client) fetchProviderKeyConfig(ctx context.Context, path string, kind string) (*ProviderKeyConfigResult, error) { +func (c *Client) fetchProviderKeyConfig(ctx context.Context, path string, payloadKey string, kind string) (*ProviderKeyConfigResult, error) { result := &ProviderKeyConfigResult{} - statusCode, body, err := c.doManagementJSONRequest(ctx, path, &result.Payload, kind) + var raw json.RawMessage + statusCode, body, err := c.doManagementJSONRequest(ctx, path, &raw, kind) result.StatusCode = statusCode result.Body = body if err != nil { return result, err } + payload, err := decodeProviderKeyConfigPayload(raw, payloadKey) + if err != nil { + return result, fmt.Errorf("decode management %s json: %w", kind, err) + } + result.Payload = payload return result, nil } func (c *Client) FetchOpenAICompatibility(ctx context.Context) (*OpenAICompatibilityResult, error) { result := &OpenAICompatibilityResult{} - statusCode, body, err := c.doManagementJSONRequest(ctx, cpaManagementOpenAICompatibilityEndpoint, &result.Payload, "openai compatibility") + var raw json.RawMessage + statusCode, body, err := c.doManagementJSONRequest(ctx, cpaManagementOpenAICompatibilityEndpoint, &raw, "openai compatibility") result.StatusCode = statusCode result.Body = body if err != nil { return result, err } + payload, err := decodeOpenAICompatibilityPayload(raw, "openai-compatibility") + if err != nil { + return result, fmt.Errorf("decode management openai compatibility json: %w", err) + } + result.Payload = payload return result, nil } +func decodeProviderKeyConfigPayload(raw json.RawMessage, payloadKey string) ([]ProviderKeyConfig, error) { + var direct []ProviderKeyConfig + if err := json.Unmarshal(raw, &direct); err == nil { + return direct, nil + } + var wrapped map[string]json.RawMessage + if err := json.Unmarshal(raw, &wrapped); err != nil { + return nil, err + } + payloadRaw, ok := wrapped[payloadKey] + if !ok { + return nil, fmt.Errorf("missing %s payload", payloadKey) + } + if err := json.Unmarshal(payloadRaw, &direct); err != nil { + return nil, err + } + return direct, nil +} + +func decodeOpenAICompatibilityPayload(raw json.RawMessage, payloadKey string) ([]OpenAICompatibilityConfig, error) { + var direct []OpenAICompatibilityConfig + if err := json.Unmarshal(raw, &direct); err == nil { + return direct, nil + } + var wrapped map[string]json.RawMessage + if err := json.Unmarshal(raw, &wrapped); err != nil { + return nil, err + } + payloadRaw, ok := wrapped[payloadKey] + if !ok { + return nil, fmt.Errorf("missing %s payload", payloadKey) + } + if err := json.Unmarshal(payloadRaw, &direct); err != nil { + return nil, err + } + return direct, nil +} + func firstNonEmptyString(values []string) string { for _, value := range values { trimmed := strings.TrimSpace(value) diff --git a/internal/cpa/client_test.go b/internal/cpa/client_test.go index 48aa5556..1e45c7bf 100644 --- a/internal/cpa/client_test.go +++ b/internal/cpa/client_test.go @@ -8,53 +8,6 @@ import ( "time" ) -func TestFetchUsageExportSendsBearerToken(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer secret" { - t.Fatalf("expected Authorization header, got %q", got) - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version":1,"exported_at":"2026-04-16T00:00:00Z","usage":{}}`)) - })) - defer server.Close() - - client := NewClient(server.URL, "secret", 2*time.Second) - result, err := client.FetchUsageExport(context.Background()) - if err != nil { - t.Fatalf("FetchUsageExport returned error: %v", err) - } - if result.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d", result.StatusCode) - } -} - -func TestFetchUsageExportHandlesUnauthorized(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, `{"error":"invalid management key"}`, http.StatusUnauthorized) - })) - defer server.Close() - - client := NewClient(server.URL, "secret", 2*time.Second) - _, err := client.FetchUsageExport(context.Background()) - if err == nil { - t.Fatal("expected unauthorized error") - } -} - -func TestFetchUsageExportRejectsInvalidJSON(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`not-json`)) - })) - defer server.Close() - - client := NewClient(server.URL, "secret", 2*time.Second) - _, err := client.FetchUsageExport(context.Background()) - if err == nil { - t.Fatal("expected invalid json error") - } -} - func TestFetchExternalAPIKeysSendsBearerTokenAndParsesExternalKeys(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != cpaManagementExternalAPIKeysEndpoint { @@ -283,6 +236,88 @@ func TestProviderMetadataFetchersUseDedicatedEndpoints(t *testing.T) { } } +func TestProviderMetadataFetchersParseWrappedEndpointResponses(t *testing.T) { + tests := []struct { + name string + path string + fetch func(context.Context, *Client) (*ProviderKeyConfigResult, error) + response string + }{ + { + name: "gemini", + path: cpaManagementGeminiAPIKeyEndpoint, + fetch: func(ctx context.Context, client *Client) (*ProviderKeyConfigResult, error) { + return client.FetchGeminiAPIKeys(ctx) + }, + response: `{"gemini-api-key":[{"apiKey":"gemini-key","prefix":"gemini-prefix","name":"Gemini"}]}`, + }, + { + name: "claude", + path: cpaManagementClaudeAPIKeyEndpoint, + fetch: func(ctx context.Context, client *Client) (*ProviderKeyConfigResult, error) { + return client.FetchClaudeAPIKeys(ctx) + }, + response: `{"claude-api-key":[{"api-key":"claude-key","prefix":"claude-prefix","name":"Claude"}]}`, + }, + { + name: "codex", + path: cpaManagementCodexAPIKeyEndpoint, + fetch: func(ctx context.Context, client *Client) (*ProviderKeyConfigResult, error) { + return client.FetchCodexAPIKeys(ctx) + }, + response: `{"codex-api-key":[{"key":"codex-key","prefix":"codex-prefix","name":"Codex"}]}`, + }, + { + name: "vertex", + path: cpaManagementVertexAPIKeyEndpoint, + fetch: func(ctx context.Context, client *Client) (*ProviderKeyConfigResult, error) { + return client.FetchVertexAPIKeys(ctx) + }, + response: `{"vertex-api-key":[{"apiKey":"vertex-key","prefix":"vertex-prefix","name":"Vertex"}]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != tt.path { + t.Fatalf("unexpected path %q", r.URL.Path) + } + _, _ = w.Write([]byte(tt.response)) + })) + defer server.Close() + + client := NewClient(server.URL, "management-secret", 2*time.Second) + result, err := tt.fetch(context.Background(), client) + if err != nil { + t.Fatalf("fetch returned error: %v", err) + } + if len(result.Payload) != 1 || result.Payload[0].APIKey == "" || result.Payload[0].Prefix == "" || result.Payload[0].Name == "" { + t.Fatalf("unexpected wrapped provider payload: %#v", result.Payload) + } + }) + } +} + +func TestFetchOpenAICompatibilityParsesWrappedEndpointResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != cpaManagementOpenAICompatibilityEndpoint { + t.Fatalf("unexpected path %q", r.URL.Path) + } + _, _ = w.Write([]byte(`{"openai-compatibility":[{"id":"custom-openai","prefix":"custom","api-keys":["custom-key"]}]}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "management-secret", 2*time.Second) + result, err := client.FetchOpenAICompatibility(context.Background()) + if err != nil { + t.Fatalf("FetchOpenAICompatibility returned error: %v", err) + } + if len(result.Payload) != 1 || result.Payload[0].Name != "custom-openai" || result.Payload[0].Prefix != "custom" || len(result.Payload[0].APIKeyEntries) != 1 || result.Payload[0].APIKeyEntries[0].APIKey != "custom-key" { + t.Fatalf("unexpected wrapped openai compatibility payload: %#v", result.Payload) + } +} + func TestFetchOpenAICompatibilityUsesDedicatedEndpoint(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != cpaManagementOpenAICompatibilityEndpoint { diff --git a/internal/cpa/endpoints.go b/internal/cpa/endpoints.go index aaf1a7fe..c3ec2c4e 100644 --- a/internal/cpa/endpoints.go +++ b/internal/cpa/endpoints.go @@ -1,7 +1,6 @@ package cpa const ( - cpaManagementUsageExportEndpoint = "/v0/management/usage/export" cpaManagementAuthFilesEndpoint = "/v0/management/auth-files" cpaManagementExternalAPIKeysEndpoint = "/v0/management/api-keys" cpaManagementVertexAPIKeyEndpoint = "/v0/management/vertex-api-key" diff --git a/internal/cpa/redis_queue_client.go b/internal/cpa/redis_queue_client.go index 7159f34b..6c9f2f63 100644 --- a/internal/cpa/redis_queue_client.go +++ b/internal/cpa/redis_queue_client.go @@ -33,15 +33,6 @@ func NewRedisQueueClient(baseURL, redisQueueAddr, managementKey string, timeout } } -func (c *RedisQueueClient) Probe(ctx context.Context) error { - conn, _, err := c.openAuthenticatedConnection(ctx) - if err != nil { - return err - } - _ = conn.Close() - return nil -} - func (c *RedisQueueClient) PopUsage(ctx context.Context) ([]string, error) { if c == nil { return nil, fmt.Errorf("redis queue client is nil") diff --git a/internal/cpa/redis_queue_client_test.go b/internal/cpa/redis_queue_client_test.go index adc02121..010b3e43 100644 --- a/internal/cpa/redis_queue_client_test.go +++ b/internal/cpa/redis_queue_client_test.go @@ -70,47 +70,6 @@ func TestRedisQueueClientClassifiesAuthErrors(t *testing.T) { } } -func TestRedisQueueClientProbeAuthenticatesWithoutPopping(t *testing.T) { - server := newRedisQueueTestServer(t, func(t *testing.T, conn net.Conn) { - reader := bufio.NewReader(conn) - if got := readRESPCommand(t, reader); strings.Join(got, " ") != cpaManagementRedisAuthCommand+" secret" { - t.Fatalf("unexpected auth command: %v", got) - } - fmt.Fprint(conn, "+OK\r\n") - if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatalf("set read deadline: %v", err) - } - line, err := reader.ReadString('\n') - if err == nil { - t.Fatalf("expected probe to close without pop command, got %q", line) - } - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - t.Fatal("probe left connection open waiting for another command") - } - }) - - client := NewRedisQueueClient(server.URL, "", "secret", time.Second, ManagementUsageQueueKey, 2) - if err := client.Probe(ctxWithTimeout(t)); err != nil { - t.Fatalf("Probe returned error: %v", err) - } -} - -func TestRedisQueueClientProbeClassifiesAuthErrors(t *testing.T) { - server := newRedisQueueTestServer(t, func(t *testing.T, conn net.Conn) { - readRESPCommand(t, bufio.NewReader(conn)) - fmt.Fprint(conn, "-ERR invalid password\r\n") - }) - - client := NewRedisQueueClient(server.URL, "", "wrong", time.Second, ManagementUsageQueueKey, 1000) - err := client.Probe(ctxWithTimeout(t)) - if err == nil { - t.Fatal("expected auth error") - } - if !errors.Is(err, ErrRedisQueueAuth) { - t.Fatalf("expected ErrRedisQueueAuth, got %v", err) - } -} - func TestRedisQueueClientPrefersExplicitQueueAddr(t *testing.T) { if got := redisQueueAddress("https://cpa.example.com", "redis-stream.example.com:6380"); got != "redis-stream.example.com:6380" { t.Fatalf("expected explicit redis queue addr, got %q", got) diff --git a/internal/cpa/types.go b/internal/cpa/types.go index 2fcb22a5..dd666952 100644 --- a/internal/cpa/types.go +++ b/internal/cpa/types.go @@ -6,12 +6,6 @@ import ( "time" ) -type UsageExport struct { - Version int `json:"version"` - ExportedAt time.Time `json:"exported_at"` - Usage StatisticsSnapshot `json:"usage"` -} - type StatisticsSnapshot struct { TotalRequests int64 `json:"total_requests"` SuccessCount int64 `json:"success_count"` diff --git a/internal/models/models.go b/internal/models/models.go index 24b13f02..eecfc7e1 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -1,34 +1,15 @@ package models -import ( - "time" - - "gorm.io/gorm" -) - -type SnapshotRun struct { - ID uint `gorm:"primaryKey"` - FetchedAt time.Time `gorm:"index:idx_snapshot_runs_fetched_at"` - CPABaseURL string - ExportedAt *time.Time - Version string - Status string `gorm:"index:idx_snapshot_runs_status"` - HTTPStatus int - PayloadHash string - RawPayload []byte - BackupFilePath string - ErrorMessage string - InsertedEvents int - DedupedEvents int - CreatedAt time.Time - UpdatedAt time.Time -} +import "time" type UsageEvent struct { - ID uint `gorm:"primaryKey"` - EventKey string `gorm:"uniqueIndex:uniq_usage_events_event_key"` - SnapshotRunID uint + ID uint `gorm:"primaryKey"` + EventKey string `gorm:"uniqueIndex:uniq_usage_events_event_key"` APIGroupKey string `gorm:"index:idx_usage_events_api_group_key"` + Provider string `gorm:"column:provider"` + Endpoint string `gorm:"column:endpoint"` + AuthType string `gorm:"column:auth_type"` + RequestID string `gorm:"column:request_id"` Model string `gorm:"index:idx_usage_events_model"` Timestamp time.Time `gorm:"index:idx_usage_events_timestamp"` Source string `gorm:"index:idx_usage_events_source"` @@ -51,7 +32,6 @@ type RedisUsageInbox struct { Status string `gorm:"not null;index"` AttemptCount int `gorm:"not null;default:0"` LastError string - SnapshotRunID *uint `gorm:"index"` UsageEventKey string `gorm:"index"` PoppedAt time.Time `gorm:"not null;index"` ProcessedAt *time.Time @@ -59,36 +39,6 @@ type RedisUsageInbox struct { UpdatedAt time.Time } -type AuthFile struct { - ID uint `gorm:"primaryKey"` - AuthIndex string `gorm:"uniqueIndex:uniq_auth_files_auth_index"` - Name string - Email string - Type string - Provider string - Label string - Status string - Source string - Disabled bool - Unavailable bool - RuntimeOnly bool - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index"` -} - -type ProviderMetadata struct { - ID uint `gorm:"primaryKey"` - LookupKey string `gorm:"uniqueIndex:uniq_provider_metadata_lookup_key"` - ProviderType string `gorm:"index:idx_provider_metadata_provider_type"` - DisplayName string - ProviderKey string `gorm:"index:idx_provider_metadata_provider_key"` - MatchKind string - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index"` -} - type ModelPriceSetting struct { ID uint `gorm:"primaryKey"` Model string `gorm:"uniqueIndex:uniq_model_price_settings_model"` @@ -99,13 +49,47 @@ type ModelPriceSetting struct { UpdatedAt time.Time } +type UsageIdentityAuthType int + +const ( + UsageIdentityAuthTypeAuthFile UsageIdentityAuthType = 1 + UsageIdentityAuthTypeAIProvider UsageIdentityAuthType = 2 +) + +type UsageIdentity struct { + ID uint `gorm:"primaryKey"` + Name string + AuthType UsageIdentityAuthType `gorm:"uniqueIndex:uniq_usage_identities_type_identity;index:idx_usage_identities_auth_type"` + AuthTypeName string `gorm:"index:idx_usage_identities_auth_type_name"` + Identity string `gorm:"uniqueIndex:uniq_usage_identities_type_identity;index:idx_usage_identities_identity"` + Type string `gorm:"column:type"` + Provider string + + TotalRequests int64 + SuccessCount int64 + FailureCount int64 + InputTokens int64 + OutputTokens int64 + ReasoningTokens int64 + CachedTokens int64 + TotalTokens int64 + + LastAggregatedUsageEventID uint `gorm:"index:idx_usage_identities_last_aggregated_usage_event_id"` + FirstUsedAt *time.Time + LastUsedAt *time.Time + StatsUpdatedAt *time.Time + + IsDeleted bool `gorm:"index:idx_usage_identities_is_deleted"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `gorm:"index:idx_usage_identities_deleted_at"` +} + func All() []any { return []any{ - &SnapshotRun{}, &UsageEvent{}, &RedisUsageInbox{}, - &AuthFile{}, - &ProviderMetadata{}, &ModelPriceSetting{}, + &UsageIdentity{}, } } diff --git a/internal/models/models_test.go b/internal/models/models_test.go index 25f46fee..e411cf9b 100644 --- a/internal/models/models_test.go +++ b/internal/models/models_test.go @@ -4,10 +4,13 @@ import "testing" func TestAllIncludesCoreModels(t *testing.T) { items := All() - if len(items) != 6 { - t.Fatalf("expected 6 models, got %d", len(items)) + if len(items) != 4 { + t.Fatalf("expected 4 models after removing legacy metadata tables, got %d", len(items)) } - if _, ok := items[2].(*RedisUsageInbox); !ok { - t.Fatalf("expected RedisUsageInbox to be registered, got %T", items[2]) + if _, ok := items[0].(*UsageEvent); !ok { + t.Fatalf("expected UsageEvent to be first registered model, got %T", items[0]) + } + if _, ok := items[1].(*RedisUsageInbox); !ok { + t.Fatalf("expected RedisUsageInbox to be registered, got %T", items[1]) } } diff --git a/internal/poller/poller.go b/internal/poller/poller.go index 60675611..2c119b67 100644 --- a/internal/poller/poller.go +++ b/internal/poller/poller.go @@ -3,39 +3,12 @@ package poller import ( "context" "errors" - "fmt" - "sync" "time" - - "github.com/sirupsen/logrus" ) -type Syncer interface { - SyncOnce(ctx context.Context) error -} - -type StatusSyncer interface { - SyncStatus(ctx context.Context) (string, error) -} - var ErrSyncAlreadyRunning = errors.New("sync already running") var ErrSyncCompletedWithWarnings = errors.New("sync completed with warnings") -type Poller struct { - interval time.Duration - syncer Syncer - ticker func(time.Duration) ticker - now func() time.Time - - mu sync.Mutex - running bool - lastRunAt time.Time - lastError string - lastWarning string - lastStatus string - syncRunning bool -} - type Status struct { Running bool LastRunAt time.Time @@ -45,137 +18,6 @@ type Status struct { SyncRunning bool } -type ticker interface { - Chan() <-chan time.Time - Stop() -} - -type realTicker struct { - inner *time.Ticker -} - -func New(syncer Syncer, interval time.Duration) *Poller { - return &Poller{ - interval: interval, - syncer: syncer, - ticker: func(d time.Duration) ticker { - return realTicker{inner: time.NewTicker(d)} - }, - now: time.Now, - } -} - -func (t realTicker) Chan() <-chan time.Time { return t.inner.C } -func (t realTicker) Stop() { t.inner.Stop() } - -func (p *Poller) Run(ctx context.Context) error { - if p == nil { - return fmt.Errorf("poller is nil") - } - if p.syncer == nil { - return fmt.Errorf("poller syncer is nil") - } - if p.interval <= 0 { - return fmt.Errorf("poll interval must be greater than zero") - } - if p.ticker == nil { - p.ticker = func(d time.Duration) ticker { return realTicker{inner: time.NewTicker(d)} } - } - if p.now == nil { - p.now = time.Now - } - - logrus.WithField("interval", p.interval.String()).Info("legacy export poller task started") - p.setRunning(true) - defer p.setRunning(false) - - p.runBackgroundSync(ctx) - - t := p.ticker(p.interval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - return nil - case <-t.Chan(): - p.runBackgroundSync(ctx) - } - } -} - -func (p *Poller) runBackgroundSync(ctx context.Context) { - if err := p.runSync(ctx); shouldLogSyncError(err) { - logrus.WithError(err).Error("poller sync failed") - } -} - func shouldLogSyncError(err error) bool { return err != nil && !errors.Is(err, ErrSyncCompletedWithWarnings) && !errors.Is(err, ErrSyncAlreadyRunning) && !errors.Is(err, context.Canceled) } - -func (p *Poller) Status() Status { - p.mu.Lock() - defer p.mu.Unlock() - return Status{ - Running: p.running, - LastRunAt: p.lastRunAt, - LastError: p.lastError, - LastWarning: p.lastWarning, - LastStatus: p.lastStatus, - SyncRunning: p.syncRunning, - } -} - -func (p *Poller) SyncNow(ctx context.Context) error { - return p.runSync(ctx) -} - -func (p *Poller) runSync(ctx context.Context) error { - p.mu.Lock() - if p.syncRunning { - p.mu.Unlock() - return ErrSyncAlreadyRunning - } - p.syncRunning = true - p.mu.Unlock() - - defer func() { - p.mu.Lock() - p.syncRunning = false - p.mu.Unlock() - }() - - lastStatus := "" - var err error - if statusSyncer, ok := p.syncer.(StatusSyncer); ok { - lastStatus, err = statusSyncer.SyncStatus(ctx) - } else { - err = p.syncer.SyncOnce(ctx) - } - - warningResult := err != nil && lastStatus != "" && lastStatus != "failed" - p.mu.Lock() - p.lastRunAt = p.now().UTC() - p.lastStatus = lastStatus - p.lastWarning = "" - if warningResult { - p.lastWarning = err.Error() - p.lastError = "" - } else if err != nil { - p.lastError = err.Error() - } else { - p.lastError = "" - } - p.mu.Unlock() - if warningResult { - return fmt.Errorf("%w: %v", ErrSyncCompletedWithWarnings, err) - } - return err -} - -func (p *Poller) setRunning(running bool) { - p.mu.Lock() - defer p.mu.Unlock() - p.running = running -} diff --git a/internal/poller/poller_test.go b/internal/poller/poller_test.go deleted file mode 100644 index ff5d401f..00000000 --- a/internal/poller/poller_test.go +++ /dev/null @@ -1,345 +0,0 @@ -package poller - -import ( - "bytes" - "context" - "errors" - "strings" - "sync" - "testing" - "time" - - "github.com/sirupsen/logrus" -) - -type syncStub struct { - mu sync.Mutex - calls int - err error - started chan struct{} - release chan struct{} -} - -type syncResultStub struct { - status string - err error -} - -func (s *syncStub) SyncOnce(ctx context.Context) error { - s.mu.Lock() - s.calls++ - s.mu.Unlock() - if s.started != nil { - s.started <- struct{}{} - } - if s.release != nil { - select { - case <-s.release: - case <-ctx.Done(): - return ctx.Err() - } - } - return s.err -} - -func (s *syncResultStub) SyncOnce(context.Context) error { - return s.err -} - -func (s *syncResultStub) SyncStatus(context.Context) (string, error) { - return s.status, s.err -} - -func (s *syncStub) CallCount() int { - s.mu.Lock() - defer s.mu.Unlock() - return s.calls -} - -type fakeTicker struct { - ch chan time.Time - stopped bool -} - -func (t *fakeTicker) Chan() <-chan time.Time { return t.ch } -func (t *fakeTicker) Stop() { t.stopped = true } - -func TestRunLogsPollerStart(t *testing.T) { - logs := capturePollerLogrusOutput(t) - syncer := &syncStub{} - ft := &fakeTicker{ch: make(chan time.Time)} - p := New(syncer, time.Minute) - p.ticker = func(time.Duration) ticker { return ft } - - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan error, 1) - go func() { done <- p.Run(ctx) }() - waitFor(t, func() bool { return syncer.CallCount() == 1 }) - cancel() - if err := <-done; err != nil { - t.Fatalf("poller returned error: %v", err) - } - - content := logs.String() - if !strings.Contains(content, "level=info") || !strings.Contains(content, "msg=\"legacy export poller task started\"") { - t.Fatalf("expected poller start info log, got %q", content) - } -} - -func TestRunExecutesImmediateAndScheduledSyncs(t *testing.T) { - syncer := &syncStub{} - ft := &fakeTicker{ch: make(chan time.Time, 2)} - p := New(syncer, time.Minute) - p.ticker = func(time.Duration) ticker { return ft } - p.now = func() time.Time { return time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) } - - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan error, 1) - go func() { done <- p.Run(ctx) }() - - waitFor(t, func() bool { return syncer.CallCount() == 1 }) - ft.ch <- time.Now() - waitFor(t, func() bool { return syncer.CallCount() == 2 }) - cancel() - - if err := <-done; err != nil { - t.Fatalf("Run returned error: %v", err) - } - status := p.Status() - if status.Running { - t.Fatal("expected poller to stop after context cancellation") - } - if status.LastRunAt.IsZero() { - t.Fatal("expected LastRunAt to be set") - } -} - -func TestRunContinuesAfterSyncFailure(t *testing.T) { - syncer := &syncStub{err: errors.New("boom")} - ft := &fakeTicker{ch: make(chan time.Time, 1)} - p := New(syncer, time.Minute) - p.ticker = func(time.Duration) ticker { return ft } - p.now = time.Now - - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan error, 1) - go func() { done <- p.Run(ctx) }() - - waitFor(t, func() bool { return syncer.CallCount() == 1 }) - ft.ch <- time.Now() - waitFor(t, func() bool { return syncer.CallCount() == 2 }) - cancel() - <-done - - status := p.Status() - if status.LastError != "boom" { - t.Fatalf("expected last error to be recorded, got %q", status.LastError) - } -} - -func TestPollerRunLogsSyncFailure(t *testing.T) { - logs := capturePollerLogrusOutput(t) - syncer := &syncStub{err: errors.New("boom")} - ft := &fakeTicker{ch: make(chan time.Time, 1)} - p := New(syncer, time.Minute) - p.ticker = func(time.Duration) ticker { return ft } - p.now = time.Now - - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan error, 1) - go func() { done <- p.Run(ctx) }() - - waitFor(t, func() bool { return syncer.CallCount() == 1 }) - cancel() - if err := <-done; err != nil { - t.Fatalf("Run returned error: %v", err) - } - - content := logs.String() - for _, want := range []string{ - "level=error", - "msg=\"poller sync failed\"", - "error=boom", - } { - if !strings.Contains(content, want) { - t.Fatalf("expected log output to contain %q, got %q", want, content) - } - } -} - -func TestPollerRunLogsFailedStatusSyncFailure(t *testing.T) { - logs := capturePollerLogrusOutput(t) - syncer := &syncResultStub{ - status: "failed", - err: errors.New("fetch usage export: unavailable"), - } - ft := &fakeTicker{ch: make(chan time.Time, 1)} - p := New(syncer, time.Minute) - p.ticker = func(time.Duration) ticker { return ft } - p.now = time.Now - - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan error, 1) - go func() { done <- p.Run(ctx) }() - - waitFor(t, func() bool { return p.Status().LastError == "fetch usage export: unavailable" }) - cancel() - if err := <-done; err != nil { - t.Fatalf("Run returned error: %v", err) - } - - content := logs.String() - for _, want := range []string{ - "level=error", - "msg=\"poller sync failed\"", - "error=\"fetch usage export: unavailable\"", - } { - if !strings.Contains(content, want) { - t.Fatalf("expected log output to contain %q, got %q", want, content) - } - } -} - -func TestPollerRunDoesNotErrorLogContextCancellation(t *testing.T) { - logs := capturePollerLogrusOutput(t) - syncer := &syncStub{err: context.Canceled} - ft := &fakeTicker{ch: make(chan time.Time, 1)} - p := New(syncer, time.Minute) - p.ticker = func(time.Duration) ticker { return ft } - p.now = time.Now - - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan error, 1) - go func() { done <- p.Run(ctx) }() - - waitFor(t, func() bool { return syncer.CallCount() == 1 }) - cancel() - if err := <-done; err != nil { - t.Fatalf("Run returned error: %v", err) - } - if strings.Contains(logs.String(), "level=error") { - t.Fatalf("did not expect error log for context cancellation, got %q", logs.String()) - } -} - -func TestPollerRunDoesNotErrorLogAlreadyRunning(t *testing.T) { - logs := capturePollerLogrusOutput(t) - p := New(&syncStub{}, time.Minute) - p.syncRunning = true - - p.runBackgroundSync(context.Background()) - - if strings.Contains(logs.String(), "level=error") { - t.Fatalf("did not expect error log for already-running sync, got %q", logs.String()) - } -} - -func TestPollerRunDoesNotErrorLogCompletedWithWarnings(t *testing.T) { - logs := capturePollerLogrusOutput(t) - syncer := &syncResultStub{ - status: "completed_with_warnings", - err: errors.New("metadata unavailable"), - } - ft := &fakeTicker{ch: make(chan time.Time, 1)} - p := New(syncer, time.Minute) - p.ticker = func(time.Duration) ticker { return ft } - p.now = time.Now - - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan error, 1) - go func() { done <- p.Run(ctx) }() - - waitFor(t, func() bool { return p.Status().LastWarning == "metadata unavailable" }) - cancel() - if err := <-done; err != nil { - t.Fatalf("Run returned error: %v", err) - } - if strings.Contains(logs.String(), "level=error") { - t.Fatalf("did not expect error log for warning result, got %q", logs.String()) - } -} - -func TestStatusRecordsCompletedWithWarningsResult(t *testing.T) { - syncer := &syncResultStub{ - status: "completed_with_warnings", - err: errors.New("fetch provider metadata: unavailable"), - } - p := New(syncer, time.Minute) - p.now = func() time.Time { return time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) } - - p.runSync(context.Background()) - - status := p.Status() - if status.LastStatus != "completed_with_warnings" { - t.Fatalf("expected completed_with_warnings status, got %+v", status) - } - if status.LastError != "" || status.LastWarning != "fetch provider metadata: unavailable" { - t.Fatalf("expected partial sync error to be recorded as warning, got %+v", status) - } -} - -func TestSyncNowSkipsOverlappingSyncs(t *testing.T) { - syncer := &syncStub{ - started: make(chan struct{}, 1), - release: make(chan struct{}, 1), - } - p := New(syncer, time.Minute) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - firstSyncDone := make(chan error, 1) - go func() { firstSyncDone <- p.SyncNow(ctx) }() - - select { - case <-syncer.started: - case <-time.After(time.Second): - t.Fatal("expected initial sync to start") - } - - if err := p.SyncNow(ctx); !errors.Is(err, ErrSyncAlreadyRunning) { - t.Fatalf("expected overlapping sync to be skipped, got %v", err) - } - if syncer.CallCount() != 1 { - t.Fatalf("expected no overlapping sync runs, got %d calls", syncer.CallCount()) - } - - syncer.release <- struct{}{} - select { - case err := <-firstSyncDone: - if err != nil { - t.Fatalf("initial sync returned error: %v", err) - } - case <-time.After(time.Second): - t.Fatal("expected initial sync to finish") - } -} - -func capturePollerLogrusOutput(t *testing.T) *bytes.Buffer { - t.Helper() - previousOutput := logrus.StandardLogger().Out - previousFormatter := logrus.StandardLogger().Formatter - previousLevel := logrus.GetLevel() - var logs bytes.Buffer - logrus.SetOutput(&logs) - logrus.SetFormatter(&logrus.TextFormatter{DisableTimestamp: true}) - logrus.SetLevel(logrus.DebugLevel) - t.Cleanup(func() { - logrus.SetOutput(previousOutput) - logrus.SetFormatter(previousFormatter) - logrus.SetLevel(previousLevel) - }) - return &logs -} - -func waitFor(t *testing.T, check func() bool) { - t.Helper() - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - if check() { - return - } - time.Sleep(10 * time.Millisecond) - } - t.Fatal("condition not met before timeout") -} diff --git a/internal/poller/redis_drain.go b/internal/poller/redis_drain.go index e74bad3f..a8243d72 100644 --- a/internal/poller/redis_drain.go +++ b/internal/poller/redis_drain.go @@ -7,7 +7,6 @@ import ( "sync" "time" - "cpa-usage-keeper/internal/cpa" "cpa-usage-keeper/internal/service" "github.com/sirupsen/logrus" ) @@ -17,14 +16,12 @@ const redisInboxProcessInterval = 5 * time.Second type RedisBatchSyncer interface { PullRedisUsageInbox(ctx context.Context) (*service.RedisInboxPullResult, error) - ProcessRedisUsageInbox(ctx context.Context, syncMetadata bool) (*service.RedisBatchSyncResult, error) - SyncMetadata(ctx context.Context) error + ProcessRedisUsageInbox(ctx context.Context) (*service.RedisBatchSyncResult, error) } type RedisDrainConfig struct { - IdleInterval time.Duration - ErrorBackoff time.Duration - MetadataInterval time.Duration + IdleInterval time.Duration + ErrorBackoff time.Duration } type RedisDrain struct { @@ -33,15 +30,14 @@ type RedisDrain struct { now func() time.Time sleep func(context.Context, time.Duration) bool - mu sync.Mutex - running bool - lastRunAt time.Time - lastError string - lastWarning string - lastStatus string - pullRunning bool - processRunning bool - lastMetadataSyncAt time.Time + mu sync.Mutex + running bool + lastRunAt time.Time + lastError string + lastWarning string + lastStatus string + pullRunning bool + processRunning bool } func NewRedisDrain(syncer RedisBatchSyncer, cfg RedisDrainConfig) *RedisDrain { @@ -76,7 +72,7 @@ func (d *RedisDrain) Run(ctx context.Context) error { return nil } -// runPullLoop 只从 CPA Redis 队列 LPOP 数据并写入 redis_usage_inboxes,不解码、不写 usage_events、不创建 snapshot_runs。 +// runPullLoop 只从 CPA Redis 队列 LPOP 数据并写入 redis_usage_inboxes,不解码、不写 usage_events。 func (d *RedisDrain) runPullLoop(ctx context.Context) { logrus.WithField("idle_interval", d.config.IdleInterval.String()).Info("redis inbox pull task started") for { @@ -110,24 +106,18 @@ func (d *RedisDrain) runProcessLoop(ctx context.Context) { if !d.sleep(ctx, redisInboxProcessInterval) { return } - syncMetadata := d.shouldSyncMetadata() - result, err := d.runRedisProcess(ctx, syncMetadata) + result, err := d.runRedisProcess(ctx) if err != nil && !errors.Is(err, ErrSyncCompletedWithWarnings) { if shouldLogSyncError(err) { - d.logBatchFailure(result, syncMetadata, err) + d.logBatchFailure(result, err) } continue } - if syncMetadata && result != nil && (!result.Empty || errors.Is(err, ErrSyncCompletedWithWarnings)) { - d.setLastMetadataSyncAt(d.now().UTC()) - } } } -func (d *RedisDrain) logBatchFailure(result *service.RedisBatchSyncResult, syncMetadata bool, err error) { +func (d *RedisDrain) logBatchFailure(result *service.RedisBatchSyncResult, err error) { fields := logrus.Fields{ - "sync_metadata": syncMetadata, - "auth_error": errors.Is(err, cpa.ErrRedisQueueAuth), "status": "", "empty": false, "inserted_events": 0, @@ -163,7 +153,7 @@ func (d *RedisDrain) SyncNow(ctx context.Context) error { if _, err := d.runRedisPull(ctx); err != nil { return err } - _, err := d.runRedisProcess(ctx, true) + _, err := d.runRedisProcess(ctx) return err } @@ -189,7 +179,7 @@ func (d *RedisDrain) runRedisPull(ctx context.Context) (*service.RedisInboxPullR } // runRedisProcess 只防止 Process 自身重入,不阻塞 Pull;Process 的输入必须来自已持久化的 redis_usage_inboxes。 -func (d *RedisDrain) runRedisProcess(ctx context.Context, syncMetadata bool) (*service.RedisBatchSyncResult, error) { +func (d *RedisDrain) runRedisProcess(ctx context.Context) (*service.RedisBatchSyncResult, error) { d.mu.Lock() if d.processRunning { d.mu.Unlock() @@ -204,20 +194,12 @@ func (d *RedisDrain) runRedisProcess(ctx context.Context, syncMetadata bool) (*s d.mu.Unlock() }() - result, err := d.syncer.ProcessRedisUsageInbox(ctx, syncMetadata) + result, err := d.syncer.ProcessRedisUsageInbox(ctx) returnErr := err if err != nil && result != nil && result.Status != "" && result.Status != "failed" { returnErr = fmt.Errorf("%w: %v", ErrSyncCompletedWithWarnings, err) } d.recordResult(result, err) - if err == nil && result != nil && result.Empty && syncMetadata { - metadataErr := d.syncer.SyncMetadata(ctx) - if metadataErr != nil { - d.recordMetadataWarning(metadataErr) - return result, fmt.Errorf("%w: %v", ErrSyncCompletedWithWarnings, metadataErr) - } - d.setLastMetadataSyncAt(d.now().UTC()) - } return result, returnErr } @@ -240,13 +222,6 @@ func (d *RedisDrain) recordPullResult(result *service.RedisInboxPullResult, err } } -func (d *RedisDrain) recordMetadataWarning(err error) { - d.mu.Lock() - defer d.mu.Unlock() - d.lastWarning = err.Error() - d.lastError = "" -} - func (d *RedisDrain) recordResult(result *service.RedisBatchSyncResult, err error) { d.mu.Lock() defer d.mu.Unlock() @@ -270,19 +245,6 @@ func (d *RedisDrain) recordResult(result *service.RedisBatchSyncResult, err erro } } -func (d *RedisDrain) shouldSyncMetadata() bool { - d.mu.Lock() - last := d.lastMetadataSyncAt - d.mu.Unlock() - return last.IsZero() || d.now().UTC().Sub(last.UTC()) >= d.config.MetadataInterval -} - -func (d *RedisDrain) setLastMetadataSyncAt(t time.Time) { - d.mu.Lock() - defer d.mu.Unlock() - d.lastMetadataSyncAt = t.UTC() -} - func (d *RedisDrain) setRunning(running bool) { d.mu.Lock() defer d.mu.Unlock() @@ -302,9 +264,6 @@ func (d *RedisDrain) validate() error { if d.config.ErrorBackoff <= 0 { return fmt.Errorf("redis drain error backoff must be greater than zero") } - if d.config.MetadataInterval <= 0 { - return fmt.Errorf("redis drain metadata interval must be greater than zero") - } if d.now == nil { d.now = time.Now } diff --git a/internal/poller/redis_drain_test.go b/internal/poller/redis_drain_test.go index 1d695bf6..1f63f6ec 100644 --- a/internal/poller/redis_drain_test.go +++ b/internal/poller/redis_drain_test.go @@ -19,13 +19,10 @@ type redisDrainSyncStub struct { pullErrs []error processResults []*service.RedisBatchSyncResult processErrs []error - metadataFlags []bool pullStarted chan struct{} releasePull chan struct{} pullCalls int processCalls int - metadataCalls int - metadataErr error } func (s *redisDrainSyncStub) PullRedisUsageInbox(context.Context) (*service.RedisInboxPullResult, error) { @@ -56,11 +53,10 @@ func (s *redisDrainSyncStub) PullRedisUsageInbox(context.Context) (*service.Redi return result, err } -func (s *redisDrainSyncStub) ProcessRedisUsageInbox(ctx context.Context, syncMetadata bool) (*service.RedisBatchSyncResult, error) { +func (s *redisDrainSyncStub) ProcessRedisUsageInbox(ctx context.Context) (*service.RedisBatchSyncResult, error) { s.mu.Lock() s.processCalls++ call := s.processCalls - s.metadataFlags = append(s.metadataFlags, syncMetadata) result := &service.RedisBatchSyncResult{Status: "completed", InsertedEvents: 1} if len(s.processResults) >= call { result = s.processResults[call-1] @@ -82,23 +78,10 @@ func (s *redisDrainSyncStub) ProcessRedisUsageInbox(ctx context.Context, syncMet return result, err } -func (s *redisDrainSyncStub) SyncMetadata(context.Context) error { +func (s *redisDrainSyncStub) counts() (int, int) { s.mu.Lock() defer s.mu.Unlock() - s.metadataCalls++ - return s.metadataErr -} - -func (s *redisDrainSyncStub) counts() (int, int, int, int) { - s.mu.Lock() - defer s.mu.Unlock() - return s.pullCalls, s.processCalls, 0, s.metadataCalls -} - -func (s *redisDrainSyncStub) flags() []bool { - s.mu.Lock() - defer s.mu.Unlock() - return append([]bool(nil), s.metadataFlags...) + return s.pullCalls, s.processCalls } func captureRedisDrainLogrusOutput(t *testing.T) *bytes.Buffer { @@ -121,7 +104,7 @@ func captureRedisDrainLogrusOutput(t *testing.T) *bytes.Buffer { func TestRedisDrainLoopsLogTaskStarts(t *testing.T) { logs := captureRedisDrainLogrusOutput(t) syncer := &redisDrainSyncStub{pullResults: []*service.RedisInboxPullResult{{Empty: true, Status: "empty"}}} - drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour, MetadataInterval: time.Hour}) + drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour}) pullCtx, cancelPull := context.WithCancel(context.Background()) drain.sleep = func(context.Context, time.Duration) bool { @@ -147,7 +130,7 @@ func TestRedisDrainLoopsLogTaskStarts(t *testing.T) { func TestRedisDrainPullLoopDoesNotProcessInbox(t *testing.T) { syncer := &redisDrainSyncStub{pullResults: []*service.RedisInboxPullResult{{Empty: true, Status: "empty"}}} - drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour, MetadataInterval: time.Hour}) + drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour}) ctx, cancel := context.WithCancel(context.Background()) drain.sleep = func(context.Context, time.Duration) bool { cancel() @@ -156,7 +139,7 @@ func TestRedisDrainPullLoopDoesNotProcessInbox(t *testing.T) { drain.runPullLoop(ctx) - pulls, processes, _, _ := syncer.counts() + pulls, processes := syncer.counts() if pulls != 1 || processes != 0 { t.Fatalf("expected pull loop to pull once and not process inbox, got pulls=%d processes=%d", pulls, processes) } @@ -164,7 +147,7 @@ func TestRedisDrainPullLoopDoesNotProcessInbox(t *testing.T) { func TestRedisDrainProcessLoopUsesFixedInterval(t *testing.T) { syncer := &redisDrainSyncStub{} - drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour, MetadataInterval: time.Hour}) + drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour}) ctx, cancel := context.WithCancel(context.Background()) calls := 0 drain.sleep = func(_ context.Context, d time.Duration) bool { @@ -181,7 +164,7 @@ func TestRedisDrainProcessLoopUsesFixedInterval(t *testing.T) { drain.runProcessLoop(ctx) - _, processes, _, _ := syncer.counts() + _, processes := syncer.counts() if processes != 1 { t.Fatalf("expected process loop to process once, got %d", processes) } @@ -192,25 +175,21 @@ func TestRedisDrainProcessLoopUsesFixedInterval(t *testing.T) { func TestRedisDrainSyncNowPullsThenProcesses(t *testing.T) { syncer := &redisDrainSyncStub{} - drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour, MetadataInterval: time.Hour}) + drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour}) if err := drain.SyncNow(context.Background()); err != nil { t.Fatalf("SyncNow returned error: %v", err) } - pulls, processes, _, _ := syncer.counts() + pulls, processes := syncer.counts() if pulls != 1 || processes != 1 { t.Fatalf("expected SyncNow to pull and process once, got pulls=%d processes=%d", pulls, processes) } - flags := syncer.flags() - if len(flags) != 1 || !flags[0] { - t.Fatalf("expected SyncNow processing to sync metadata, got %v", flags) - } } func TestRedisDrainPullAndProcessCanRunIndependently(t *testing.T) { syncer := &redisDrainSyncStub{pullStarted: make(chan struct{}), releasePull: make(chan struct{})} - drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour, MetadataInterval: time.Hour}) + drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: time.Hour}) ctx := context.Background() pullDone := make(chan error, 1) go func() { @@ -219,7 +198,7 @@ func TestRedisDrainPullAndProcessCanRunIndependently(t *testing.T) { }() <-syncer.pullStarted - if _, err := drain.runRedisProcess(ctx, false); err != nil { + if _, err := drain.runRedisProcess(ctx); err != nil { close(syncer.releasePull) t.Fatalf("expected process to run while pull is active, got %v", err) } @@ -228,7 +207,7 @@ func TestRedisDrainPullAndProcessCanRunIndependently(t *testing.T) { t.Fatalf("pull returned error: %v", err) } - pulls, processes, _, _ := syncer.counts() + pulls, processes := syncer.counts() if pulls != 1 || processes != 1 { t.Fatalf("expected pull and process to each run once, got pulls=%d processes=%d", pulls, processes) } @@ -236,7 +215,7 @@ func TestRedisDrainPullAndProcessCanRunIndependently(t *testing.T) { func TestRedisDrainBacksOffAfterPullError(t *testing.T) { syncer := &redisDrainSyncStub{pullErrs: []error{errors.New("dial failed")}} - drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: 25 * time.Millisecond, MetadataInterval: time.Hour}) + drain := NewRedisDrain(syncer, RedisDrainConfig{IdleInterval: time.Hour, ErrorBackoff: 25 * time.Millisecond}) ctx, cancel := context.WithCancel(context.Background()) var slept time.Duration drain.sleep = func(_ context.Context, d time.Duration) bool { diff --git a/internal/repository/auth_files.go b/internal/repository/auth_files.go deleted file mode 100644 index 4a05fc06..00000000 --- a/internal/repository/auth_files.go +++ /dev/null @@ -1,105 +0,0 @@ -package repository - -import ( - "fmt" - "strings" - - "cpa-usage-keeper/internal/models" - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -type AuthFileInput struct { - AuthIndex string - Name string - Email string - Type string - Provider string - Label string - Status string - Source string - Disabled bool - Unavailable bool - RuntimeOnly bool -} - -func ReplaceAuthFiles(db *gorm.DB, files []AuthFileInput) error { - if db == nil { - return fmt.Errorf("database is nil") - } - - normalized := make([]models.AuthFile, 0, len(files)) - authIndexes := make([]string, 0, len(files)) - seen := make(map[string]struct{}, len(files)) - for _, file := range files { - authIndex := strings.TrimSpace(file.AuthIndex) - if authIndex == "" { - continue - } - if _, ok := seen[authIndex]; ok { - continue - } - seen[authIndex] = struct{}{} - authIndexes = append(authIndexes, authIndex) - normalized = append(normalized, models.AuthFile{ - AuthIndex: authIndex, - Name: strings.TrimSpace(file.Name), - Email: strings.TrimSpace(file.Email), - Type: strings.TrimSpace(file.Type), - Provider: strings.TrimSpace(file.Provider), - Label: strings.TrimSpace(file.Label), - Status: strings.TrimSpace(file.Status), - Source: strings.TrimSpace(file.Source), - Disabled: file.Disabled, - Unavailable: file.Unavailable, - RuntimeOnly: file.RuntimeOnly, - }) - } - - return db.Transaction(func(tx *gorm.DB) error { - if len(normalized) == 0 { - if err := tx.Where("1 = 1").Delete(&models.AuthFile{}).Error; err != nil { - return fmt.Errorf("soft delete auth files: %w", err) - } - return nil - } - - if err := tx.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "auth_index"}}, - DoUpdates: clause.AssignmentColumns([]string{ - "name", - "email", - "type", - "provider", - "label", - "status", - "source", - "disabled", - "unavailable", - "runtime_only", - "updated_at", - "deleted_at", - }), - }).Create(&normalized).Error; err != nil { - return fmt.Errorf("upsert auth files: %w", err) - } - - if err := tx.Where("auth_index NOT IN ?", authIndexes).Delete(&models.AuthFile{}).Error; err != nil { - return fmt.Errorf("soft delete stale auth files: %w", err) - } - - return nil - }) -} - -func ListAuthFiles(db *gorm.DB) ([]models.AuthFile, error) { - if db == nil { - return nil, fmt.Errorf("database is nil") - } - - var files []models.AuthFile - if err := db.Order("auth_index asc").Find(&files).Error; err != nil { - return nil, fmt.Errorf("list auth files: %w", err) - } - return files, nil -} diff --git a/internal/repository/auth_files_test.go b/internal/repository/auth_files_test.go deleted file mode 100644 index 05cecdf5..00000000 --- a/internal/repository/auth_files_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package repository - -import ( - "path/filepath" - "testing" - - "cpa-usage-keeper/internal/config" - "gorm.io/gorm" -) - -func TestReplaceAuthFilesUpsertsSoftDeletesAndRestoresRows(t *testing.T) { - db := openAuthFilesTestDatabase(t) - if err := ReplaceAuthFiles(db, []AuthFileInput{{ - AuthIndex: "1", - Name: "First", - Email: "first@example.com", - Type: "claude", - }, { - AuthIndex: "2", - Name: "Second", - Email: "second@example.com", - Type: "gemini", - }}); err != nil { - t.Fatalf("ReplaceAuthFiles returned error: %v", err) - } - - if err := ReplaceAuthFiles(db, []AuthFileInput{{ - AuthIndex: "2", - Name: "Second Updated", - Email: "updated@example.com", - Type: "vertex", - }}); err != nil { - t.Fatalf("ReplaceAuthFiles returned error: %v", err) - } - - files, err := ListAuthFiles(db) - if err != nil { - t.Fatalf("ListAuthFiles returned error: %v", err) - } - if len(files) != 1 { - t.Fatalf("expected 1 auth file after replacement, got %d", len(files)) - } - if files[0].AuthIndex != "2" || files[0].Email != "updated@example.com" || files[0].Type != "vertex" { - t.Fatalf("unexpected auth file after replacement: %+v", files[0]) - } - - if err := ReplaceAuthFiles(db, []AuthFileInput{{ - AuthIndex: "1", - Name: "First Restored", - Email: "first-restored@example.com", - Type: "claude", - }}); err != nil { - t.Fatalf("ReplaceAuthFiles restore returned error: %v", err) - } - - files, err = ListAuthFiles(db) - if err != nil { - t.Fatalf("ListAuthFiles returned error: %v", err) - } - if len(files) != 1 || files[0].AuthIndex != "1" || files[0].Email != "first-restored@example.com" { - t.Fatalf("unexpected restored auth file: %+v", files) - } -} - -func openAuthFilesTestDatabase(t *testing.T) *gorm.DB { - t.Helper() - db, err := OpenDatabase(config.Config{SQLitePath: filepath.Join(t.TempDir(), "auth_files.db")}) - if err != nil { - t.Fatalf("OpenDatabase returned error: %v", err) - } - closeTestDatabase(t, db) - return db -} diff --git a/internal/repository/db.go b/internal/repository/db.go index 6cc90017..fe65466c 100644 --- a/internal/repository/db.go +++ b/internal/repository/db.go @@ -13,34 +13,8 @@ import ( "gorm.io/gorm/clause" ) -type SnapshotRunInput struct { - FetchedAt time.Time - CPABaseURL string - ExportedAt *time.Time - Version string - Status string - HTTPStatus int - PayloadHash string - RawPayload []byte - ErrorMessage string -} - -type SnapshotRunResult struct { - Status string - HTTPStatus int - ErrorMessage string - InsertedEvents int - DedupedEvents int - ExportedAt *time.Time -} - -type SnapshotRunsCleanupResult struct { - Deleted int64 -} - type StorageCleanupResult struct { - RedisInbox RedisUsageInboxCleanupResult - SnapshotRuns SnapshotRunsCleanupResult + RedisInbox RedisUsageInboxCleanupResult } func OpenDatabase(cfg config.Config) (*gorm.DB, error) { @@ -67,6 +41,9 @@ func OpenDatabase(cfg config.Config) (*gorm.DB, error) { return nil, fmt.Errorf("enable sqlite foreign keys: %w", err) } + if err := runSchemaMigrations(db); err != nil { + return nil, fmt.Errorf("run schema migrations: %w", err) + } if err := db.AutoMigrate(models.All()...); err != nil { return nil, fmt.Errorf("auto migrate database: %w", err) } @@ -82,59 +59,6 @@ func sqliteDSN(path string) string { return trimmed + "?_busy_timeout=5000&_foreign_keys=on" } -func CreateSnapshotRun(db *gorm.DB, input SnapshotRunInput) (*models.SnapshotRun, error) { - if db == nil { - return nil, fmt.Errorf("database is nil") - } - - run := &models.SnapshotRun{ - FetchedAt: input.FetchedAt.UTC(), - CPABaseURL: strings.TrimSpace(input.CPABaseURL), - ExportedAt: normalizeOptionalTime(input.ExportedAt), - Version: strings.TrimSpace(input.Version), - Status: strings.TrimSpace(input.Status), - HTTPStatus: input.HTTPStatus, - PayloadHash: strings.TrimSpace(input.PayloadHash), - RawPayload: append([]byte(nil), input.RawPayload...), - ErrorMessage: strings.TrimSpace(input.ErrorMessage), - } - if run.Status == "" { - run.Status = "pending" - } - - if err := db.Create(run).Error; err != nil { - return nil, fmt.Errorf("create snapshot run: %w", err) - } - - return run, nil -} - -func FinalizeSnapshotRun(db *gorm.DB, snapshotRunID uint, result SnapshotRunResult) error { - if db == nil { - return fmt.Errorf("database is nil") - } - - updates := map[string]any{ - "status": strings.TrimSpace(result.Status), - "http_status": result.HTTPStatus, - "error_message": strings.TrimSpace(result.ErrorMessage), - "inserted_events": result.InsertedEvents, - "deduped_events": result.DedupedEvents, - } - if updates["status"] == "" { - updates["status"] = "completed" - } - if exportedAt := normalizeOptionalTime(result.ExportedAt); exportedAt != nil { - updates["exported_at"] = *exportedAt - } - - if err := db.Model(&models.SnapshotRun{}).Where("id = ?", snapshotRunID).Updates(updates).Error; err != nil { - return fmt.Errorf("finalize snapshot run %d: %w", snapshotRunID, err) - } - - return nil -} - func InsertUsageEvents(db *gorm.DB, events []models.UsageEvent) (int, int, error) { if db == nil { return 0, 0, fmt.Errorf("database is nil") @@ -163,75 +87,17 @@ func InsertUsageEvents(db *gorm.DB, events []models.UsageEvent) (int, int, error return inserted, deduped, nil } -func FindLatestUsageEventTimestamp(db *gorm.DB) (*time.Time, error) { - if db == nil { - return nil, fmt.Errorf("database is nil") - } - - var event models.UsageEvent - result := db.Select("timestamp").Order("timestamp DESC").Limit(1).Find(&event) - if result.Error != nil { - return nil, fmt.Errorf("find latest usage event timestamp: %w", result.Error) - } - if result.RowsAffected == 0 { - return nil, nil - } - - timestamp := event.Timestamp.UTC() - return ×tamp, nil -} - -// CleanupSnapshotRuns 按项目本地日期保留今天和往前 7 天内每天最新的一条 snapshot_run。 -// 只要保留窗口内存在快照,其它 snapshot_runs 都会删除;如果窗口内没有任何快照,则直接返回避免误删全表。 -func CleanupSnapshotRuns(db *gorm.DB, now time.Time) (SnapshotRunsCleanupResult, error) { - if db == nil { - return SnapshotRunsCleanupResult{}, fmt.Errorf("database is nil") - } - - localNow := now.In(time.Local) - localTodayStart := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, time.Local) - keepIDs := make([]uint, 0, 7) - for dayOffset := 0; dayOffset <= 7; dayOffset++ { - dayStart := localTodayStart.AddDate(0, 0, -dayOffset).UTC() - dayEnd := localTodayStart.AddDate(0, 0, -dayOffset+1).UTC() - if dayOffset == 0 { - dayEnd = now.UTC().Add(time.Nanosecond) - } - var dayIDs []uint - err := db.Model(&models.SnapshotRun{}).Select("id").Where("fetched_at >= ? AND fetched_at < ?", dayStart, dayEnd).Order("fetched_at DESC, id DESC").Limit(1).Pluck("id", &dayIDs).Error - if err != nil { - return SnapshotRunsCleanupResult{}, fmt.Errorf("load snapshot run retained for cleanup: %w", err) - } - if len(dayIDs) > 0 { - keepIDs = append(keepIDs, dayIDs[0]) - } - } - - if len(keepIDs) == 0 { - return SnapshotRunsCleanupResult{}, nil - } - deleted := db.Model(&models.SnapshotRun{}).Where("id NOT IN ?", keepIDs).Delete(&models.SnapshotRun{}) - if deleted.Error != nil { - return SnapshotRunsCleanupResult{}, fmt.Errorf("delete old snapshot runs: %w", deleted.Error) - } - return SnapshotRunsCleanupResult{Deleted: deleted.RowsAffected}, nil -} - -// CleanupStorage 是每日维护任务的统一仓储清理入口:先清 Redis inbox,再清 snapshot_runs,最后执行 VACUUM。 +// CleanupStorage 是每日维护任务的统一仓储清理入口:先清 Redis inbox,最后执行 VACUUM。 // VACUUM 必须在删除完成后单独执行,任何一步失败都会停止后续步骤并把已完成部分的结果返回给上层日志。 func CleanupStorage(db *gorm.DB, now time.Time) (StorageCleanupResult, error) { redisResult, err := CleanupRedisUsageInbox(db, now) if err != nil { return StorageCleanupResult{RedisInbox: redisResult}, err } - snapshotResult, err := CleanupSnapshotRuns(db, now) - if err != nil { - return StorageCleanupResult{RedisInbox: redisResult, SnapshotRuns: snapshotResult}, err - } if err := db.Exec("VACUUM").Error; err != nil { - return StorageCleanupResult{RedisInbox: redisResult, SnapshotRuns: snapshotResult}, err + return StorageCleanupResult{RedisInbox: redisResult}, err } - return StorageCleanupResult{RedisInbox: redisResult, SnapshotRuns: snapshotResult}, nil + return StorageCleanupResult{RedisInbox: redisResult}, nil } func Vacuum(db *gorm.DB) error { @@ -240,11 +106,3 @@ func Vacuum(db *gorm.DB) error { } return db.Exec("VACUUM").Error } - -func normalizeOptionalTime(value *time.Time) *time.Time { - if value == nil { - return nil - } - normalized := value.UTC() - return &normalized -} diff --git a/internal/repository/db_test.go b/internal/repository/db_test.go index 40720834..28d5550c 100644 --- a/internal/repository/db_test.go +++ b/internal/repository/db_test.go @@ -23,8 +23,8 @@ func TestOpenDatabaseAutoMigratesCoreTables(t *testing.T) { } closeTestDatabase(t, db) - if !db.Migrator().HasTable("snapshot_runs") { - t.Fatal("expected snapshot_runs table to exist") + if db.Migrator().HasTable("snapshot_runs") { + t.Fatal("expected legacy snapshot_runs table not to exist") } if !db.Migrator().HasTable("usage_events") { t.Fatal("expected usage_events table to exist") @@ -70,47 +70,12 @@ func TestOpenDatabaseConfiguresSQLiteRuntime(t *testing.T) { } } -func TestCreateSnapshotRunStoresInitialState(t *testing.T) { - db := openTestDatabase(t) - fetchedAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) - exportedAt := time.Date(2026, 4, 16, 11, 55, 0, 0, time.FixedZone("UTC+2", 2*60*60)) - - run, err := CreateSnapshotRun(db, SnapshotRunInput{ - FetchedAt: fetchedAt, - CPABaseURL: " https://cpa.example.com/ ", - ExportedAt: &exportedAt, - Version: "1", - Status: "pending", - HTTPStatus: 200, - PayloadHash: "abc123", - RawPayload: []byte(`{"version":1}`), - ErrorMessage: "", - }) - if err != nil { - t.Fatalf("CreateSnapshotRun returned error: %v", err) - } - - var stored models.SnapshotRun - if err := db.First(&stored, run.ID).Error; err != nil { - t.Fatalf("load snapshot run: %v", err) - } - if stored.Status != "pending" { - t.Fatalf("expected pending status, got %q", stored.Status) - } - if stored.CPABaseURL != "https://cpa.example.com/" { - t.Fatalf("expected trimmed base url, got %q", stored.CPABaseURL) - } - if stored.ExportedAt == nil || !stored.ExportedAt.Equal(exportedAt.UTC()) { - t.Fatalf("expected normalized exported_at, got %+v", stored.ExportedAt) - } -} - func TestInsertUsageEventsDeduplicatesByEventKey(t *testing.T) { db := openTestDatabase(t) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 10}, - {EventKey: "event-1", SnapshotRunID: 2, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-opus", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), TotalTokens: 20}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 10}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 10}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-opus", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), TotalTokens: 20}, } inserted, deduped, err := InsertUsageEvents(db, events) @@ -136,14 +101,13 @@ func TestInsertUsageEventsBatchesLargeInsertSet(t *testing.T) { baseTime := time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC) for i := 0; i < 300; i++ { events = append(events, models.UsageEvent{ - EventKey: fmt.Sprintf("event-%03d", i), - SnapshotRunID: 1, - APIGroupKey: "provider-a", - Model: "claude-sonnet", - Timestamp: baseTime.Add(time.Duration(i) * time.Minute), - Source: "source-a", - AuthIndex: "auth-1", - TotalTokens: int64(i + 1), + EventKey: fmt.Sprintf("event-%03d", i), + APIGroupKey: "provider-a", + Model: "claude-sonnet", + Timestamp: baseTime.Add(time.Duration(i) * time.Minute), + Source: "source-a", + AuthIndex: "auth-1", + TotalTokens: int64(i + 1), }) } @@ -164,265 +128,7 @@ func TestInsertUsageEventsBatchesLargeInsertSet(t *testing.T) { } } -func TestFindLatestUsageEventTimestampReturnsNilForEmptyTable(t *testing.T) { - db := openTestDatabase(t) - - timestamp, err := FindLatestUsageEventTimestamp(db) - if err != nil { - t.Fatalf("FindLatestUsageEventTimestamp returned error: %v", err) - } - if timestamp != nil { - t.Fatalf("expected nil timestamp for empty table, got %v", *timestamp) - } -} - -func TestFindLatestUsageEventTimestampReturnsMaxValue(t *testing.T) { - db := openTestDatabase(t) - events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 18, 11, 0, 0, 0, time.UTC), TotalTokens: 20}, - {EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 17, 10, 0, 0, 0, time.UTC), TotalTokens: 15}, - } - if _, _, err := InsertUsageEvents(db, events); err != nil { - t.Fatalf("InsertUsageEvents returned error: %v", err) - } - - timestamp, err := FindLatestUsageEventTimestamp(db) - if err != nil { - t.Fatalf("FindLatestUsageEventTimestamp returned error: %v", err) - } - if timestamp == nil { - t.Fatal("expected max timestamp, got nil") - } - expected := time.Date(2026, 4, 18, 11, 0, 0, 0, time.UTC) - if !timestamp.Equal(expected) { - t.Fatalf("expected max timestamp %s, got %s", expected, timestamp) - } -} - -func TestFinalizeSnapshotRunUpdatesResultFields(t *testing.T) { - db := openTestDatabase(t) - run, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Now().UTC(), Status: "pending"}) - if err != nil { - t.Fatalf("CreateSnapshotRun returned error: %v", err) - } - - exportedAt := time.Date(2026, 4, 16, 12, 30, 0, 0, time.UTC) - err = FinalizeSnapshotRun(db, run.ID, SnapshotRunResult{ - Status: "completed", - HTTPStatus: 200, - InsertedEvents: 7, - DedupedEvents: 2, - ExportedAt: &exportedAt, - }) - if err != nil { - t.Fatalf("FinalizeSnapshotRun returned error: %v", err) - } - - var stored models.SnapshotRun - if err := db.First(&stored, run.ID).Error; err != nil { - t.Fatalf("load snapshot run: %v", err) - } - if stored.Status != "completed" { - t.Fatalf("expected completed status, got %q", stored.Status) - } - if stored.InsertedEvents != 7 || stored.DedupedEvents != 2 { - t.Fatalf("unexpected event counts: %+v", stored) - } - if stored.ExportedAt == nil || !stored.ExportedAt.Equal(exportedAt) { - t.Fatalf("expected exportedAt to be updated, got %+v", stored.ExportedAt) - } -} - -func TestCleanupSnapshotRunsKeepsLatestSnapshotPerLocalDayForSevenDays(t *testing.T) { - previousLocal := time.Local - location, err := time.LoadLocation("Asia/Shanghai") - if err != nil { - t.Fatalf("load location: %v", err) - } - time.Local = location - t.Cleanup(func() { time.Local = previousLocal }) - db := openTestDatabase(t) - now := time.Date(2026, 4, 27, 2, 30, 0, 0, time.UTC) - - oldDay, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 20, 15, 0, 0, 0, time.UTC), RawPayload: []byte(`old`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun oldDay returned error: %v", err) - } - if _, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 20, 17, 0, 0, 0, time.UTC), RawPayload: []byte(`first-day-early`)}); err != nil { - t.Fatalf("CreateSnapshotRun firstDayEarly returned error: %v", err) - } - firstDayLatest, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 21, 15, 30, 0, 0, time.UTC), RawPayload: []byte(`first-day-latest`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun firstDayLatest returned error: %v", err) - } - if _, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 26, 16, 10, 0, 0, time.UTC), RawPayload: []byte(`today-early`)}); err != nil { - t.Fatalf("CreateSnapshotRun todayEarly returned error: %v", err) - } - todayLatest, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 27, 2, 0, 0, 0, time.UTC), RawPayload: []byte(`today-latest`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun todayLatest returned error: %v", err) - } - if _, _, err := InsertUsageEvents(db, []models.UsageEvent{{EventKey: "event-old-snapshot", SnapshotRunID: oldDay.ID, Timestamp: now, TotalTokens: 1}}); err != nil { - t.Fatalf("InsertUsageEvents returned error: %v", err) - } - - result, err := CleanupSnapshotRuns(db, now) - if err != nil { - t.Fatalf("CleanupSnapshotRuns returned error: %v", err) - } - if result.Deleted != 2 { - t.Fatalf("expected 2 deleted snapshot runs, got %+v", result) - } - - var remaining []models.SnapshotRun - if err := db.Order("id asc").Find(&remaining).Error; err != nil { - t.Fatalf("load remaining snapshot runs: %v", err) - } - remainingIDs := make([]uint, 0, len(remaining)) - for _, run := range remaining { - remainingIDs = append(remainingIDs, run.ID) - } - expectedIDs := []uint{oldDay.ID, firstDayLatest.ID, todayLatest.ID} - if fmt.Sprint(remainingIDs) != fmt.Sprint(expectedIDs) { - t.Fatalf("expected remaining snapshot ids %v, got %v", expectedIDs, remainingIDs) - } - - var eventCount int64 - if err := db.Model(&models.UsageEvent{}).Count(&eventCount).Error; err != nil { - t.Fatalf("count usage events: %v", err) - } - if eventCount != 1 { - t.Fatalf("expected usage events to remain untouched, got %d", eventCount) - } -} - -func TestCleanupSnapshotRunsKeepsSeventhPreviousLocalDay(t *testing.T) { - previousLocal := time.Local - location, err := time.LoadLocation("Asia/Shanghai") - if err != nil { - t.Fatalf("load location: %v", err) - } - time.Local = location - t.Cleanup(func() { time.Local = previousLocal }) - db := openTestDatabase(t) - now := time.Date(2026, 4, 30, 12, 0, 0, 0, location) - endDay := time.Date(2026, 4, 30, 0, 0, 0, 0, location) - startDay := endDay.AddDate(0, 0, -7) - if endDay.Sub(startDay) != 7*24*time.Hour { - t.Fatalf("expected cleanup window from %s to %s to be 7 days", startDay, endDay) - } - - older, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 22, 23, 0, 0, 0, location), RawPayload: []byte(`older`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun older returned error: %v", err) - } - seventhDayEarly, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 23, 9, 0, 0, 0, location), RawPayload: []byte(`seventh-day-early`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun seventhDayEarly returned error: %v", err) - } - seventhDayLatest, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 23, 20, 0, 0, 0, location), RawPayload: []byte(`seventh-day-latest`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun seventhDayLatest returned error: %v", err) - } - todayLatest, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 30, 11, 0, 0, 0, location), RawPayload: []byte(`today-latest`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun todayLatest returned error: %v", err) - } - - result, err := CleanupSnapshotRuns(db, now) - if err != nil { - t.Fatalf("CleanupSnapshotRuns returned error: %v", err) - } - if result.Deleted != 2 { - t.Fatalf("expected older and early seventh-day snapshots to be deleted, got %+v", result) - } - - var remaining []models.SnapshotRun - if err := db.Order("id asc").Find(&remaining).Error; err != nil { - t.Fatalf("load remaining snapshot runs: %v", err) - } - remainingIDs := make([]uint, 0, len(remaining)) - for _, run := range remaining { - remainingIDs = append(remainingIDs, run.ID) - } - expectedIDs := []uint{seventhDayLatest.ID, todayLatest.ID} - if fmt.Sprint(remainingIDs) != fmt.Sprint(expectedIDs) { - t.Fatalf("expected remaining snapshot ids %v after deleting %d and %d, got %v", expectedIDs, older.ID, seventhDayEarly.ID, remainingIDs) - } -} - -func TestCleanupSnapshotRunsKeepsRowsWhenRetentionWindowHasNoSnapshots(t *testing.T) { - previousLocal := time.Local - location, err := time.LoadLocation("Asia/Shanghai") - if err != nil { - t.Fatalf("load location: %v", err) - } - time.Local = location - t.Cleanup(func() { time.Local = previousLocal }) - db := openTestDatabase(t) - now := time.Date(2026, 4, 30, 12, 0, 0, 0, location) - - oldSnapshot, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 1, 12, 0, 0, 0, location), RawPayload: []byte(`old`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun old returned error: %v", err) - } - - result, err := CleanupSnapshotRuns(db, now) - if err != nil { - t.Fatalf("CleanupSnapshotRuns returned error: %v", err) - } - if result.Deleted != 0 { - t.Fatalf("expected no deletions when retention window has no snapshots, got %+v", result) - } - - var remaining []models.SnapshotRun - if err := db.Find(&remaining).Error; err != nil { - t.Fatalf("load remaining snapshot runs: %v", err) - } - if len(remaining) != 1 || remaining[0].ID != oldSnapshot.ID { - t.Fatalf("expected old snapshot %d to remain when keepIDs is empty, got %+v", oldSnapshot.ID, remaining) - } -} - -func TestCleanupSnapshotRunsDeletesFutureSnapshots(t *testing.T) { - previousLocal := time.Local - location, err := time.LoadLocation("Asia/Shanghai") - if err != nil { - t.Fatalf("load location: %v", err) - } - time.Local = location - t.Cleanup(func() { time.Local = previousLocal }) - db := openTestDatabase(t) - now := time.Date(2026, 4, 27, 2, 30, 0, 0, time.UTC) - - kept, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 27, 2, 0, 0, 0, time.UTC), RawPayload: []byte(`kept`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun kept returned error: %v", err) - } - future, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 27, 4, 0, 0, 0, time.UTC), RawPayload: []byte(`future`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun future returned error: %v", err) - } - - result, err := CleanupSnapshotRuns(db, now) - if err != nil { - t.Fatalf("CleanupSnapshotRuns returned error: %v", err) - } - if result.Deleted != 1 { - t.Fatalf("expected future snapshot to be deleted, got %+v", result) - } - - var remaining []models.SnapshotRun - if err := db.Order("id asc").Find(&remaining).Error; err != nil { - t.Fatalf("load remaining snapshot runs: %v", err) - } - if len(remaining) != 1 || remaining[0].ID != kept.ID { - t.Fatalf("expected only current snapshot %d to remain after deleting %d, got %+v", kept.ID, future.ID, remaining) - } -} - -func TestCleanupStorageCleansRedisInboxAndSnapshotRuns(t *testing.T) { +func TestCleanupStorageCleansRedisInboxAndVacuums(t *testing.T) { previousLocal := time.Local location, err := time.LoadLocation("Asia/Shanghai") if err != nil { @@ -443,20 +149,12 @@ func TestCleanupStorageCleansRedisInboxAndSnapshotRuns(t *testing.T) { if err := db.Model(&models.RedisUsageInbox{}).Where("id = ?", inboxRows[0].ID).Updates(map[string]any{"status": RedisUsageInboxStatusProcessed, "processed_at": time.Date(2026, 4, 26, 15, 59, 59, 0, time.UTC)}).Error; err != nil { t.Fatalf("seed processed inbox row: %v", err) } - oldSnapshot, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 19, 15, 0, 0, 0, time.UTC), RawPayload: []byte(`old`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun old returned error: %v", err) - } - keptSnapshot, err := CreateSnapshotRun(db, SnapshotRunInput{FetchedAt: time.Date(2026, 4, 27, 2, 0, 0, 0, time.UTC), RawPayload: []byte(`kept`)}) - if err != nil { - t.Fatalf("CreateSnapshotRun kept returned error: %v", err) - } result, err := CleanupStorage(db, now) if err != nil { t.Fatalf("CleanupStorage returned error: %v", err) } - if result.RedisInbox.ProcessedDeleted != 1 || result.SnapshotRuns.Deleted != 1 { + if result.RedisInbox.ProcessedDeleted != 1 { t.Fatalf("unexpected cleanup result: %+v", result) } @@ -467,13 +165,6 @@ func TestCleanupStorageCleansRedisInboxAndSnapshotRuns(t *testing.T) { if len(inboxRemaining) != 1 || inboxRemaining[0].ID != inboxRows[1].ID { t.Fatalf("expected only pending inbox row to remain, got %+v", inboxRemaining) } - var snapshotRemaining []models.SnapshotRun - if err := db.Order("id asc").Find(&snapshotRemaining).Error; err != nil { - t.Fatalf("load remaining snapshot runs: %v", err) - } - if len(snapshotRemaining) != 1 || snapshotRemaining[0].ID != keptSnapshot.ID { - t.Fatalf("expected only retained snapshot %d to remain after deleting %d, got %+v", keptSnapshot.ID, oldSnapshot.ID, snapshotRemaining) - } } func openTestDatabase(t *testing.T) *gorm.DB { diff --git a/internal/repository/migrations.go b/internal/repository/migrations.go new file mode 100644 index 00000000..57831b28 --- /dev/null +++ b/internal/repository/migrations.go @@ -0,0 +1,502 @@ +package repository + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "cpa-usage-keeper/internal/models" + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +const ( + migrationAddUsageEventRedisFields = "20260503_add_usage_event_redis_fields" + migrationBackfillUsageEventRedisFields = "20260503_backfill_usage_event_redis_fields" + migrationDropSnapshotRuns = "20260503_drop_snapshot_runs" + migrationDropLegacySnapshotRunColumns = "20260504_drop_legacy_snapshot_run_columns" + migrationCreateUsageIdentities = "20260504_create_usage_identities" + migrationMigrateUsageIdentitiesMetadata = "20260504_migrate_usage_identities_metadata" + migrationBackfillUsageEventIdentityFields = "20260504_backfill_usage_event_identity_fields" + migrationBackfillUsageIdentityStats = "20260504_backfill_usage_identity_stats" + migrationDropLegacyMetadataTables = "20260504_drop_legacy_metadata_tables" + migrationRemovePrefixUsageIdentities = "20260504_remove_prefix_usage_identities" +) + +type schemaMigration struct { + Version string `gorm:"primaryKey;column:version"` + AppliedAt time.Time `gorm:"not null;column:applied_at"` +} + +func (schemaMigration) TableName() string { + return "schema_migrations" +} + +type databaseMigration struct { + version string + run func(*gorm.DB) error +} + +func runSchemaMigrations(db *gorm.DB) error { + if err := db.Exec("CREATE TABLE IF NOT EXISTS schema_migrations (version TEXT PRIMARY KEY, applied_at DATETIME NOT NULL)").Error; err != nil { + return fmt.Errorf("create schema_migrations table: %w", err) + } + + migrations := []databaseMigration{ + {version: migrationAddUsageEventRedisFields, run: addUsageEventRedisFieldsMigration}, + {version: migrationBackfillUsageEventRedisFields, run: backfillUsageEventRedisFieldsMigration}, + {version: migrationDropSnapshotRuns, run: dropSnapshotRunsMigration}, + {version: migrationDropLegacySnapshotRunColumns, run: dropLegacySnapshotRunColumnsMigration}, + {version: migrationCreateUsageIdentities, run: createUsageIdentitiesMigration}, + {version: migrationMigrateUsageIdentitiesMetadata, run: migrateUsageIdentitiesMetadataMigration}, + {version: migrationBackfillUsageEventIdentityFields, run: backfillUsageEventIdentityFieldsMigration}, + {version: migrationBackfillUsageIdentityStats, run: backfillUsageIdentityStatsMigration}, + {version: migrationDropLegacyMetadataTables, run: dropLegacyMetadataTablesMigration}, + {version: migrationRemovePrefixUsageIdentities, run: removePrefixUsageIdentitiesMigration}, + } + for _, migration := range migrations { + if err := runSchemaMigration(db, migration); err != nil { + return err + } + } + return nil +} + +func runSchemaMigration(db *gorm.DB, migration databaseMigration) error { + return db.Transaction(func(tx *gorm.DB) error { + logger := logrus.WithField("version", migration.version) + var count int64 + if err := tx.Table("schema_migrations").Where("version = ?", migration.version).Count(&count).Error; err != nil { + logger.WithError(err).Error("schema migration failed") + return fmt.Errorf("check schema migration %s: %w", migration.version, err) + } + if count > 0 { + logger.Info("schema migration skipped") + return nil + } + logger.Info("schema migration started") + if err := migration.run(tx); err != nil { + logger.WithError(err).Error("schema migration failed") + return fmt.Errorf("run schema migration %s: %w", migration.version, err) + } + if err := tx.Create(&schemaMigration{Version: migration.version, AppliedAt: time.Now().UTC()}).Error; err != nil { + logger.WithError(err).Error("schema migration failed") + return fmt.Errorf("record schema migration %s: %w", migration.version, err) + } + logger.Info("schema migration applied") + return nil + }) +} + +func addUsageEventRedisFieldsMigration(tx *gorm.DB) error { + if !tx.Migrator().HasTable(&models.UsageEvent{}) { + return nil + } + columns := []struct { + name string + sql string + }{ + {name: "provider", sql: "ALTER TABLE usage_events ADD COLUMN provider TEXT"}, + {name: "endpoint", sql: "ALTER TABLE usage_events ADD COLUMN endpoint TEXT"}, + {name: "auth_type", sql: "ALTER TABLE usage_events ADD COLUMN auth_type TEXT"}, + {name: "request_id", sql: "ALTER TABLE usage_events ADD COLUMN request_id TEXT"}, + } + for _, column := range columns { + if tx.Migrator().HasColumn(&models.UsageEvent{}, column.name) { + continue + } + if err := tx.Exec(column.sql).Error; err != nil { + return fmt.Errorf("add usage_events.%s column: %w", column.name, err) + } + } + return nil +} + +type redisUsageBackfillPayload struct { + Provider string `json:"provider"` + Endpoint string `json:"endpoint"` + AuthType string `json:"auth_type"` + RequestID string `json:"request_id"` +} + +func backfillUsageEventRedisFieldsMigration(tx *gorm.DB) error { + if !tx.Migrator().HasTable(&models.UsageEvent{}) || !tx.Migrator().HasTable(&models.RedisUsageInbox{}) { + return nil + } + for _, column := range []string{"provider", "endpoint", "auth_type", "request_id"} { + if !tx.Migrator().HasColumn(&models.UsageEvent{}, column) { + return nil + } + } + + var inboxRows []models.RedisUsageInbox + return tx.Where("status = ?", RedisUsageInboxStatusProcessed). + Order("id asc"). + FindInBatches(&inboxRows, 500, func(_ *gorm.DB, _ int) error { + for _, inbox := range inboxRows { + var payload redisUsageBackfillPayload + if err := json.Unmarshal([]byte(inbox.RawMessage), &payload); err != nil { + continue + } + payload.Provider = strings.TrimSpace(payload.Provider) + payload.Endpoint = strings.TrimSpace(payload.Endpoint) + payload.AuthType = normalizeUsageEventRedisAuthType(payload.AuthType) + payload.RequestID = strings.TrimSpace(payload.RequestID) + if payload.Provider == "" && payload.Endpoint == "" && payload.AuthType == "" && payload.RequestID == "" { + continue + } + if err := backfillUsageEventRedisFields(tx, strings.TrimSpace(inbox.UsageEventKey), payload, true); err != nil { + return err + } + } + return nil + }).Error +} + +func backfillUsageEventRedisFields(tx *gorm.DB, usageEventKey string, payload redisUsageBackfillPayload, allowRequestIDFallback bool) error { + if usageEventKey == "" { + if allowRequestIDFallback && payload.RequestID != "" { + return backfillUsageEventRedisFields(tx, payload.RequestID, payload, false) + } + return nil + } + + var event models.UsageEvent + result := tx.Where("event_key = ?", usageEventKey).Limit(1).Find(&event) + if result.Error != nil { + return fmt.Errorf("load usage event %q for redis backfill: %w", usageEventKey, result.Error) + } + if result.RowsAffected == 0 { + if allowRequestIDFallback && payload.RequestID != "" && payload.RequestID != usageEventKey { + return backfillUsageEventRedisFields(tx, payload.RequestID, payload, false) + } + return nil + } + + updates := map[string]any{} + if strings.TrimSpace(event.Provider) == "" && payload.Provider != "" { + updates["provider"] = payload.Provider + } + if strings.TrimSpace(event.Endpoint) == "" && payload.Endpoint != "" { + updates["endpoint"] = payload.Endpoint + } + if strings.TrimSpace(event.AuthType) == "" && payload.AuthType != "" { + updates["auth_type"] = payload.AuthType + } + if strings.TrimSpace(event.RequestID) == "" && payload.RequestID != "" { + updates["request_id"] = payload.RequestID + } + if len(updates) == 0 { + return nil + } + if err := tx.Model(&models.UsageEvent{}).Where("id = ?", event.ID).Updates(updates).Error; err != nil { + return fmt.Errorf("backfill usage event %q redis fields: %w", event.EventKey, err) + } + return nil +} + +func normalizeUsageEventRedisAuthType(value string) string { + trimmed := strings.ToLower(strings.TrimSpace(value)) + if trimmed == "api_key" { + return "apikey" + } + return trimmed +} + +func dropSnapshotRunsMigration(tx *gorm.DB) error { + if !tx.Migrator().HasTable("snapshot_runs") { + return nil + } + if err := tx.Exec("DROP TABLE IF EXISTS snapshot_runs").Error; err != nil { + return fmt.Errorf("drop snapshot_runs table: %w", err) + } + return nil +} + +func dropLegacySnapshotRunColumnsMigration(tx *gorm.DB) error { + for _, indexName := range []string{"idx_usage_events_snapshot_run_id", "idx_redis_usage_inboxes_snapshot_run_id"} { + if err := tx.Exec("DROP INDEX IF EXISTS " + indexName).Error; err != nil { + return fmt.Errorf("drop legacy snapshot_run_id index %s: %w", indexName, err) + } + } + if err := dropColumnIfExists(tx, &models.UsageEvent{}, "snapshot_run_id", "usage_events"); err != nil { + return err + } + if err := dropColumnIfExists(tx, &models.RedisUsageInbox{}, "snapshot_run_id", "redis_usage_inboxes"); err != nil { + return err + } + return nil +} + +func dropColumnIfExists(tx *gorm.DB, model any, columnName string, tableName string) error { + if !tx.Migrator().HasTable(model) || !tx.Migrator().HasColumn(model, columnName) { + return nil + } + if err := tx.Exec("ALTER TABLE " + tableName + " DROP COLUMN " + columnName).Error; err != nil { + return fmt.Errorf("drop %s.%s column: %w", tableName, columnName, err) + } + return nil +} + +func createUsageIdentitiesMigration(tx *gorm.DB) error { + statements := []string{ + `CREATE TABLE IF NOT EXISTS usage_identities ( + id integer PRIMARY KEY AUTOINCREMENT, + name text, + auth_type integer, + auth_type_name text, + identity text, + type text, + provider text, + total_requests integer DEFAULT 0, + success_count integer DEFAULT 0, + failure_count integer DEFAULT 0, + input_tokens integer DEFAULT 0, + output_tokens integer DEFAULT 0, + reasoning_tokens integer DEFAULT 0, + cached_tokens integer DEFAULT 0, + total_tokens integer DEFAULT 0, + last_aggregated_usage_event_id integer DEFAULT 0, + first_used_at datetime, + last_used_at datetime, + stats_updated_at datetime, + is_deleted numeric DEFAULT false, + created_at datetime, + updated_at datetime, + deleted_at datetime + )`, + `CREATE UNIQUE INDEX IF NOT EXISTS uniq_usage_identities_type_identity ON usage_identities(auth_type, identity)`, + `CREATE INDEX IF NOT EXISTS idx_usage_identities_auth_type ON usage_identities(auth_type)`, + `CREATE INDEX IF NOT EXISTS idx_usage_identities_auth_type_name ON usage_identities(auth_type_name)`, + `CREATE INDEX IF NOT EXISTS idx_usage_identities_identity ON usage_identities(identity)`, + `CREATE INDEX IF NOT EXISTS idx_usage_identities_is_deleted ON usage_identities(is_deleted)`, + `CREATE INDEX IF NOT EXISTS idx_usage_identities_last_aggregated_usage_event_id ON usage_identities(last_aggregated_usage_event_id)`, + `CREATE INDEX IF NOT EXISTS idx_usage_identities_deleted_at ON usage_identities(deleted_at)`, + } + for _, statement := range statements { + if err := tx.Exec(statement).Error; err != nil { + return fmt.Errorf("create usage_identities schema: %w", err) + } + } + return nil +} + +func migrateUsageIdentitiesMetadataMigration(tx *gorm.DB) error { + now := time.Now().UTC() + if tx.Migrator().HasTable("auth_files") { + isDeletedSelect, deletedAtSelect := legacyDeletedStateSelect(tx, "auth_files") + if err := tx.Exec(` + INSERT INTO usage_identities (name, auth_type, auth_type_name, identity, type, provider, is_deleted, created_at, updated_at, deleted_at) + SELECT COALESCE(NULLIF(TRIM(email), ''), NULLIF(TRIM(label), ''), NULLIF(TRIM(name), ''), auth_index), + ?, ?, auth_index, type, provider, `+isDeletedSelect+`, COALESCE(created_at, ?), ?, `+deletedAtSelect+` + FROM auth_files + WHERE auth_index IS NOT NULL AND TRIM(auth_index) != '' + ON CONFLICT(auth_type, identity) DO UPDATE SET + name = excluded.name, + auth_type_name = excluded.auth_type_name, + type = excluded.type, + provider = excluded.provider, + is_deleted = excluded.is_deleted, + deleted_at = excluded.deleted_at, + updated_at = excluded.updated_at`, models.UsageIdentityAuthTypeAuthFile, "oauth", now, now).Error; err != nil { + return fmt.Errorf("migrate auth_files to usage_identities: %w", err) + } + } + if tx.Migrator().HasTable("provider_metadata") { + isDeletedSelect, deletedAtSelect := legacyDeletedStateSelect(tx, "provider_metadata") + if err := tx.Exec(` + INSERT INTO usage_identities (name, auth_type, auth_type_name, identity, type, provider, is_deleted, created_at, updated_at, deleted_at) + SELECT display_name, ?, ?, lookup_key, provider_type, display_name, `+isDeletedSelect+`, COALESCE(created_at, ?), ?, `+deletedAtSelect+` + FROM provider_metadata + WHERE lookup_key IS NOT NULL AND TRIM(lookup_key) != '' + ON CONFLICT(auth_type, identity) DO UPDATE SET + name = excluded.name, + auth_type_name = excluded.auth_type_name, + type = excluded.type, + provider = excluded.provider, + is_deleted = excluded.is_deleted, + deleted_at = excluded.deleted_at, + updated_at = excluded.updated_at`, models.UsageIdentityAuthTypeAIProvider, "apikey", now, now).Error; err != nil { + return fmt.Errorf("migrate provider_metadata to usage_identities: %w", err) + } + } + return nil +} + +func legacyDeletedStateSelect(tx *gorm.DB, table string) (string, string) { + if tx.Migrator().HasColumn(table, "deleted_at") { + return "deleted_at IS NOT NULL", "deleted_at" + } + return "false", "NULL" +} + +func backfillUsageEventIdentityFieldsMigration(tx *gorm.DB) error { + if !tx.Migrator().HasTable(&models.UsageIdentity{}) || !tx.Migrator().HasTable(&models.UsageEvent{}) { + return nil + } + for _, column := range []string{"auth_type", "provider", "source", "auth_index"} { + if !tx.Migrator().HasColumn(&models.UsageEvent{}, column) { + return nil + } + } + + if err := tx.Exec(` + UPDATE usage_events + SET auth_type = CASE + WHEN TRIM(COALESCE(auth_type, '')) = '' THEN ? + ELSE auth_type + END, + provider = CASE + WHEN TRIM(COALESCE(provider, '')) = '' THEN COALESCE(( + SELECT NULLIF(TRIM(usage_identities.provider), '') + FROM usage_identities + WHERE usage_identities.auth_type = ? + AND usage_identities.identity = usage_events.source + LIMIT 1 + ), provider) + ELSE provider + END + WHERE EXISTS ( + SELECT 1 + FROM usage_identities + WHERE usage_identities.auth_type = ? + AND usage_identities.identity = usage_events.source + ) + AND (TRIM(COALESCE(auth_type, '')) = '' OR TRIM(COALESCE(provider, '')) = '')`, "apikey", models.UsageIdentityAuthTypeAIProvider, models.UsageIdentityAuthTypeAIProvider).Error; err != nil { + return fmt.Errorf("backfill AI provider usage event identity fields: %w", err) + } + + if err := tx.Exec(` + UPDATE usage_events + SET auth_type = ? + WHERE TRIM(COALESCE(auth_type, '')) = '' + AND EXISTS ( + SELECT 1 + FROM usage_identities + WHERE usage_identities.auth_type = ? + AND usage_identities.identity = usage_events.auth_index + )`, "oauth", models.UsageIdentityAuthTypeAuthFile).Error; err != nil { + return fmt.Errorf("backfill auth file usage event identity fields: %w", err) + } + return nil +} + +func backfillUsageIdentityStatsMigration(tx *gorm.DB) error { + if !tx.Migrator().HasTable(&models.UsageIdentity{}) || !tx.Migrator().HasTable(&models.UsageEvent{}) { + return nil + } + for _, column := range []string{"auth_type", "source", "auth_index"} { + if !tx.Migrator().HasColumn(&models.UsageEvent{}, column) { + return nil + } + } + + var identities []models.UsageIdentity + if err := tx.Find(&identities).Error; err != nil { + return fmt.Errorf("list usage identities for stats backfill: %w", err) + } + for _, identity := range identities { + stats, err := aggregateUsageIdentityFullStats(tx, identity) + if err != nil { + return err + } + updates := map[string]any{ + "total_requests": stats.TotalRequests, + "success_count": stats.SuccessCount, + "failure_count": stats.FailureCount, + "input_tokens": stats.InputTokens, + "output_tokens": stats.OutputTokens, + "reasoning_tokens": stats.ReasoningTokens, + "cached_tokens": stats.CachedTokens, + "total_tokens": stats.TotalTokens, + "first_used_at": stats.FirstUsedAt, + "last_used_at": stats.LastUsedAt, + "stats_updated_at": nil, + "last_aggregated_usage_event_id": stats.MaxUsageEventID, + } + if stats.TotalRequests > 0 { + now := time.Now().UTC() + updates["stats_updated_at"] = now + } + if err := tx.Model(&models.UsageIdentity{}).Where("id = ?", identity.ID).Updates(updates).Error; err != nil { + return fmt.Errorf("backfill usage identity stats for %q: %w", identity.Identity, err) + } + } + return nil +} + +func aggregateUsageIdentityFullStats(tx *gorm.DB, identity models.UsageIdentity) (usageIdentityStatsDelta, error) { + var stats usageIdentityStatsDelta + query, ok := usageIdentityBackfillEventsQuery(tx.Model(&models.UsageEvent{}), identity) + if !ok { + return stats, nil + } + if err := query.Select(` + COUNT(*) AS total_requests, + COALESCE(SUM(CASE WHEN failed THEN 0 ELSE 1 END), 0) AS success_count, + COALESCE(SUM(CASE WHEN failed THEN 1 ELSE 0 END), 0) AS failure_count, + COALESCE(SUM(input_tokens), 0) AS input_tokens, + COALESCE(SUM(output_tokens), 0) AS output_tokens, + COALESCE(SUM(reasoning_tokens), 0) AS reasoning_tokens, + COALESCE(SUM(cached_tokens), 0) AS cached_tokens, + COALESCE(SUM(total_tokens), 0) AS total_tokens, + COALESCE(MAX(id), 0) AS max_usage_event_id`). + Scan(&stats).Error; err != nil { + return stats, fmt.Errorf("aggregate full usage identity stats for %q: %w", identity.Identity, err) + } + if stats.TotalRequests == 0 { + return stats, nil + } + + var firstEvent models.UsageEvent + firstQuery, _ := usageIdentityBackfillEventsQuery(tx.Model(&models.UsageEvent{}), identity) + if err := firstQuery.Order("timestamp asc, id asc").First(&firstEvent).Error; err != nil { + return stats, fmt.Errorf("find first usage identity event for %q: %w", identity.Identity, err) + } + firstUsedAt := firstEvent.Timestamp + stats.FirstUsedAt = &firstUsedAt + + var lastEvent models.UsageEvent + lastQuery, _ := usageIdentityBackfillEventsQuery(tx.Model(&models.UsageEvent{}), identity) + if err := lastQuery.Order("timestamp desc, id desc").First(&lastEvent).Error; err != nil { + return stats, fmt.Errorf("find last usage identity event for %q: %w", identity.Identity, err) + } + lastUsedAt := lastEvent.Timestamp + stats.LastUsedAt = &lastUsedAt + return stats, nil +} + +func usageIdentityBackfillEventsQuery(query *gorm.DB, identity models.UsageIdentity) (*gorm.DB, bool) { + switch identity.AuthType { + case models.UsageIdentityAuthTypeAuthFile: + return query.Where("auth_index = ? AND (auth_type = ? OR TRIM(COALESCE(auth_type, '')) = '')", identity.Identity, "oauth"), true + case models.UsageIdentityAuthTypeAIProvider: + return query.Where("source = ? AND (auth_type = ? OR TRIM(COALESCE(auth_type, '')) = '')", identity.Identity, "apikey"), true + default: + return query, false + } +} + +func dropLegacyMetadataTablesMigration(tx *gorm.DB) error { + if err := tx.Exec("DROP TABLE IF EXISTS auth_files").Error; err != nil { + return fmt.Errorf("drop auth_files table: %w", err) + } + if err := tx.Exec("DROP TABLE IF EXISTS provider_metadata").Error; err != nil { + return fmt.Errorf("drop provider_metadata table: %w", err) + } + return nil +} + +func removePrefixUsageIdentitiesMigration(tx *gorm.DB) error { + if !tx.Migrator().HasTable(&models.UsageIdentity{}) { + return nil + } + if err := tx.Exec(` + DELETE FROM usage_identities + WHERE auth_type = ? + AND LOWER(TRIM(identity)) IN ('gemini', 'claude', 'codex', 'vertex', 'openai')`, models.UsageIdentityAuthTypeAIProvider).Error; err != nil { + return fmt.Errorf("remove prefix-generated usage identities: %w", err) + } + return nil +} diff --git a/internal/repository/migrations_test.go b/internal/repository/migrations_test.go new file mode 100644 index 00000000..cadf6275 --- /dev/null +++ b/internal/repository/migrations_test.go @@ -0,0 +1,875 @@ +package repository + +import ( + "bytes" + "fmt" + "path/filepath" + "strings" + "testing" + "time" + + "cpa-usage-keeper/internal/config" + "cpa-usage-keeper/internal/models" + "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func TestOpenDatabaseRunsSchemaMigrationsAndAddsUsageEventRedisFields(t *testing.T) { + db := openTestDatabase(t) + + if !db.Migrator().HasTable("schema_migrations") { + t.Fatal("expected schema_migrations table to exist") + } + for _, column := range []string{"provider", "endpoint", "auth_type", "request_id"} { + if !db.Migrator().HasColumn(&models.UsageEvent{}, column) { + t.Fatalf("expected usage_events.%s column to exist", column) + } + } + + var versions []string + if err := db.Table("schema_migrations").Order("version asc").Pluck("version", &versions).Error; err != nil { + t.Fatalf("load schema migrations: %v", err) + } + expected := []string{ + "20260503_add_usage_event_redis_fields", + "20260503_backfill_usage_event_redis_fields", + "20260503_drop_snapshot_runs", + "20260504_backfill_usage_event_identity_fields", + "20260504_backfill_usage_identity_stats", + "20260504_create_usage_identities", + "20260504_drop_legacy_metadata_tables", + "20260504_drop_legacy_snapshot_run_columns", + "20260504_migrate_usage_identities_metadata", + "20260504_remove_prefix_usage_identities", + } + if len(versions) != len(expected) { + t.Fatalf("expected migration versions %v, got %v", expected, versions) + } + for i := range expected { + if versions[i] != expected[i] { + t.Fatalf("expected migration versions %v, got %v", expected, versions) + } + } +} + +func TestOpenDatabaseMigrationsAreIdempotent(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "app.db") + cfg := config.Config{SQLitePath: dbPath} + + db, err := OpenDatabase(cfg) + if err != nil { + t.Fatalf("first OpenDatabase returned error: %v", err) + } + closeOpenedDatabase(t, db) + + db, err = OpenDatabase(cfg) + if err != nil { + t.Fatalf("second OpenDatabase returned error: %v", err) + } + closeTestDatabase(t, db) + + var count int64 + if err := db.Table("schema_migrations").Count(&count).Error; err != nil { + t.Fatalf("count schema migrations: %v", err) + } + if count != 10 { + t.Fatalf("expected 10 applied migrations after reopening database, got %d", count) + } +} + +func TestOpenDatabaseLogsSchemaMigrations(t *testing.T) { + logs := captureRepositoryLogs(t, logrus.InfoLevel) + dbPath := filepath.Join(t.TempDir(), "app.db") + cfg := config.Config{SQLitePath: dbPath} + + db, err := OpenDatabase(cfg) + if err != nil { + t.Fatalf("first OpenDatabase returned error: %v", err) + } + closeOpenedDatabase(t, db) + + db, err = OpenDatabase(cfg) + if err != nil { + t.Fatalf("second OpenDatabase returned error: %v", err) + } + closeTestDatabase(t, db) + + content := logs.String() + for _, want := range []string{ + "level=info", + "msg=\"schema migration started\"", + "msg=\"schema migration applied\"", + "msg=\"schema migration skipped\"", + "version=20260503_add_usage_event_redis_fields", + "version=20260504_migrate_usage_identities_metadata", + "version=20260504_drop_legacy_metadata_tables", + } { + if !strings.Contains(content, want) { + t.Fatalf("expected migration logs to contain %q, got:\n%s", want, content) + } + } +} + +func TestRunSchemaMigrationLogsErrors(t *testing.T) { + logs := captureRepositoryLogs(t, logrus.InfoLevel) + db, err := gorm.Open(sqlite.Open(sqliteDSN(filepath.Join(t.TempDir(), "app.db"))), &gorm.Config{}) + if err != nil { + t.Fatalf("open database: %v", err) + } + defer closeOpenedDatabase(t, db) + if err := db.Exec("CREATE TABLE schema_migrations (version TEXT PRIMARY KEY, applied_at DATETIME NOT NULL)").Error; err != nil { + t.Fatalf("create schema_migrations: %v", err) + } + + err = runSchemaMigration(db, databaseMigration{ + version: "test_failure", + run: func(*gorm.DB) error { + return fmt.Errorf("boom") + }, + }) + if err == nil { + t.Fatal("expected migration error") + } + + content := logs.String() + for _, want := range []string{ + "level=info", + "msg=\"schema migration started\"", + "version=test_failure", + "level=error", + "msg=\"schema migration failed\"", + "error=boom", + } { + if !strings.Contains(content, want) { + t.Fatalf("expected migration error logs to contain %q, got:\n%s", want, content) + } + } +} + +func TestOpenDatabaseDropsLegacySnapshotRunsTable(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + db, err := gorm.Open(sqlite.Open(sqliteDSN(dbPath)), &gorm.Config{}) + if err != nil { + t.Fatalf("open legacy database: %v", err) + } + if err := db.Exec(`CREATE TABLE snapshot_runs (id integer PRIMARY KEY AUTOINCREMENT, fetched_at datetime, status text)`).Error; err != nil { + t.Fatalf("create legacy snapshot_runs table: %v", err) + } + if err := db.Exec(`INSERT INTO snapshot_runs (fetched_at, status) VALUES (?, ?)`, time.Date(2026, 5, 3, 8, 0, 0, 0, time.UTC), "completed").Error; err != nil { + t.Fatalf("seed legacy snapshot_runs table: %v", err) + } + closeOpenedDatabase(t, db) + + db = openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + if db.Migrator().HasTable("snapshot_runs") { + t.Fatal("expected legacy snapshot_runs table to be dropped") + } + var count int64 + if err := db.Table("schema_migrations").Where("version = ?", "20260503_drop_snapshot_runs").Count(&count).Error; err != nil { + t.Fatalf("count drop snapshot migration: %v", err) + } + if count != 1 { + t.Fatalf("expected drop snapshot migration to be recorded once, got %d", count) + } +} + +func TestOpenDatabaseDropsLegacySnapshotRunIDColumns(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + seedLegacyRedisUsageTables(t, dbPath) + + db := openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + if db.Migrator().HasColumn(&models.UsageEvent{}, "snapshot_run_id") { + t.Fatal("expected usage_events.snapshot_run_id to be dropped") + } + if db.Migrator().HasColumn(&models.RedisUsageInbox{}, "snapshot_run_id") { + t.Fatal("expected redis_usage_inboxes.snapshot_run_id to be dropped") + } + var oldIndexCount int64 + if err := db.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type = 'index' AND name IN (?, ?)", "idx_usage_events_snapshot_run_id", "idx_redis_usage_inboxes_snapshot_run_id").Scan(&oldIndexCount).Error; err != nil { + t.Fatalf("count legacy snapshot_run_id indexes: %v", err) + } + if oldIndexCount != 0 { + t.Fatalf("expected legacy snapshot_run_id indexes to be dropped, got %d", oldIndexCount) + } + var migrationCount int64 + if err := db.Table("schema_migrations").Where("version = ?", "20260504_drop_legacy_snapshot_run_columns").Count(&migrationCount).Error; err != nil { + t.Fatalf("count drop snapshot_run_id columns migration: %v", err) + } + if migrationCount != 1 { + t.Fatalf("expected drop snapshot_run_id columns migration to be recorded once, got %d", migrationCount) + } +} + +func TestOpenDatabaseBackfillsUsageEventRedisFieldsByUsageEventKey(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + seedLegacyRedisUsageTables(t, dbPath) + + db := openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + var event models.UsageEvent + if err := db.Where("event_key = ?", "legacy-canonical-key").First(&event).Error; err != nil { + t.Fatalf("load usage event: %v", err) + } + if event.Provider != "claude" || event.Endpoint != "/v1/messages" || event.AuthType != "apikey" || event.RequestID != "req-from-raw" { + t.Fatalf("expected backfill by usage_event_key, got %+v", event) + } +} + +func TestOpenDatabaseBackfillsUsageEventRedisFieldsByRawRequestIDFallback(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + seedLegacyRedisUsageTables(t, dbPath) + + db := openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + var event models.UsageEvent + if err := db.Where("event_key = ?", "req-fallback").First(&event).Error; err != nil { + t.Fatalf("load fallback usage event: %v", err) + } + if event.Provider != "fallback-provider" || event.Endpoint != "/fallback" || event.AuthType != "oauth" || event.RequestID != "req-fallback" { + t.Fatalf("expected fallback backfill by raw request_id, got %+v", event) + } +} + +func TestOpenDatabaseBackfillsUsageEventRedisFieldsByRawRequestIDWhenUsageEventKeyIsBlank(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + seedLegacyRedisUsageTables(t, dbPath) + + db := openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + var event models.UsageEvent + if err := db.Where("event_key = ?", "req-blank-fallback").First(&event).Error; err != nil { + t.Fatalf("load blank fallback usage event: %v", err) + } + if event.Provider != "blank-provider" || event.Endpoint != "/blank" || event.AuthType != "oauth" || event.RequestID != "req-blank-fallback" { + t.Fatalf("expected blank usage_event_key to fall back by raw request_id, got %+v", event) + } + + var emptyEvent models.UsageEvent + if err := db.Where("event_key = ?", "").First(&emptyEvent).Error; err != nil { + t.Fatalf("load empty-key usage event: %v", err) + } + if emptyEvent.Provider != "" || emptyEvent.Endpoint != "" || emptyEvent.AuthType != "" || emptyEvent.RequestID != "" { + t.Fatalf("expected empty-key usage event to remain unchanged, got %+v", emptyEvent) + } +} + +func TestOpenDatabaseBackfillDoesNotOverwriteExistingUsageEventFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + seedLegacyRedisUsageTables(t, dbPath) + + // 模拟目标列已经有值的部分迁移数据库。 + db, err := gorm.Open(sqlite.Open(sqliteDSN(dbPath)), &gorm.Config{}) + if err != nil { + t.Fatalf("open partially migrated database: %v", err) + } + for _, statement := range []string{ + "ALTER TABLE usage_events ADD COLUMN provider TEXT", + "ALTER TABLE usage_events ADD COLUMN endpoint TEXT", + "ALTER TABLE usage_events ADD COLUMN auth_type TEXT", + "ALTER TABLE usage_events ADD COLUMN request_id TEXT", + "UPDATE usage_events SET provider = 'existing-provider', endpoint = 'existing-endpoint', auth_type = 'existing-auth', request_id = 'existing-request' WHERE event_key = 'existing-key'", + } { + if err := db.Exec(statement).Error; err != nil { + t.Fatalf("prepare partially migrated database with %q: %v", statement, err) + } + } + closeOpenedDatabase(t, db) + + db = openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + var event models.UsageEvent + if err := db.Where("event_key = ?", "existing-key").First(&event).Error; err != nil { + t.Fatalf("load existing usage event: %v", err) + } + if event.Provider != "existing-provider" || event.Endpoint != "existing-endpoint" || event.AuthType != "existing-auth" || event.RequestID != "existing-request" { + t.Fatalf("expected existing fields to remain unchanged, got %+v", event) + } +} + +func TestOpenDatabaseUsageIdentityMigratesLegacyMetadataAndDropsOldTables(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy-identities.db") + seedLegacyUsageIdentityTables(t, dbPath) + + db := openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + if !db.Migrator().HasTable(&models.UsageIdentity{}) { + t.Fatal("expected usage_identities table to exist") + } + if db.Migrator().HasTable("auth_files") { + t.Fatal("expected auth_files table to be dropped") + } + if db.Migrator().HasTable("provider_metadata") { + t.Fatal("expected provider_metadata table to be dropped") + } + + var identities []models.UsageIdentity + if err := db.Order("auth_type asc, identity asc").Find(&identities).Error; err != nil { + t.Fatalf("load usage identities: %v", err) + } + if len(identities) != 4 { + t.Fatalf("expected all 4 legacy usage identities, got %d: %+v", len(identities), identities) + } + + oauth := findUsageIdentity(t, identities, models.UsageIdentityAuthTypeAuthFile, "auth-1") + if oauth.Name != "person@example.com" || oauth.AuthTypeName != "oauth" || oauth.Type != "claude" || oauth.Provider != "claude" { + t.Fatalf("unexpected oauth identity mapping: %+v", oauth) + } + if oauth.TotalRequests != 3 || oauth.SuccessCount != 2 || oauth.FailureCount != 1 || oauth.InputTokens != 31 || oauth.OutputTokens != 41 || oauth.ReasoningTokens != 11 || oauth.CachedTokens != 7 || oauth.TotalTokens != 90 { + t.Fatalf("unexpected oauth identity stats: %+v", oauth) + } + if oauth.FirstUsedAt == nil || !oauth.FirstUsedAt.Equal(time.Date(2026, 5, 4, 8, 0, 0, 0, time.UTC)) { + t.Fatalf("unexpected oauth first used timestamp: %+v", oauth.FirstUsedAt) + } + if oauth.LastUsedAt == nil || !oauth.LastUsedAt.Equal(time.Date(2026, 5, 4, 9, 0, 0, 0, time.UTC)) { + t.Fatalf("unexpected oauth last used timestamp: %+v", oauth.LastUsedAt) + } + if oauth.StatsUpdatedAt == nil { + t.Fatal("expected oauth stats_updated_at to be set") + } + if oauth.LastAggregatedUsageEventID != 3 { + t.Fatalf("expected oauth last aggregated usage event id 3, got %d", oauth.LastAggregatedUsageEventID) + } + + provider := findUsageIdentity(t, identities, models.UsageIdentityAuthTypeAIProvider, "api-source-1") + if provider.Name != "Claude API" || provider.AuthTypeName != "apikey" || provider.Type != "claude" || provider.Provider != "Claude API" { + t.Fatalf("unexpected provider identity mapping: %+v", provider) + } + if provider.TotalRequests != 2 || provider.SuccessCount != 2 || provider.FailureCount != 0 || provider.InputTokens != 9 || provider.OutputTokens != 9 || provider.ReasoningTokens != 10 || provider.CachedTokens != 11 || provider.TotalTokens != 39 { + t.Fatalf("unexpected provider identity stats: %+v", provider) + } + if provider.FirstUsedAt == nil || !provider.FirstUsedAt.Equal(time.Date(2026, 5, 4, 10, 0, 0, 0, time.UTC)) || provider.LastUsedAt == nil || !provider.LastUsedAt.Equal(time.Date(2026, 5, 4, 10, 30, 0, 0, time.UTC)) { + t.Fatalf("unexpected provider usage timestamps: first=%+v last=%+v", provider.FirstUsedAt, provider.LastUsedAt) + } + if provider.StatsUpdatedAt == nil { + t.Fatal("expected provider stats_updated_at to be set") + } + if provider.LastAggregatedUsageEventID != 5 { + t.Fatalf("expected provider last aggregated usage event id 5, got %d", provider.LastAggregatedUsageEventID) + } + + deletedOAuth := findUsageIdentity(t, identities, models.UsageIdentityAuthTypeAuthFile, "auth-deleted") + if !deletedOAuth.IsDeleted || deletedOAuth.DeletedAt == nil || !deletedOAuth.DeletedAt.Equal(time.Date(2026, 5, 4, 7, 30, 0, 0, time.UTC)) { + t.Fatalf("expected deleted auth file state to be preserved, got %+v", deletedOAuth) + } + if deletedOAuth.TotalRequests != 1 || deletedOAuth.TotalTokens != 100 || deletedOAuth.LastAggregatedUsageEventID != 6 { + t.Fatalf("expected deleted auth file usage stats to be backfilled, got %+v", deletedOAuth) + } + + deletedProvider := findUsageIdentity(t, identities, models.UsageIdentityAuthTypeAIProvider, "api-deleted") + if !deletedProvider.IsDeleted || deletedProvider.DeletedAt == nil || !deletedProvider.DeletedAt.Equal(time.Date(2026, 5, 4, 7, 30, 0, 0, time.UTC)) { + t.Fatalf("expected deleted provider state to be preserved, got %+v", deletedProvider) + } + if deletedProvider.TotalRequests != 1 || deletedProvider.TotalTokens != 100 || deletedProvider.LastAggregatedUsageEventID != 7 { + t.Fatalf("expected deleted provider usage stats to be backfilled, got %+v", deletedProvider) + } +} + +func TestOpenDatabaseBackfillsUsageEventIdentityFieldsFromUsageIdentities(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy-identities.db") + seedLegacyUsageIdentityTables(t, dbPath) + + db, err := gorm.Open(sqlite.Open(sqliteDSN(dbPath)), &gorm.Config{}) + if err != nil { + t.Fatalf("open seeded legacy database: %v", err) + } + if err := db.Exec("UPDATE usage_events SET provider = '' WHERE event_key IN (?, ?)", "legacy-apikey", "legacy-oauth").Error; err != nil { + t.Fatalf("blank legacy usage event providers: %v", err) + } + if err := db.Exec("UPDATE usage_events SET provider = ? WHERE event_key = ?", "existing-provider", "apikey-success").Error; err != nil { + t.Fatalf("set existing provider: %v", err) + } + closeOpenedDatabase(t, db) + + db = openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + var legacyProvider models.UsageEvent + if err := db.Where("event_key = ?", "legacy-apikey").First(&legacyProvider).Error; err != nil { + t.Fatalf("load legacy provider event: %v", err) + } + if legacyProvider.AuthType != "apikey" || legacyProvider.Provider != "Claude API" { + t.Fatalf("expected legacy provider event identity fields to be backfilled, got %+v", legacyProvider) + } + + var legacyOAuth models.UsageEvent + if err := db.Where("event_key = ?", "legacy-oauth").First(&legacyOAuth).Error; err != nil { + t.Fatalf("load legacy oauth event: %v", err) + } + if legacyOAuth.AuthType != "oauth" { + t.Fatalf("expected legacy oauth event auth_type to be backfilled, got %+v", legacyOAuth) + } + + var existingProvider models.UsageEvent + if err := db.Where("event_key = ?", "apikey-success").First(&existingProvider).Error; err != nil { + t.Fatalf("load existing provider event: %v", err) + } + if existingProvider.Provider != "existing-provider" { + t.Fatalf("expected existing provider field to remain unchanged, got %+v", existingProvider) + } + + var providerFilterCount int64 + if err := db.Model(&models.UsageEvent{}).Where("auth_type = ? AND provider = ?", "apikey", "Claude API").Count(&providerFilterCount).Error; err != nil { + t.Fatalf("count provider-filtered usage events: %v", err) + } + if providerFilterCount != 1 { + t.Fatalf("expected provider filter to match migrated legacy event, got %d", providerFilterCount) + } +} + +func TestOpenDatabaseUsageIdentityMigrationsAreIdempotent(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy-identities.db") + seedLegacyUsageIdentityTables(t, dbPath) + + db := openMigratedDatabase(t, dbPath) + closeOpenedDatabase(t, db) + db = openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + var identities []models.UsageIdentity + if err := db.Order("auth_type asc, identity asc").Find(&identities).Error; err != nil { + t.Fatalf("load usage identities after reopen: %v", err) + } + if len(identities) != 4 { + t.Fatalf("expected all 4 usage identities after reopen, got %d: %+v", len(identities), identities) + } + oauth := findUsageIdentity(t, identities, models.UsageIdentityAuthTypeAuthFile, "auth-1") + if oauth.TotalRequests != 3 || oauth.TotalTokens != 90 || oauth.LastAggregatedUsageEventID != 3 { + t.Fatalf("expected oauth stats not to double-add after reopen, got %+v", oauth) + } + provider := findUsageIdentity(t, identities, models.UsageIdentityAuthTypeAIProvider, "api-source-1") + if provider.TotalRequests != 2 || provider.TotalTokens != 39 || provider.LastAggregatedUsageEventID != 5 { + t.Fatalf("expected provider stats not to double-add after reopen, got %+v", provider) + } + deletedOAuth := findUsageIdentity(t, identities, models.UsageIdentityAuthTypeAuthFile, "auth-deleted") + if !deletedOAuth.IsDeleted || deletedOAuth.TotalRequests != 1 || deletedOAuth.TotalTokens != 100 || deletedOAuth.LastAggregatedUsageEventID != 6 { + t.Fatalf("expected deleted oauth stats not to double-add after reopen, got %+v", deletedOAuth) + } + deletedProvider := findUsageIdentity(t, identities, models.UsageIdentityAuthTypeAIProvider, "api-deleted") + if !deletedProvider.IsDeleted || deletedProvider.TotalRequests != 1 || deletedProvider.TotalTokens != 100 || deletedProvider.LastAggregatedUsageEventID != 7 { + t.Fatalf("expected deleted provider stats not to double-add after reopen, got %+v", deletedProvider) + } + + var duplicateVersions int64 + if err := db.Table("schema_migrations").Select("COUNT(*) - COUNT(DISTINCT version)").Scan(&duplicateVersions).Error; err != nil { + t.Fatalf("count duplicate schema migration versions: %v", err) + } + if duplicateVersions != 0 { + t.Fatalf("expected no duplicate schema migration versions, got %d", duplicateVersions) + } + for _, version := range []string{"20260504_create_usage_identities", "20260504_migrate_usage_identities_metadata", "20260504_backfill_usage_event_identity_fields", "20260504_backfill_usage_identity_stats", "20260504_drop_legacy_metadata_tables", "20260504_drop_legacy_snapshot_run_columns", "20260504_remove_prefix_usage_identities"} { + var count int64 + if err := db.Table("schema_migrations").Where("version = ?", version).Count(&count).Error; err != nil { + t.Fatalf("count schema migration %s: %v", version, err) + } + if count != 1 { + t.Fatalf("expected schema migration %s to be recorded once, got %d", version, count) + } + } + if db.Migrator().HasTable("auth_files") || db.Migrator().HasTable("provider_metadata") { + t.Fatal("expected old metadata tables to stay dropped after reopen") + } +} + +func TestOpenDatabaseSkipsUsageIdentityMetadataMigrationWhenLegacyTablesAreMissing(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "no-legacy-identities.db") + + db := openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + if !db.Migrator().HasTable(&models.UsageIdentity{}) { + t.Fatal("expected usage_identities table to exist") + } + var count int64 + if err := db.Model(&models.UsageIdentity{}).Count(&count).Error; err != nil { + t.Fatalf("count usage identities: %v", err) + } + if count != 0 { + t.Fatalf("expected no usage identities without legacy metadata, got %d", count) + } + if db.Migrator().HasTable("auth_files") || db.Migrator().HasTable("provider_metadata") { + t.Fatal("expected legacy metadata tables not to be recreated") + } +} + +func TestOpenDatabaseRemovesPrefixGeneratedUsageIdentities(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "prefix-identities.db") + seedPrefixGeneratedUsageIdentities(t, dbPath) + + db := openMigratedDatabase(t, dbPath) + defer closeOpenedDatabase(t, db) + + for _, prefix := range []string{"gemini", "claude", "codex", "vertex", "openai"} { + var prefixCount int64 + if err := db.Model(&models.UsageIdentity{}).Where("auth_type = ? AND identity = ?", models.UsageIdentityAuthTypeAIProvider, prefix).Count(&prefixCount).Error; err != nil { + t.Fatalf("count prefix usage identity %q: %v", prefix, err) + } + if prefixCount != 0 { + t.Fatalf("expected fixed prefix usage identity %q to be removed, got %d", prefix, prefixCount) + } + } + + var apiKey models.UsageIdentity + if err := db.Where("auth_type = ? AND identity = ?", models.UsageIdentityAuthTypeAIProvider, "claude-key").First(&apiKey).Error; err != nil { + t.Fatalf("load real api key identity: %v", err) + } + if apiKey.TotalRequests != 1 || apiKey.LastAggregatedUsageEventID != 1 { + t.Fatalf("expected real api key identity stats to remain, got %+v", apiKey) + } + + var unusedSiblingKey models.UsageIdentity + if err := db.Where("auth_type = ? AND identity = ?", models.UsageIdentityAuthTypeAIProvider, "claude-unused-key").First(&unusedSiblingKey).Error; err != nil { + t.Fatalf("load unused sibling api key identity: %v", err) + } + if unusedSiblingKey.Type != "claude" || unusedSiblingKey.Provider != "Claude Team" { + t.Fatalf("expected unused sibling api key identity to remain unchanged, got %+v", unusedSiblingKey) + } + + var unusedKey models.UsageIdentity + if err := db.Where("auth_type = ? AND identity = ?", models.UsageIdentityAuthTypeAIProvider, "gemini-unused-key").First(&unusedKey).Error; err != nil { + t.Fatalf("load unused real api key identity: %v", err) + } + if unusedKey.Type != "gemini" || unusedKey.Provider != "Gemini Team" { + t.Fatalf("expected unused real api key identity to remain unchanged, got %+v", unusedKey) + } + + var customPrefix models.UsageIdentity + if err := db.Where("auth_type = ? AND identity = ?", models.UsageIdentityAuthTypeAIProvider, "https://proxy.internal/v1").First(&customPrefix).Error; err != nil { + t.Fatalf("load custom prefix-like identity: %v", err) + } + if customPrefix.Type != "openai" || customPrefix.Provider != "Custom OpenAI" { + t.Fatalf("expected non-fixed custom prefix-like identity to remain unchanged, got %+v", customPrefix) + } +} + +func findUsageIdentity(t *testing.T, identities []models.UsageIdentity, authType models.UsageIdentityAuthType, identity string) models.UsageIdentity { + t.Helper() + for _, usageIdentity := range identities { + if usageIdentity.AuthType == authType && usageIdentity.Identity == identity { + return usageIdentity + } + } + t.Fatalf("usage identity auth_type=%d identity=%q not found in %+v", authType, identity, identities) + return models.UsageIdentity{} +} + +func seedPrefixGeneratedUsageIdentities(t *testing.T, dbPath string) { + t.Helper() + db, err := gorm.Open(sqlite.Open(sqliteDSN(dbPath)), &gorm.Config{}) + if err != nil { + t.Fatalf("open prefix identity database: %v", err) + } + defer closeOpenedDatabase(t, db) + + if err := db.Exec(`CREATE TABLE usage_identities ( + id integer PRIMARY KEY AUTOINCREMENT, + name text, + auth_type integer, + auth_type_name text, + identity text, + type text, + provider text, + total_requests integer DEFAULT 0, + success_count integer DEFAULT 0, + failure_count integer DEFAULT 0, + input_tokens integer DEFAULT 0, + output_tokens integer DEFAULT 0, + reasoning_tokens integer DEFAULT 0, + cached_tokens integer DEFAULT 0, + total_tokens integer DEFAULT 0, + last_aggregated_usage_event_id integer DEFAULT 0, + first_used_at datetime, + last_used_at datetime, + stats_updated_at datetime, + is_deleted numeric DEFAULT false, + created_at datetime, + updated_at datetime, + deleted_at datetime + )`).Error; err != nil { + t.Fatalf("create usage_identities table: %v", err) + } + if err := db.Exec(`CREATE UNIQUE INDEX uniq_usage_identities_type_identity ON usage_identities(auth_type, identity)`).Error; err != nil { + t.Fatalf("create usage identity unique index: %v", err) + } + if err := db.Exec(`CREATE TABLE usage_events ( + id integer PRIMARY KEY AUTOINCREMENT, + event_key text, + api_group_key text, + provider text, + endpoint text, + auth_type text, + request_id text, + model text, + timestamp datetime, + source text, + auth_index text, + failed numeric, + latency_ms integer, + input_tokens integer, + output_tokens integer, + reasoning_tokens integer, + cached_tokens integer, + total_tokens integer, + created_at datetime + )`).Error; err != nil { + t.Fatalf("create usage_events table: %v", err) + } + + now := time.Date(2026, 5, 4, 8, 0, 0, 0, time.UTC) + rows := []models.UsageIdentity{ + {Name: "Claude Team", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "claude-key", Type: "claude", Provider: "Claude Team", TotalRequests: 1, SuccessCount: 1, TotalTokens: 30, LastAggregatedUsageEventID: 1, CreatedAt: now, UpdatedAt: now}, + {Name: "Claude Team", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "claude-unused-key", Type: "claude", Provider: "Claude Team", CreatedAt: now, UpdatedAt: now}, + {Name: "Gemini Team", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "gemini", Type: "gemini", Provider: "Gemini Team", TotalRequests: 2, SuccessCount: 2, TotalTokens: 40, LastAggregatedUsageEventID: 2, CreatedAt: now, UpdatedAt: now}, + {Name: "Claude Team", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "claude", Type: "claude", Provider: "Claude Team", CreatedAt: now, UpdatedAt: now}, + {Name: "Codex Team", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "codex", Type: "codex", Provider: "Codex Team", CreatedAt: now, UpdatedAt: now}, + {Name: "Vertex Team", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "vertex", Type: "vertex", Provider: "Vertex Team", CreatedAt: now, UpdatedAt: now}, + {Name: "OpenAI Team", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "openai", Type: "openai", Provider: "OpenAI Team", CreatedAt: now, UpdatedAt: now}, + {Name: "Gemini Team", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "gemini-unused-key", Type: "gemini", Provider: "Gemini Team", CreatedAt: now, UpdatedAt: now}, + {Name: "Custom OpenAI", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "https://proxy.internal/v1", Type: "openai", Provider: "Custom OpenAI", CreatedAt: now, UpdatedAt: now}, + } + if err := db.Create(&rows).Error; err != nil { + t.Fatalf("seed usage identities: %v", err) + } + if err := db.Exec(`INSERT INTO usage_events (event_key, api_group_key, provider, endpoint, auth_type, request_id, model, timestamp, source, failed, latency_ms, input_tokens, output_tokens, total_tokens, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, "claude-event", "group", "Claude Team", "/v1/messages", "apikey", "req", "claude-sonnet", now, "claude-key", false, 100, 10, 20, 30, now).Error; err != nil { + t.Fatalf("seed usage event: %v", err) + } +} + +func seedLegacyUsageIdentityTables(t *testing.T, dbPath string) { + t.Helper() + db, err := gorm.Open(sqlite.Open(sqliteDSN(dbPath)), &gorm.Config{}) + if err != nil { + t.Fatalf("open legacy identity database: %v", err) + } + defer closeOpenedDatabase(t, db) + + statements := []string{ + `CREATE TABLE auth_files ( + id integer PRIMARY KEY AUTOINCREMENT, + auth_index text, + name text, + email text, + type text, + provider text, + label text, + created_at datetime, + updated_at datetime, + deleted_at datetime + )`, + `CREATE TABLE provider_metadata ( + id integer PRIMARY KEY AUTOINCREMENT, + lookup_key text, + provider_type text, + display_name text, + created_at datetime, + updated_at datetime, + deleted_at datetime + )`, + `CREATE TABLE usage_events ( + id integer PRIMARY KEY AUTOINCREMENT, + event_key text, + api_group_key text, + provider text, + endpoint text, + auth_type text, + request_id text, + model text, + timestamp datetime, + source text, + auth_index text, + failed numeric, + latency_ms integer, + input_tokens integer, + output_tokens integer, + reasoning_tokens integer, + cached_tokens integer, + total_tokens integer, + created_at datetime + )`, + } + for _, statement := range statements { + if err := db.Exec(statement).Error; err != nil { + t.Fatalf("seed legacy identity schema with %q: %v", statement, err) + } + } + + now := time.Date(2026, 5, 4, 7, 0, 0, 0, time.UTC) + deletedAt := time.Date(2026, 5, 4, 7, 30, 0, 0, time.UTC) + if err := db.Exec("INSERT INTO auth_files (auth_index, name, email, type, provider, label, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", "auth-1", "OAuth Name", "person@example.com", "claude", "claude", "OAuth Label", now, now, nil).Error; err != nil { + t.Fatalf("seed active auth file: %v", err) + } + if err := db.Exec("INSERT INTO auth_files (auth_index, name, email, type, provider, label, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", "auth-deleted", "Deleted OAuth", "deleted@example.com", "claude", "claude", "Deleted", now, now, deletedAt).Error; err != nil { + t.Fatalf("seed deleted auth file: %v", err) + } + if err := db.Exec("INSERT INTO provider_metadata (lookup_key, provider_type, display_name, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?)", "api-source-1", "claude", "Claude API", now, now, nil).Error; err != nil { + t.Fatalf("seed active provider metadata: %v", err) + } + if err := db.Exec("INSERT INTO provider_metadata (lookup_key, provider_type, display_name, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?)", "api-deleted", "claude", "Deleted API", now, now, deletedAt).Error; err != nil { + t.Fatalf("seed deleted provider metadata: %v", err) + } + + events := []struct { + eventKey string + authType string + authIndex string + source string + failed bool + inputTokens int64 + outputTokens int64 + reasoningTokens int64 + cachedTokens int64 + totalTokens int64 + timestamp time.Time + }{ + {eventKey: "oauth-success", authType: "oauth", authIndex: "auth-1", failed: false, inputTokens: 10, outputTokens: 20, reasoningTokens: 3, cachedTokens: 4, totalTokens: 37, timestamp: time.Date(2026, 5, 4, 8, 0, 0, 0, time.UTC)}, + {eventKey: "legacy-oauth", authIndex: "auth-1", failed: false, inputTokens: 1, outputTokens: 1, reasoningTokens: 1, cachedTokens: 1, totalTokens: 4, timestamp: time.Date(2026, 5, 4, 8, 30, 0, 0, time.UTC)}, + {eventKey: "oauth-failure", authType: "oauth", authIndex: "auth-1", failed: true, inputTokens: 20, outputTokens: 20, reasoningTokens: 7, cachedTokens: 2, totalTokens: 49, timestamp: time.Date(2026, 5, 4, 9, 0, 0, 0, time.UTC)}, + {eventKey: "apikey-success", authType: "apikey", source: "api-source-1", failed: false, inputTokens: 7, outputTokens: 8, reasoningTokens: 9, cachedTokens: 10, totalTokens: 34, timestamp: time.Date(2026, 5, 4, 10, 0, 0, 0, time.UTC)}, + {eventKey: "legacy-apikey", source: "api-source-1", failed: false, inputTokens: 2, outputTokens: 1, reasoningTokens: 1, cachedTokens: 1, totalTokens: 5, timestamp: time.Date(2026, 5, 4, 10, 30, 0, 0, time.UTC)}, + {eventKey: "deleted-oauth", authType: "oauth", authIndex: "auth-deleted", failed: false, totalTokens: 100, timestamp: time.Date(2026, 5, 4, 11, 0, 0, 0, time.UTC)}, + {eventKey: "deleted-api", authType: "apikey", source: "api-deleted", failed: false, totalTokens: 100, timestamp: time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC)}, + } + for _, event := range events { + if err := db.Exec( + `INSERT INTO usage_events (event_key, api_group_key, provider, endpoint, auth_type, request_id, model, timestamp, source, auth_index, failed, latency_ms, input_tokens, output_tokens, reasoning_tokens, cached_tokens, total_tokens, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + event.eventKey, "group", "claude", "/v1/messages", event.authType, event.eventKey, "claude-sonnet", event.timestamp, event.source, event.authIndex, event.failed, 100, event.inputTokens, event.outputTokens, event.reasoningTokens, event.cachedTokens, event.totalTokens, event.timestamp, + ).Error; err != nil { + t.Fatalf("seed usage event %s: %v", event.eventKey, err) + } + } +} + +func seedLegacyRedisUsageTables(t *testing.T, dbPath string) { + t.Helper() + db, err := gorm.Open(sqlite.Open(sqliteDSN(dbPath)), &gorm.Config{}) + if err != nil { + t.Fatalf("open legacy database: %v", err) + } + defer closeOpenedDatabase(t, db) + + statements := []string{ + `CREATE TABLE usage_events ( + id integer PRIMARY KEY AUTOINCREMENT, + event_key text, + snapshot_run_id integer, + api_group_key text, + model text, + timestamp datetime, + source text, + auth_index text, + failed numeric, + latency_ms integer, + input_tokens integer, + output_tokens integer, + reasoning_tokens integer, + cached_tokens integer, + total_tokens integer, + created_at datetime + )`, + `CREATE UNIQUE INDEX uniq_usage_events_event_key ON usage_events(event_key)`, + `CREATE TABLE redis_usage_inboxes ( + id integer PRIMARY KEY AUTOINCREMENT, + queue_key text NOT NULL DEFAULT '', + message_hash text NOT NULL DEFAULT '', + raw_message text NOT NULL DEFAULT '', + status text NOT NULL DEFAULT '', + attempt_count integer NOT NULL DEFAULT 0, + last_error text, + snapshot_run_id integer, + usage_event_key text, + popped_at datetime NOT NULL DEFAULT '1970-01-01 00:00:00', + processed_at datetime, + created_at datetime, + updated_at datetime + )`, + } + for _, statement := range statements { + if err := db.Exec(statement).Error; err != nil { + t.Fatalf("seed legacy schema with %q: %v", statement, err) + } + } + + now := time.Date(2026, 5, 3, 8, 0, 0, 0, time.UTC) + legacyEvents := []map[string]any{ + {"event_key": "legacy-canonical-key", "api_group_key": "raw-key", "model": "claude-sonnet", "timestamp": now, "created_at": now}, + {"event_key": "req-fallback", "api_group_key": "fallback", "model": "claude-opus", "timestamp": now, "created_at": now}, + {"event_key": "req-blank-fallback", "api_group_key": "blank", "model": "claude-opus", "timestamp": now, "created_at": now}, + {"event_key": "", "api_group_key": "empty", "model": "claude-empty", "timestamp": now, "created_at": now}, + {"event_key": "existing-key", "api_group_key": "existing", "model": "claude-haiku", "timestamp": now, "created_at": now}, + } + for _, values := range legacyEvents { + if err := db.Table("usage_events").Create(values).Error; err != nil { + t.Fatalf("seed legacy usage event: %v", err) + } + } + + inboxes := []struct { + hash string + rawMessage string + status string + usageEventKey string + processedAt *time.Time + }{ + {hash: "hash-1", rawMessage: `{"provider":" claude ","endpoint":" /v1/messages ","auth_type":" API_KEY ","request_id":" req-from-raw "}`, status: RedisUsageInboxStatusProcessed, usageEventKey: "legacy-canonical-key", processedAt: &now}, + {hash: "hash-2", rawMessage: `{"provider":" fallback-provider ","endpoint":" /fallback ","auth_type":" OAuth ","request_id":" req-fallback "}`, status: RedisUsageInboxStatusProcessed, usageEventKey: "missing-key", processedAt: &now}, + {hash: "hash-3", rawMessage: `{"provider":" overwrite-provider ","endpoint":" /overwrite ","auth_type":" api_key ","request_id":" overwrite-request "}`, status: RedisUsageInboxStatusProcessed, usageEventKey: "existing-key", processedAt: &now}, + {hash: "hash-4", rawMessage: `{"provider":" blank-provider ","endpoint":" /blank ","auth_type":" OAuth ","request_id":" req-blank-fallback "}`, status: RedisUsageInboxStatusProcessed, usageEventKey: "", processedAt: &now}, + {hash: "hash-5", rawMessage: `{"provider":"pending-provider","request_id":"pending-key"}`, status: RedisUsageInboxStatusPending, usageEventKey: "pending-key"}, + } + for _, inbox := range inboxes { + if err := db.Exec( + "INSERT INTO redis_usage_inboxes (queue_key, message_hash, raw_message, status, attempt_count, usage_event_key, popped_at, processed_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + "queue", inbox.hash, inbox.rawMessage, inbox.status, 0, inbox.usageEventKey, now, inbox.processedAt, now, now, + ).Error; err != nil { + t.Fatalf("seed legacy redis inbox: %v", err) + } + } +} + +func openMigratedDatabase(t *testing.T, dbPath string) *gorm.DB { + t.Helper() + db, err := OpenDatabase(config.Config{SQLitePath: dbPath}) + if err != nil { + t.Fatalf("OpenDatabase returned error: %v", err) + } + return db +} + +func captureRepositoryLogs(t *testing.T, level logrus.Level) *bytes.Buffer { + t.Helper() + var logs bytes.Buffer + previousOutput := logrus.StandardLogger().Out + previousFormatter := logrus.StandardLogger().Formatter + previousLevel := logrus.GetLevel() + logrus.SetOutput(&logs) + logrus.SetFormatter(&logrus.TextFormatter{DisableTimestamp: true}) + logrus.SetLevel(level) + t.Cleanup(func() { + logrus.SetOutput(previousOutput) + logrus.SetFormatter(previousFormatter) + logrus.SetLevel(previousLevel) + }) + return &logs +} + +func closeOpenedDatabase(t *testing.T, db *gorm.DB) { + t.Helper() + sqlDB, err := db.DB() + if err != nil { + t.Fatalf("get sql database: %v", err) + } + if err := sqlDB.Close(); err != nil { + t.Fatalf("close database: %v", err) + } +} diff --git a/internal/repository/provider_metadata.go b/internal/repository/provider_metadata.go deleted file mode 100644 index efefc9b3..00000000 --- a/internal/repository/provider_metadata.go +++ /dev/null @@ -1,163 +0,0 @@ -package repository - -import ( - "fmt" - "strings" - - "cpa-usage-keeper/internal/models" - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -type ProviderMetadataInput struct { - LookupKey string - ProviderType string - DisplayName string - ProviderKey string - MatchKind string -} - -func ReplaceProviderMetadata(db *gorm.DB, items []ProviderMetadataInput) error { - if db == nil { - return fmt.Errorf("database is nil") - } - - normalized := make([]models.ProviderMetadata, 0, len(items)) - lookupKeys := make([]string, 0, len(items)) - seen := make(map[string]struct{}, len(items)) - for _, item := range items { - lookupKey := strings.TrimSpace(item.LookupKey) - if lookupKey == "" { - continue - } - if _, ok := seen[lookupKey]; ok { - continue - } - seen[lookupKey] = struct{}{} - lookupKeys = append(lookupKeys, lookupKey) - normalized = append(normalized, models.ProviderMetadata{ - LookupKey: lookupKey, - ProviderType: strings.TrimSpace(item.ProviderType), - DisplayName: strings.TrimSpace(item.DisplayName), - ProviderKey: strings.TrimSpace(item.ProviderKey), - MatchKind: strings.TrimSpace(item.MatchKind), - }) - } - - return db.Transaction(func(tx *gorm.DB) error { - if len(normalized) == 0 { - if err := tx.Where("1 = 1").Delete(&models.ProviderMetadata{}).Error; err != nil { - return fmt.Errorf("soft delete provider metadata: %w", err) - } - return nil - } - - if err := tx.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "lookup_key"}}, - DoUpdates: clause.AssignmentColumns([]string{ - "provider_type", - "display_name", - "provider_key", - "match_kind", - "updated_at", - "deleted_at", - }), - }).Create(&normalized).Error; err != nil { - return fmt.Errorf("upsert provider metadata: %w", err) - } - - if err := tx.Where("lookup_key NOT IN ?", lookupKeys).Delete(&models.ProviderMetadata{}).Error; err != nil { - return fmt.Errorf("soft delete stale provider metadata: %w", err) - } - - return nil - }) -} - -func ReplaceProviderMetadataForProviderTypes(db *gorm.DB, items []ProviderMetadataInput, providerTypes []string) error { - if db == nil { - return fmt.Errorf("database is nil") - } - - allowedTypes := make(map[string]struct{}, len(providerTypes)) - for _, providerType := range providerTypes { - providerType = strings.TrimSpace(providerType) - if providerType != "" { - allowedTypes[providerType] = struct{}{} - } - } - if len(allowedTypes) == 0 { - return nil - } - - normalized := make([]models.ProviderMetadata, 0, len(items)) - lookupKeys := make([]string, 0, len(items)) - seen := make(map[string]struct{}, len(items)) - for _, item := range items { - lookupKey := strings.TrimSpace(item.LookupKey) - providerType := strings.TrimSpace(item.ProviderType) - if lookupKey == "" { - continue - } - if _, ok := allowedTypes[providerType]; !ok { - continue - } - if _, ok := seen[lookupKey]; ok { - continue - } - seen[lookupKey] = struct{}{} - lookupKeys = append(lookupKeys, lookupKey) - normalized = append(normalized, models.ProviderMetadata{ - LookupKey: lookupKey, - ProviderType: providerType, - DisplayName: strings.TrimSpace(item.DisplayName), - ProviderKey: strings.TrimSpace(item.ProviderKey), - MatchKind: strings.TrimSpace(item.MatchKind), - }) - } - - types := make([]string, 0, len(allowedTypes)) - for providerType := range allowedTypes { - types = append(types, providerType) - } - - return db.Transaction(func(tx *gorm.DB) error { - if len(normalized) > 0 { - if err := tx.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "lookup_key"}}, - DoUpdates: clause.AssignmentColumns([]string{ - "provider_type", - "display_name", - "provider_key", - "match_kind", - "updated_at", - "deleted_at", - }), - }).Create(&normalized).Error; err != nil { - return fmt.Errorf("upsert provider metadata: %w", err) - } - } - - query := tx.Where("provider_type IN ?", types) - if len(lookupKeys) > 0 { - query = query.Where("lookup_key NOT IN ?", lookupKeys) - } - if err := query.Delete(&models.ProviderMetadata{}).Error; err != nil { - return fmt.Errorf("soft delete stale provider metadata: %w", err) - } - - return nil - }) -} - -func ListProviderMetadata(db *gorm.DB) ([]models.ProviderMetadata, error) { - if db == nil { - return nil, fmt.Errorf("database is nil") - } - - var items []models.ProviderMetadata - if err := db.Order("provider_type asc, display_name asc, lookup_key asc").Find(&items).Error; err != nil { - return nil, fmt.Errorf("list provider metadata: %w", err) - } - return items, nil -} diff --git a/internal/repository/provider_metadata_test.go b/internal/repository/provider_metadata_test.go deleted file mode 100644 index 2c3f2a02..00000000 --- a/internal/repository/provider_metadata_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package repository - -import ( - "path/filepath" - "testing" - - "cpa-usage-keeper/internal/config" - "gorm.io/gorm" -) - -func TestReplaceProviderMetadataUpsertsSoftDeletesAndRestoresRows(t *testing.T) { - db := openProviderMetadataTestDatabase(t) - if err := ReplaceProviderMetadata(db, []ProviderMetadataInput{{ - LookupKey: "sk-a", - ProviderType: "openai", - DisplayName: "Provider A", - ProviderKey: "openai:Provider A", - MatchKind: "api_key", - }, { - LookupKey: "prefix-b", - ProviderType: "claude", - DisplayName: "Provider B", - ProviderKey: "claude:Provider B", - MatchKind: "prefix", - }}); err != nil { - t.Fatalf("ReplaceProviderMetadata returned error: %v", err) - } - - if err := ReplaceProviderMetadata(db, []ProviderMetadataInput{{ - LookupKey: "prefix-b", - ProviderType: "claude", - DisplayName: "Provider B Updated", - ProviderKey: "claude:Provider B Updated", - MatchKind: "prefix", - }}); err != nil { - t.Fatalf("ReplaceProviderMetadata returned error: %v", err) - } - - items, err := ListProviderMetadata(db) - if err != nil { - t.Fatalf("ListProviderMetadata returned error: %v", err) - } - if len(items) != 1 { - t.Fatalf("expected 1 provider metadata row after replacement, got %d", len(items)) - } - if items[0].LookupKey != "prefix-b" || items[0].DisplayName != "Provider B Updated" { - t.Fatalf("unexpected provider metadata after replacement: %+v", items[0]) - } - - if err := ReplaceProviderMetadata(db, []ProviderMetadataInput{{ - LookupKey: "sk-a", - ProviderType: "openai", - DisplayName: "Provider A Restored", - ProviderKey: "openai:Provider A Restored", - MatchKind: "api_key", - }}); err != nil { - t.Fatalf("ReplaceProviderMetadata restore returned error: %v", err) - } - - items, err = ListProviderMetadata(db) - if err != nil { - t.Fatalf("ListProviderMetadata returned error: %v", err) - } - if len(items) != 1 || items[0].LookupKey != "sk-a" || items[0].DisplayName != "Provider A Restored" { - t.Fatalf("unexpected restored provider metadata: %+v", items) - } -} - -func TestReplaceProviderMetadataForProviderTypesOnlyDeletesFetchedProviderRows(t *testing.T) { - db := openProviderMetadataTestDatabase(t) - if err := ReplaceProviderMetadata(db, []ProviderMetadataInput{{ - LookupKey: "gemini-old", - ProviderType: "gemini", - DisplayName: "Gemini Old", - ProviderKey: "gemini:old", - MatchKind: "api_key", - }, { - LookupKey: "claude-old", - ProviderType: "claude", - DisplayName: "Claude Old", - ProviderKey: "claude:old", - MatchKind: "api_key", - }}); err != nil { - t.Fatalf("ReplaceProviderMetadata returned error: %v", err) - } - - if err := ReplaceProviderMetadataForProviderTypes(db, []ProviderMetadataInput{{ - LookupKey: "claude-new", - ProviderType: "claude", - DisplayName: "Claude New", - ProviderKey: "claude:new", - MatchKind: "api_key", - }}, []string{"claude"}); err != nil { - t.Fatalf("ReplaceProviderMetadataForProviderTypes returned error: %v", err) - } - - items, err := ListProviderMetadata(db) - if err != nil { - t.Fatalf("ListProviderMetadata returned error: %v", err) - } - lookupKeys := make(map[string]struct{}, len(items)) - for _, item := range items { - lookupKeys[item.LookupKey] = struct{}{} - } - for _, expected := range []string{"gemini-old", "claude-new"} { - if _, ok := lookupKeys[expected]; !ok { - t.Fatalf("expected %q to remain, got %+v", expected, items) - } - } - if _, ok := lookupKeys["claude-old"]; ok { - t.Fatalf("expected stale claude row to be deleted, got %+v", items) - } -} - -func openProviderMetadataTestDatabase(t *testing.T) *gorm.DB { - t.Helper() - db, err := OpenDatabase(config.Config{SQLitePath: filepath.Join(t.TempDir(), "provider_metadata.db")}) - if err != nil { - t.Fatalf("OpenDatabase returned error: %v", err) - } - closeTestDatabase(t, db) - return db -} diff --git a/internal/repository/redis_usage_inbox.go b/internal/repository/redis_usage_inbox.go index 88162f4b..f71c1034 100644 --- a/internal/repository/redis_usage_inbox.go +++ b/internal/repository/redis_usage_inbox.go @@ -18,7 +18,8 @@ const ( RedisUsageInboxStatusProcessFailed = "process_failed" RedisUsageInboxStatusDiscarded = "discarded" - redisUsageInboxMaxErrorLength = 1024 + redisUsageInboxMaxErrorLength = 1024 + redisUsageInboxMaxProcessAttempts = 5 ) type RedisInboxInsert struct { @@ -58,20 +59,9 @@ func InsertRedisUsageInboxMessages(db *gorm.DB, inputs []RedisInboxInsert) ([]mo return rows, nil } -func MarkRedisUsageInboxProcessed(db *gorm.DB, id uint, snapshotRunID uint, eventKey string, processedAt time.Time) error { +func MarkRedisUsageInboxProcessed(db *gorm.DB, id uint, eventKey string, processedAt time.Time) error { return db.Model(&models.RedisUsageInbox{}).Where("id = ?", id).Updates(map[string]any{ "status": RedisUsageInboxStatusProcessed, - "snapshot_run_id": snapshotRunID, - "usage_event_key": eventKey, - "processed_at": processedAt.UTC(), - "last_error": "", - }).Error -} - -func MarkRedisUsageInboxProcessedWithoutSnapshot(db *gorm.DB, id uint, eventKey string, processedAt time.Time) error { - return db.Model(&models.RedisUsageInbox{}).Where("id = ?", id).Updates(map[string]any{ - "status": RedisUsageInboxStatusProcessed, - "snapshot_run_id": nil, "usage_event_key": eventKey, "processed_at": processedAt.UTC(), "last_error": "", @@ -83,7 +73,17 @@ func MarkRedisUsageInboxDecodeFailed(db *gorm.DB, id uint, decodeErr error) erro } func MarkRedisUsageInboxProcessFailed(db *gorm.DB, id uint, processErr error) error { - return markRedisUsageInboxFailed(db, id, RedisUsageInboxStatusProcessFailed, processErr) + return db.Model(&models.RedisUsageInbox{}).Where("id = ?", id).Updates(map[string]any{ + "status": gorm.Expr( + "CASE WHEN attempt_count + ? >= ? THEN ? ELSE ? END", + 1, + redisUsageInboxMaxProcessAttempts, + RedisUsageInboxStatusDiscarded, + RedisUsageInboxStatusProcessFailed, + ), + "attempt_count": gorm.Expr("attempt_count + ?", 1), + "last_error": boundedRedisUsageInboxError(processErr), + }).Error } // ListProcessableRedisUsageInbox 返回待处理和可重试的数据,不返回已解码失败或已丢弃的数据。 diff --git a/internal/repository/redis_usage_inbox_test.go b/internal/repository/redis_usage_inbox_test.go index 63bb47db..07e5ee45 100644 --- a/internal/repository/redis_usage_inbox_test.go +++ b/internal/repository/redis_usage_inbox_test.go @@ -71,7 +71,7 @@ func TestRedisUsageInboxStatusTransitions(t *testing.T) { t.Fatalf("InsertRedisUsageInboxMessages returned error: %v", err) } - if err := MarkRedisUsageInboxProcessed(db, rows[0].ID, 42, "event-1", processedAt); err != nil { + if err := MarkRedisUsageInboxProcessed(db, rows[0].ID, "event-1", processedAt); err != nil { t.Fatalf("MarkRedisUsageInboxProcessed returned error: %v", err) } @@ -82,9 +82,6 @@ func TestRedisUsageInboxStatusTransitions(t *testing.T) { if stored.Status != RedisUsageInboxStatusProcessed { t.Fatalf("expected processed status, got %q", stored.Status) } - if stored.SnapshotRunID == nil || *stored.SnapshotRunID != 42 { - t.Fatalf("expected snapshot id 42, got %+v", stored.SnapshotRunID) - } if stored.UsageEventKey != "event-1" { t.Fatalf("expected event key to be stored, got %q", stored.UsageEventKey) } @@ -137,7 +134,7 @@ func TestRedisUsageInboxFailureTransitionsBoundErrors(t *testing.T) { } } -func TestMarkRedisUsageInboxProcessFailedKeepsRowsRetryableAfterRepeatedFailures(t *testing.T) { +func TestMarkRedisUsageInboxProcessFailedDiscardsRowsAfterMaxAttempts(t *testing.T) { db := openTestDatabase(t) poppedAt := time.Date(2026, 4, 27, 10, 0, 0, 0, time.UTC) @@ -145,7 +142,8 @@ func TestMarkRedisUsageInboxProcessFailedKeepsRowsRetryableAfterRepeatedFailures if err != nil { t.Fatalf("InsertRedisUsageInboxMessages returned error: %v", err) } - for i := 0; i < 5; i++ { + const maxProcessAttempts = 5 + for i := 0; i < maxProcessAttempts; i++ { if err := MarkRedisUsageInboxProcessFailed(db, rows[0].ID, fmt.Errorf("insert failed %d", i+1)); err != nil { t.Fatalf("MarkRedisUsageInboxProcessFailed attempt %d returned error: %v", i+1, err) } @@ -155,11 +153,11 @@ func TestMarkRedisUsageInboxProcessFailedKeepsRowsRetryableAfterRepeatedFailures if err := db.First(&stored, rows[0].ID).Error; err != nil { t.Fatalf("load inbox row: %v", err) } - if stored.Status != RedisUsageInboxStatusProcessFailed { - t.Fatalf("expected process_failed after repeated process failures, got %q", stored.Status) + if stored.Status != RedisUsageInboxStatusDiscarded { + t.Fatalf("expected discarded after repeated process failures, got %q", stored.Status) } - if stored.AttemptCount != 5 { - t.Fatalf("expected 5 attempts, got %d", stored.AttemptCount) + if stored.AttemptCount != maxProcessAttempts { + t.Fatalf("expected %d attempts, got %d", maxProcessAttempts, stored.AttemptCount) } if stored.LastError != "insert failed 5" { t.Fatalf("expected last error from final attempt, got %q", stored.LastError) @@ -169,8 +167,8 @@ func TestMarkRedisUsageInboxProcessFailedKeepsRowsRetryableAfterRepeatedFailures if err != nil { t.Fatalf("ListProcessableRedisUsageInbox returned error: %v", err) } - if len(processable) != 1 || processable[0].ID != rows[0].ID { - t.Fatalf("expected repeated process failure row to remain processable, got %+v", processable) + if len(processable) != 0 { + t.Fatalf("expected discarded row to be excluded from processing, got %+v", processable) } } @@ -220,6 +218,7 @@ func TestCleanupRedisUsageInboxRemovesOldProcessedAndFailedRows(t *testing.T) { {QueueKey: "queue", RawMessage: `{"request_id":"processed-old"}`, PoppedAt: now.Add(-48 * time.Hour)}, {QueueKey: "queue", RawMessage: `{"request_id":"processed-today"}`, PoppedAt: now.Add(-time.Hour)}, {QueueKey: "queue", RawMessage: `{"request_id":"failed-old"}`, PoppedAt: now.AddDate(0, 0, -8)}, + {QueueKey: "queue", RawMessage: `{"request_id":"discarded-old"}`, PoppedAt: now.AddDate(0, 0, -8)}, {QueueKey: "queue", RawMessage: `{"request_id":"failed-recent"}`, PoppedAt: now.AddDate(0, 0, -6)}, {QueueKey: "queue", RawMessage: `{"request_id":"pending-old"}`, PoppedAt: now.AddDate(0, 0, -10)}, }) @@ -237,7 +236,10 @@ func TestCleanupRedisUsageInboxRemovesOldProcessedAndFailedRows(t *testing.T) { if err := db.Model(&models.RedisUsageInbox{}).Where("id = ?", rows[2].ID).Updates(map[string]any{"status": RedisUsageInboxStatusProcessFailed, "updated_at": now.AddDate(0, 0, -8)}).Error; err != nil { t.Fatalf("seed old failed row: %v", err) } - if err := db.Model(&models.RedisUsageInbox{}).Where("id = ?", rows[3].ID).Updates(map[string]any{"status": RedisUsageInboxStatusDecodeFailed, "updated_at": now.AddDate(0, 0, -6)}).Error; err != nil { + if err := db.Model(&models.RedisUsageInbox{}).Where("id = ?", rows[3].ID).Updates(map[string]any{"status": RedisUsageInboxStatusDiscarded, "updated_at": now.AddDate(0, 0, -8)}).Error; err != nil { + t.Fatalf("seed old discarded row: %v", err) + } + if err := db.Model(&models.RedisUsageInbox{}).Where("id = ?", rows[4].ID).Updates(map[string]any{"status": RedisUsageInboxStatusDecodeFailed, "updated_at": now.AddDate(0, 0, -6)}).Error; err != nil { t.Fatalf("seed recent failed row: %v", err) } @@ -245,7 +247,7 @@ func TestCleanupRedisUsageInboxRemovesOldProcessedAndFailedRows(t *testing.T) { if err != nil { t.Fatalf("CleanupRedisUsageInbox returned error: %v", err) } - if result.ProcessedDeleted != 1 || result.FailedDeleted != 1 { + if result.ProcessedDeleted != 1 || result.FailedDeleted != 2 { t.Fatalf("unexpected cleanup result: %+v", result) } @@ -257,7 +259,7 @@ func TestCleanupRedisUsageInboxRemovesOldProcessedAndFailedRows(t *testing.T) { for _, row := range remaining { remainingIDs = append(remainingIDs, row.ID) } - expectedIDs := []uint{rows[1].ID, rows[3].ID, rows[4].ID} + expectedIDs := []uint{rows[1].ID, rows[4].ID, rows[5].ID} if fmt.Sprint(remainingIDs) != fmt.Sprint(expectedIDs) { t.Fatalf("expected remaining ids %v, got %v", expectedIDs, remainingIDs) } @@ -275,7 +277,7 @@ func TestListPendingRedisUsageInboxReturnsPendingRowsInIDOrder(t *testing.T) { if err != nil { t.Fatalf("InsertRedisUsageInboxMessages returned error: %v", err) } - if err := MarkRedisUsageInboxProcessed(db, rows[1].ID, 7, "event-2", poppedAt.Add(time.Minute)); err != nil { + if err := MarkRedisUsageInboxProcessed(db, rows[1].ID, "event-2", poppedAt.Add(time.Minute)); err != nil { t.Fatalf("MarkRedisUsageInboxProcessed returned error: %v", err) } diff --git a/internal/repository/usage.go b/internal/repository/usage.go index b3ef8b43..5e8398c9 100644 --- a/internal/repository/usage.go +++ b/internal/repository/usage.go @@ -70,6 +70,8 @@ func ListUsageEventsWithFilter(db *gorm.DB, filter UsageQueryFilter) (*UsageEven Timestamp: event.Timestamp.UTC(), APIGroupKey: strings.TrimSpace(event.APIGroupKey), Model: strings.TrimSpace(event.Model), + AuthType: strings.TrimSpace(event.AuthType), + Provider: strings.TrimSpace(event.Provider), Source: strings.TrimSpace(event.Source), AuthIndex: strings.TrimSpace(event.AuthIndex), Failed: event.Failed, @@ -127,10 +129,19 @@ func applyUsageEventsListFilter(query *gorm.DB, filter UsageQueryFilter) *gorm.D if model := strings.TrimSpace(filter.Model); model != "" { query = query.Where("TRIM(model) = ?", model) } - if source := strings.TrimSpace(filter.Source); source != "" { - query = query.Where("TRIM(source) = ?", source) + if authType := strings.TrimSpace(filter.AuthType); authType != "" { + query = query.Where("TRIM(auth_type) = ?", authType) + } + if provider := strings.TrimSpace(filter.Provider); provider != "" { + query = query.Where("TRIM(provider) = ?", provider) } - if authIndex := strings.TrimSpace(filter.AuthIndex); authIndex != "" { + if source := strings.TrimSpace(filter.Source); source != "" { + if authIndex := strings.TrimSpace(filter.AuthIndex); authIndex != "" && strings.TrimSpace(filter.AuthType) == "oauth" { + query = query.Where("(TRIM(auth_index) = ? OR TRIM(source) = ?)", authIndex, source) + } else { + query = query.Where("TRIM(source) = ?", source) + } + } else if authIndex := strings.TrimSpace(filter.AuthIndex); authIndex != "" { query = query.Where("TRIM(auth_index) = ?", authIndex) } switch strings.TrimSpace(filter.Result) { @@ -493,6 +504,10 @@ func applyUsageEventToOverviewSeries(series *UsageOverviewSeriesRecord, event mo series.Models[modelName] = modelSeries } +func usageEventRequiresPricing(event models.UsageEvent) bool { + return event.InputTokens > 0 || event.OutputTokens > 0 || event.CachedTokens > 0 +} + func applyUsageEventToOverview(overview *UsageOverviewRecord, event models.UsageEvent, bucketByDay bool, latestHourlyStart *time.Time, pricingByModel map[string]models.ModelPriceSetting) { overview.Summary.CachedTokens += event.CachedTokens overview.Summary.ReasoningTokens += event.ReasoningTokens @@ -502,7 +517,7 @@ func applyUsageEventToOverview(overview *UsageOverviewRecord, event models.Usage overview.Health.TotalSuccess++ } pricing, ok := pricingByModel[strings.TrimSpace(event.Model)] - if !ok { + if !ok && usageEventRequiresPricing(event) { overview.Summary.CostAvailable = false } cost := calculateUsageEventCost(event, pricing) diff --git a/internal/repository/usage_events_test.go b/internal/repository/usage_events_test.go index 837ac957..202f6119 100644 --- a/internal/repository/usage_events_test.go +++ b/internal/repository/usage_events_test.go @@ -17,9 +17,9 @@ func TestListUsageEventsWithFilterAppliesTimeBoundsAndPagination(t *testing.T) { closeTestDatabase(t, db) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", AuthIndex: "1", TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-b", AuthIndex: "2", TotalTokens: 20}, - {EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Source: "source-c", AuthIndex: "3", TotalTokens: 30}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", AuthIndex: "1", TotalTokens: 10}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-b", AuthIndex: "2", TotalTokens: 20}, + {EventKey: "event-3", APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Source: "source-c", AuthIndex: "3", TotalTokens: 30}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) @@ -50,9 +50,9 @@ func TestListUsageEventsWithFilterPagesByTimestampAndID(t *testing.T) { closeTestDatabase(t, db) timestamp := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: timestamp, Source: "source-a", AuthIndex: "1", TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: timestamp, Source: "source-b", AuthIndex: "2", TotalTokens: 20}, - {EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: timestamp.Add(-time.Hour), Source: "source-c", AuthIndex: "3", TotalTokens: 30}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: timestamp, Source: "source-a", AuthIndex: "1", TotalTokens: 10}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: timestamp, Source: "source-b", AuthIndex: "2", TotalTokens: 20}, + {EventKey: "event-3", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: timestamp.Add(-time.Hour), Source: "source-c", AuthIndex: "3", TotalTokens: 30}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) @@ -84,10 +84,10 @@ func TestListUsageEventsWithFilterAppliesModelSourceAndResultFilters(t *testing. } closeTestDatabase(t, db) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", Failed: false, TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-a", Failed: true, TotalTokens: 20}, - {EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Source: "source-a", Failed: false, TotalTokens: 30}, - {EventKey: "event-4", SnapshotRunID: 1, APIGroupKey: "provider-c", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC), Source: "source-b", Failed: false, TotalTokens: 40}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", Failed: false, TotalTokens: 10}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-a", Failed: true, TotalTokens: 20}, + {EventKey: "event-3", APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Source: "source-a", Failed: false, TotalTokens: 30}, + {EventKey: "event-4", APIGroupKey: "provider-c", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC), Source: "source-b", Failed: false, TotalTokens: 40}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) @@ -105,6 +105,66 @@ func TestListUsageEventsWithFilterAppliesModelSourceAndResultFilters(t *testing. } } +func TestListUsageEventsWithFilterAppliesProviderAuthTypeFilter(t *testing.T) { + db, err := OpenDatabase(config.Config{SQLitePath: filepath.Join(t.TempDir(), "usage-events-provider-filter.db")}) + if err != nil { + t.Fatalf("OpenDatabase returned error: %v", err) + } + closeTestDatabase(t, db) + events := []models.UsageEvent{ + {EventKey: "event-1", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), AuthType: "apikey", Provider: "OpenAI Mirror", Source: "sk-key-a", TotalTokens: 10}, + {EventKey: "event-2", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), AuthType: "apikey", Provider: "OpenAI Mirror", Source: "sk-key-b", TotalTokens: 20}, + {EventKey: "event-3", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), AuthType: "apikey", Provider: "Other Provider", Source: "sk-key-c", TotalTokens: 30}, + {EventKey: "event-4", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC), AuthType: "oauth", Provider: "OpenAI Mirror", Source: "oauth-source", AuthIndex: "auth-1", TotalTokens: 40}, + } + if _, _, err := InsertUsageEvents(db, events); err != nil { + t.Fatalf("InsertUsageEvents returned error: %v", err) + } + + page, err := ListUsageEventsWithFilter(db, UsageQueryFilter{AuthType: "apikey", Provider: "OpenAI Mirror", Page: 1, PageSize: 20}) + if err != nil { + t.Fatalf("ListUsageEventsWithFilter returned error: %v", err) + } + if page.TotalCount != 2 || len(page.Events) != 2 { + t.Fatalf("expected two matching provider events, got %+v", page) + } + for _, event := range page.Events { + if event.AuthType != "apikey" || event.Provider != "OpenAI Mirror" { + t.Fatalf("unexpected provider filtered event: %+v", event) + } + } +} + +func TestListUsageEventsWithFilterAppliesAuthSourceOrAuthIndexFilter(t *testing.T) { + db, err := OpenDatabase(config.Config{SQLitePath: filepath.Join(t.TempDir(), "usage-events-auth-filter.db")}) + if err != nil { + t.Fatalf("OpenDatabase returned error: %v", err) + } + closeTestDatabase(t, db) + events := []models.UsageEvent{ + {EventKey: "event-1", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), AuthType: "oauth", Source: "auth-1", AuthIndex: "1", TotalTokens: 10}, + {EventKey: "event-2", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), AuthType: "oauth", Source: "source-alias", AuthIndex: "auth-1", TotalTokens: 20}, + {EventKey: "event-3", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), AuthType: "oauth", Source: "other", AuthIndex: "other", TotalTokens: 30}, + {EventKey: "event-4", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC), AuthType: "apikey", Source: "auth-1", AuthIndex: "auth-1", Provider: "Provider A", TotalTokens: 40}, + } + if _, _, err := InsertUsageEvents(db, events); err != nil { + t.Fatalf("InsertUsageEvents returned error: %v", err) + } + + page, err := ListUsageEventsWithFilter(db, UsageQueryFilter{AuthType: "oauth", Source: "auth-1", AuthIndex: "auth-1", Page: 1, PageSize: 20}) + if err != nil { + t.Fatalf("ListUsageEventsWithFilter returned error: %v", err) + } + if page.TotalCount != 2 || len(page.Events) != 2 { + t.Fatalf("expected two matching auth events, got %+v", page) + } + for _, event := range page.Events { + if event.AuthType != "oauth" || (event.Source != "auth-1" && event.AuthIndex != "auth-1") { + t.Fatalf("unexpected auth filtered event: %+v", event) + } + } +} + func TestListUsageEventFilterOptionsWithFilterReturnsStableOptions(t *testing.T) { db, err := OpenDatabase(config.Config{SQLitePath: filepath.Join(t.TempDir(), "usage-events-facets.db")}) if err != nil { @@ -112,9 +172,9 @@ func TestListUsageEventFilterOptionsWithFilterReturnsStableOptions(t *testing.T) } closeTestDatabase(t, db) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", Failed: false, TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-b", Failed: true, TotalTokens: 20}, - {EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-b", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Source: "source-a", Failed: false, TotalTokens: 30}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", Failed: false, TotalTokens: 10}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-b", Failed: true, TotalTokens: 20}, + {EventKey: "event-3", APIGroupKey: "provider-b", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Source: "source-a", Failed: false, TotalTokens: 30}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) @@ -141,17 +201,17 @@ func TestListUsageAnalysisWithFilterAggregatesApisAndModels(t *testing.T) { events := []models.UsageEvent{ { - EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", + EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Failed: false, LatencyMS: 100, InputTokens: 10, OutputTokens: 4, ReasoningTokens: 2, CachedTokens: 1, TotalTokens: 17, }, { - EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", + EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Failed: true, LatencyMS: 250, InputTokens: 20, OutputTokens: 5, ReasoningTokens: 0, CachedTokens: 0, TotalTokens: 25, }, { - EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-b", Model: "gpt-5", + EventKey: "event-3", APIGroupKey: "provider-b", Model: "gpt-5", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Failed: false, LatencyMS: 400, InputTokens: 30, OutputTokens: 7, ReasoningTokens: 3, CachedTokens: 2, TotalTokens: 42, }, diff --git a/internal/repository/usage_filter.go b/internal/repository/usage_filter.go index 7b9acc1a..e27796a8 100644 --- a/internal/repository/usage_filter.go +++ b/internal/repository/usage_filter.go @@ -17,6 +17,8 @@ type UsageQueryFilter struct { Model string Source string AuthIndex string + AuthType string + Provider string Result string } @@ -42,6 +44,8 @@ type UsageEventRecord struct { Timestamp time.Time APIGroupKey string Model string + AuthType string + Provider string Source string AuthIndex string Failed bool diff --git a/internal/repository/usage_filter_test.go b/internal/repository/usage_filter_test.go index caa607eb..f8cca5a6 100644 --- a/internal/repository/usage_filter_test.go +++ b/internal/repository/usage_filter_test.go @@ -30,9 +30,9 @@ func TestBuildUsageSnapshotWithFilterAppliesTimeBounds(t *testing.T) { closeTestDatabase(t, db) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", AuthIndex: "1", TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-b", AuthIndex: "2", TotalTokens: 20}, - {EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 17, 10, 0, 0, 0, time.UTC), Source: "source-c", AuthIndex: "3", TotalTokens: 30}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", AuthIndex: "1", TotalTokens: 10}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-b", AuthIndex: "2", TotalTokens: 20}, + {EventKey: "event-3", APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 17, 10, 0, 0, 0, time.UTC), Source: "source-c", AuthIndex: "3", TotalTokens: 30}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) @@ -80,17 +80,17 @@ func TestBuildUsageOverviewWithFilterComputesSummaryAndSeries(t *testing.T) { events := []models.UsageEvent{ { - EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", + EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 15, 0, 0, time.UTC), Failed: false, InputTokens: 1000, OutputTokens: 500, ReasoningTokens: 100, CachedTokens: 200, TotalTokens: 1800, }, { - EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", + EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 45, 0, 0, time.UTC), Failed: true, InputTokens: 2000, OutputTokens: 1000, ReasoningTokens: 50, CachedTokens: 100, TotalTokens: 3150, }, { - EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", + EventKey: "event-3", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 17, 11, 5, 0, 0, time.UTC), Failed: false, InputTokens: 500, OutputTokens: 250, ReasoningTokens: 25, CachedTokens: 50, TotalTokens: 825, }, @@ -198,13 +198,13 @@ func TestBuildUsageOverviewWithFilterComputesSummaryAndSeries(t *testing.T) { func TestBuildUsageOverviewFromEventsBuildsSnapshotAndOverviewInOnePass(t *testing.T) { events := []models.UsageEvent{ { - EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", + EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 15, 0, 0, time.UTC), Failed: false, InputTokens: 1000, OutputTokens: 500, ReasoningTokens: 100, CachedTokens: 200, TotalTokens: 1800, Source: "source-a", AuthIndex: "1", LatencyMS: 120, }, { - EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "", Model: "", + EventKey: "event-2", APIGroupKey: "", Model: "", Timestamp: time.Date(2026, 4, 16, 10, 45, 0, 0, time.UTC), Failed: true, InputTokens: 2000, OutputTokens: 1000, ReasoningTokens: 50, CachedTokens: 100, TotalTokens: 3150, Source: " source-b ", AuthIndex: " 2 ", LatencyMS: 250, @@ -252,7 +252,7 @@ func TestBuildUsageOverviewFromEventsBuildsSnapshotAndOverviewInOnePass(t *testi t.Fatalf("unexpected summary token breakdown: %+v", overview.Summary) } if overview.Summary.CostAvailable { - t.Fatalf("expected cost to be unavailable when any event model is unpriced, got %+v", overview.Summary) + t.Fatalf("expected cost to be unavailable when any event model with billable tokens is unpriced, got %+v", overview.Summary) } if overview.Series.Requests["2026-04-16T09:00:00Z"] != 1 || overview.Series.Requests["2026-04-16T10:00:00Z"] != 1 { t.Fatalf("unexpected hourly request series: %+v", overview.Series.Requests) @@ -282,8 +282,8 @@ func TestBuildUsageOverviewWithFilterBuilds24hHealthGridFor24hRange(t *testing.T closeTestDatabase(t, db) events := []models.UsageEvent{ - {EventKey: "event-success", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 17, 9, 31, 0, 0, time.UTC), Failed: false, TotalTokens: 10}, - {EventKey: "event-failed", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 17, 23, 59, 0, 0, time.UTC), Failed: true, TotalTokens: 20}, + {EventKey: "event-success", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 17, 9, 31, 0, 0, time.UTC), Failed: false, TotalTokens: 10}, + {EventKey: "event-failed", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 17, 23, 59, 0, 0, time.UTC), Failed: true, TotalTokens: 20}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) @@ -365,12 +365,12 @@ func TestBuildUsageOverviewWithFilterReturnsUnavailableCostForPartialPricing(t * events := []models.UsageEvent{ { - EventKey: "event-priced", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "priced-model", + EventKey: "event-priced", APIGroupKey: "provider-a", Model: "priced-model", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 1_000_000, InputTokens: 1_000_000, }, { - EventKey: "event-unpriced", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "unpriced-model", + EventKey: "event-unpriced", APIGroupKey: "provider-a", Model: "unpriced-model", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), TotalTokens: 1_000_000, InputTokens: 1_000_000, }, @@ -387,13 +387,59 @@ func TestBuildUsageOverviewWithFilterReturnsUnavailableCostForPartialPricing(t * } if overview.Summary.CostAvailable { - t.Fatalf("expected cost to be unavailable when any in-range model is unpriced, got %+v", overview.Summary) + t.Fatalf("expected cost to be unavailable when any in-range event model with billable tokens is unpriced, got %+v", overview.Summary) } if overview.Summary.TotalCost != 1 { t.Fatalf("expected priced portion to remain in total cost, got %+v", overview.Summary) } } +func TestBuildUsageOverviewWithFilterReturnsAvailableCostWhenUnpricedEventsHaveNoBillableTokens(t *testing.T) { + db, err := OpenDatabase(config.Config{SQLitePath: filepath.Join(t.TempDir(), "usage-overview-zero-token-unpriced.db")}) + if err != nil { + t.Fatalf("OpenDatabase returned error: %v", err) + } + closeTestDatabase(t, db) + + if _, err := UpsertModelPriceSetting(db, ModelPriceSettingInput{ + Model: "priced-model", + PromptPricePer1M: 1, + CompletionPricePer1M: 0, + CachePricePer1M: 0, + }); err != nil { + t.Fatalf("UpsertModelPriceSetting returned error: %v", err) + } + + events := []models.UsageEvent{ + { + EventKey: "event-priced", APIGroupKey: "provider-a", Model: "priced-model", + Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 1_000_000, + InputTokens: 1_000_000, + }, + { + EventKey: "event-zero-token", APIGroupKey: "provider-a", Model: "unpriced-image-model", + Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), + }, + } + if _, _, err := InsertUsageEvents(db, events); err != nil { + t.Fatalf("InsertUsageEvents returned error: %v", err) + } + + start := time.Date(2026, 4, 16, 0, 0, 0, 0, time.UTC) + end := time.Date(2026, 4, 16, 23, 59, 59, 999000000, time.UTC) + overview, err := BuildUsageOverviewWithFilter(db, UsageQueryFilter{Range: "24h", StartTime: &start, EndTime: &end}) + if err != nil { + t.Fatalf("BuildUsageOverviewWithFilter returned error: %v", err) + } + + if !overview.Summary.CostAvailable { + t.Fatalf("expected zero-token unpriced model not to make cost unavailable, got %+v", overview.Summary) + } + if overview.Summary.TotalCost != 1 { + t.Fatalf("expected priced event cost to remain available, got %+v", overview.Summary) + } +} + func TestBuildUsageOverviewWithFilterReturnsUnavailableCostWithoutPricing(t *testing.T) { db, err := OpenDatabase(config.Config{SQLitePath: filepath.Join(t.TempDir(), "usage-overview-no-pricing.db")}) if err != nil { @@ -402,7 +448,7 @@ func TestBuildUsageOverviewWithFilterReturnsUnavailableCostWithoutPricing(t *tes closeTestDatabase(t, db) events := []models.UsageEvent{{ - EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", + EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 15, 0, 0, time.UTC), TotalTokens: 1800, InputTokens: 1000, OutputTokens: 500, CachedTokens: 200, }} @@ -462,7 +508,6 @@ func TestBuildUsageOverviewWithFilterUsesExactPresetWindowMinutes(t *testing.T) t.Run(tc.name, func(t *testing.T) { event := models.UsageEvent{ EventKey: "event-" + tc.rangeName, - SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: tc.end, @@ -511,9 +556,9 @@ func TestBuildUsageOverviewWithFilterBuildsLatestHourlySeriesForLongRanges(t *te } events := []models.UsageEvent{ - {EventKey: "event-old", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 17, 8, 0, 0, 0, time.UTC), TotalTokens: 1_000_000, InputTokens: 1_000_000}, - {EventKey: "event-latest-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 23, 22, 15, 0, 0, time.UTC), TotalTokens: 2_000_000, InputTokens: 2_000_000, OutputTokens: 5, CachedTokens: 7, ReasoningTokens: 11}, - {EventKey: "event-latest-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 23, 23, 45, 0, 0, time.UTC), TotalTokens: 3_000_000, InputTokens: 3_000_000, OutputTokens: 13, CachedTokens: 17, ReasoningTokens: 19}, + {EventKey: "event-old", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 17, 8, 0, 0, 0, time.UTC), TotalTokens: 1_000_000, InputTokens: 1_000_000}, + {EventKey: "event-latest-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 23, 22, 15, 0, 0, time.UTC), TotalTokens: 2_000_000, InputTokens: 2_000_000, OutputTokens: 5, CachedTokens: 7, ReasoningTokens: 11}, + {EventKey: "event-latest-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 23, 23, 45, 0, 0, time.UTC), TotalTokens: 3_000_000, InputTokens: 3_000_000, OutputTokens: 13, CachedTokens: 17, ReasoningTokens: 19}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) @@ -555,8 +600,8 @@ func TestBuildUsageOverviewWithFilterUsesDailyBucketsForLongCustomRanges(t *test closeTestDatabase(t, db) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC), TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 26, 18, 0, 0, 0, time.UTC), TotalTokens: 20}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC), TotalTokens: 10}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 26, 18, 0, 0, 0, time.UTC), TotalTokens: 20}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) diff --git a/internal/repository/usage_identities.go b/internal/repository/usage_identities.go new file mode 100644 index 00000000..6b06de5f --- /dev/null +++ b/internal/repository/usage_identities.go @@ -0,0 +1,278 @@ +package repository + +import ( + "context" + "fmt" + "strings" + "time" + + "cpa-usage-keeper/internal/models" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func ReplaceUsageIdentitiesForAuthType(ctx context.Context, db *gorm.DB, identities []models.UsageIdentity, authType models.UsageIdentityAuthType, now time.Time) error { + if db == nil { + return fmt.Errorf("database is nil") + } + + normalized, incomingIdentities := normalizeUsageIdentities(identities, authType) + + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := upsertUsageIdentities(tx, normalized); err != nil { + return err + } + + query := tx.Model(&models.UsageIdentity{}).Where("auth_type = ?", authType) + if len(incomingIdentities) > 0 { + query = query.Where("identity NOT IN ?", incomingIdentities) + } + if err := query.Updates(map[string]any{ + "is_deleted": true, + "deleted_at": now, + }).Error; err != nil { + return fmt.Errorf("mark stale usage identities deleted: %w", err) + } + + return nil + }) +} + +func ReplaceUsageIdentitiesForProviderTypes(ctx context.Context, db *gorm.DB, identities []models.UsageIdentity, providerTypes []string, now time.Time) error { + if db == nil { + return fmt.Errorf("database is nil") + } + + normalized, incomingIdentities := normalizeUsageIdentities(identities, models.UsageIdentityAuthTypeAIProvider) + types := normalizeProviderTypes(providerTypes) + + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := upsertUsageIdentities(tx, normalized); err != nil { + return err + } + if len(types) == 0 { + return nil + } + + query := tx.Model(&models.UsageIdentity{}). + Where("auth_type = ?", models.UsageIdentityAuthTypeAIProvider). + Where("type IN ?", types) + if len(incomingIdentities) > 0 { + query = query.Where("identity NOT IN ?", incomingIdentities) + } + if err := query.Updates(map[string]any{ + "is_deleted": true, + "deleted_at": now, + }).Error; err != nil { + return fmt.Errorf("mark stale provider usage identities deleted: %w", err) + } + + return nil + }) +} + +func ListUsageIdentities(ctx context.Context, db *gorm.DB) ([]models.UsageIdentity, error) { + if db == nil { + return nil, fmt.Errorf("database is nil") + } + + var identities []models.UsageIdentity + if err := db.WithContext(ctx).Order("auth_type asc, name asc, id asc").Find(&identities).Error; err != nil { + return nil, fmt.Errorf("list usage identities: %w", err) + } + return identities, nil +} + +func AggregateUsageIdentityStats(ctx context.Context, db *gorm.DB, now time.Time) error { + if db == nil { + return fmt.Errorf("database is nil") + } + + var identities []models.UsageIdentity + if err := db.WithContext(ctx).Find(&identities).Error; err != nil { + return fmt.Errorf("list usage identities for aggregation: %w", err) + } + + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + for _, identity := range identities { + delta, err := aggregateUsageIdentityDelta(tx, identity) + if err != nil { + return err + } + if delta.TotalRequests == 0 { + continue + } + + firstUsedAt := identity.FirstUsedAt + if delta.FirstUsedAt != nil && (firstUsedAt == nil || delta.FirstUsedAt.Before(*firstUsedAt)) { + first := *delta.FirstUsedAt + firstUsedAt = &first + } + + lastUsedAt := identity.LastUsedAt + if delta.LastUsedAt != nil && (lastUsedAt == nil || delta.LastUsedAt.After(*lastUsedAt)) { + last := *delta.LastUsedAt + lastUsedAt = &last + } + + updates := map[string]any{ + "total_requests": identity.TotalRequests + delta.TotalRequests, + "success_count": identity.SuccessCount + delta.SuccessCount, + "failure_count": identity.FailureCount + delta.FailureCount, + "input_tokens": identity.InputTokens + delta.InputTokens, + "output_tokens": identity.OutputTokens + delta.OutputTokens, + "reasoning_tokens": identity.ReasoningTokens + delta.ReasoningTokens, + "cached_tokens": identity.CachedTokens + delta.CachedTokens, + "total_tokens": identity.TotalTokens + delta.TotalTokens, + "first_used_at": firstUsedAt, + "last_used_at": lastUsedAt, + "stats_updated_at": now, + "last_aggregated_usage_event_id": delta.MaxUsageEventID, + } + if err := tx.Model(&models.UsageIdentity{}).Where("id = ?", identity.ID).Updates(updates).Error; err != nil { + return fmt.Errorf("update usage identity stats for %q: %w", identity.Identity, err) + } + } + return nil + }) +} + +type usageIdentityStatsDelta struct { + TotalRequests int64 + SuccessCount int64 + FailureCount int64 + InputTokens int64 + OutputTokens int64 + ReasoningTokens int64 + CachedTokens int64 + TotalTokens int64 + FirstUsedAt *time.Time + LastUsedAt *time.Time + MaxUsageEventID uint +} + +func aggregateUsageIdentityDelta(tx *gorm.DB, identity models.UsageIdentity) (usageIdentityStatsDelta, error) { + var delta usageIdentityStatsDelta + query, ok := usageIdentityEventsQuery(tx.Model(&models.UsageEvent{}), identity) + if !ok { + return delta, nil + } + + if err := query. + Select(` + COUNT(*) AS total_requests, + COALESCE(SUM(CASE WHEN failed THEN 0 ELSE 1 END), 0) AS success_count, + COALESCE(SUM(CASE WHEN failed THEN 1 ELSE 0 END), 0) AS failure_count, + COALESCE(SUM(input_tokens), 0) AS input_tokens, + COALESCE(SUM(output_tokens), 0) AS output_tokens, + COALESCE(SUM(reasoning_tokens), 0) AS reasoning_tokens, + COALESCE(SUM(cached_tokens), 0) AS cached_tokens, + COALESCE(SUM(total_tokens), 0) AS total_tokens, + COALESCE(MAX(id), 0) AS max_usage_event_id`). + Where("id > ?", identity.LastAggregatedUsageEventID). + Scan(&delta).Error; err != nil { + return delta, fmt.Errorf("aggregate usage identity stats for %q: %w", identity.Identity, err) + } + if delta.TotalRequests == 0 { + return delta, nil + } + + var firstEvent models.UsageEvent + firstQuery, _ := usageIdentityEventsQuery(tx.Model(&models.UsageEvent{}), identity) + if err := firstQuery.Where("id > ?", identity.LastAggregatedUsageEventID).Order("timestamp asc, id asc").First(&firstEvent).Error; err != nil { + return delta, fmt.Errorf("find first usage identity event for %q: %w", identity.Identity, err) + } + firstUsedAt := firstEvent.Timestamp + delta.FirstUsedAt = &firstUsedAt + + var lastEvent models.UsageEvent + lastQuery, _ := usageIdentityEventsQuery(tx.Model(&models.UsageEvent{}), identity) + if err := lastQuery.Where("id > ?", identity.LastAggregatedUsageEventID).Order("timestamp desc, id desc").First(&lastEvent).Error; err != nil { + return delta, fmt.Errorf("find last usage identity event for %q: %w", identity.Identity, err) + } + lastUsedAt := lastEvent.Timestamp + delta.LastUsedAt = &lastUsedAt + + return delta, nil +} + +func usageIdentityEventsQuery(query *gorm.DB, identity models.UsageIdentity) (*gorm.DB, bool) { + switch identity.AuthType { + case models.UsageIdentityAuthTypeAuthFile: + return query.Where("auth_type = ? AND auth_index = ?", "oauth", identity.Identity), true + case models.UsageIdentityAuthTypeAIProvider: + return query.Where("auth_type = ? AND source = ?", "apikey", identity.Identity), true + default: + return query, false + } +} + +func normalizeUsageIdentities(identities []models.UsageIdentity, authType models.UsageIdentityAuthType) ([]models.UsageIdentity, []string) { + normalized := make([]models.UsageIdentity, 0, len(identities)) + incomingIdentities := make([]string, 0, len(identities)) + seen := make(map[string]struct{}, len(identities)) + + for _, identity := range identities { + key := strings.TrimSpace(identity.Identity) + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + incomingIdentities = append(incomingIdentities, key) + + identity.ID = 0 + identity.AuthType = authType + identity.Identity = key + identity.Name = strings.TrimSpace(identity.Name) + identity.AuthTypeName = strings.TrimSpace(identity.AuthTypeName) + identity.Type = strings.TrimSpace(identity.Type) + identity.Provider = strings.TrimSpace(identity.Provider) + identity.IsDeleted = false + identity.DeletedAt = nil + normalized = append(normalized, identity) + } + + return normalized, incomingIdentities +} + +func normalizeProviderTypes(providerTypes []string) []string { + seen := make(map[string]struct{}, len(providerTypes)) + types := make([]string, 0, len(providerTypes)) + for _, providerType := range providerTypes { + providerType = strings.TrimSpace(providerType) + if providerType == "" { + continue + } + if _, ok := seen[providerType]; ok { + continue + } + seen[providerType] = struct{}{} + types = append(types, providerType) + } + return types +} + +func upsertUsageIdentities(tx *gorm.DB, identities []models.UsageIdentity) error { + if len(identities) == 0 { + return nil + } + + if err := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "auth_type"}, {Name: "identity"}}, + DoUpdates: clause.Assignments(map[string]any{ + "name": gorm.Expr("excluded.name"), + "auth_type_name": gorm.Expr("excluded.auth_type_name"), + "type": gorm.Expr("excluded.type"), + "provider": gorm.Expr("excluded.provider"), + "is_deleted": false, + "deleted_at": nil, + "updated_at": gorm.Expr("excluded.updated_at"), + }), + }).Create(&identities).Error; err != nil { + return fmt.Errorf("upsert usage identities: %w", err) + } + return nil +} diff --git a/internal/repository/usage_identities_test.go b/internal/repository/usage_identities_test.go new file mode 100644 index 00000000..a99f49e6 --- /dev/null +++ b/internal/repository/usage_identities_test.go @@ -0,0 +1,532 @@ +package repository + +import ( + "context" + "testing" + "time" + + "cpa-usage-keeper/internal/models" +) + +func TestUsageIdentityReplaceForAuthTypeMarksStaleRowsDeletedAndPreservesStats(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + now := time.Date(2026, 5, 4, 10, 0, 0, 0, time.UTC) + firstUsedAt := now.Add(-2 * time.Hour) + lastUsedAt := now.Add(-time.Hour) + statsUpdatedAt := now.Add(-30 * time.Minute) + + existingActive := models.UsageIdentity{ + Name: "Old Name", + AuthType: models.UsageIdentityAuthTypeAuthFile, + Identity: "auth-1", + Type: "account", + Provider: "claude", + TotalRequests: 10, + SuccessCount: 8, + FailureCount: 2, + InputTokens: 100, + OutputTokens: 50, + TotalTokens: 150, + LastAggregatedUsageEventID: 42, + FirstUsedAt: &firstUsedAt, + LastUsedAt: &lastUsedAt, + StatsUpdatedAt: &statsUpdatedAt, + } + existingStale := models.UsageIdentity{ + Name: "Stale", + AuthType: models.UsageIdentityAuthTypeAuthFile, + Identity: "auth-stale", + Type: "account", + Provider: "claude", + } + unrelatedProvider := models.UsageIdentity{ + Name: "Provider", + AuthType: models.UsageIdentityAuthTypeAIProvider, + Identity: "provider-1", + Type: "openai", + Provider: "OpenAI", + } + if err := db.Create(&[]models.UsageIdentity{existingActive, existingStale, unrelatedProvider}).Error; err != nil { + t.Fatalf("seed usage identities: %v", err) + } + + err := ReplaceUsageIdentitiesForAuthType(ctx, db, []models.UsageIdentity{ + { + Name: "New Name", + AuthTypeName: "oauth", + Identity: "auth-1", + Type: "account", + Provider: "claude-code", + }, + { + Name: "New Auth", + AuthTypeName: "oauth", + Identity: "auth-2", + Type: "account", + Provider: "claude-code", + }, + }, models.UsageIdentityAuthTypeAuthFile, now) + if err != nil { + t.Fatalf("ReplaceUsageIdentitiesForAuthType returned error: %v", err) + } + + rows, err := ListUsageIdentities(ctx, db) + if err != nil { + t.Fatalf("ListUsageIdentities returned error: %v", err) + } + byIdentity := usageIdentitiesByIdentity(rows) + + updated := byIdentity["auth-1"] + if updated.Name != "New Name" || updated.Provider != "claude-code" || updated.AuthType != models.UsageIdentityAuthTypeAuthFile || updated.IsDeleted { + t.Fatalf("expected active metadata update for auth-1, got %+v", updated) + } + if updated.TotalRequests != 10 || updated.SuccessCount != 8 || updated.FailureCount != 2 || updated.InputTokens != 100 || updated.OutputTokens != 50 || updated.TotalTokens != 150 || updated.LastAggregatedUsageEventID != 42 { + t.Fatalf("expected stats to be preserved, got %+v", updated) + } + if updated.FirstUsedAt == nil || !updated.FirstUsedAt.Equal(firstUsedAt) || updated.LastUsedAt == nil || !updated.LastUsedAt.Equal(lastUsedAt) || updated.StatsUpdatedAt == nil || !updated.StatsUpdatedAt.Equal(statsUpdatedAt) { + t.Fatalf("expected usage timestamps to be preserved, got %+v", updated) + } + + inserted := byIdentity["auth-2"] + if inserted.ID == 0 || inserted.IsDeleted || inserted.AuthType != models.UsageIdentityAuthTypeAuthFile || inserted.Name != "New Auth" { + t.Fatalf("expected active inserted auth-2, got %+v", inserted) + } + + stale := byIdentity["auth-stale"] + if !stale.IsDeleted || stale.DeletedAt == nil || !stale.DeletedAt.Equal(now) { + t.Fatalf("expected stale auth identity to be deleted at %s, got %+v", now, stale) + } + + provider := byIdentity["provider-1"] + if provider.IsDeleted || provider.DeletedAt != nil { + t.Fatalf("expected unrelated provider identity untouched, got %+v", provider) + } +} + +func TestUsageIdentityReplaceForAuthTypeRestoresDeletedIdentity(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + deletedAt := time.Date(2026, 5, 3, 10, 0, 0, 0, time.UTC) + now := deletedAt.Add(24 * time.Hour) + + deleted := models.UsageIdentity{ + Name: "Deleted", + AuthType: models.UsageIdentityAuthTypeAuthFile, + AuthTypeName: "oauth", + Identity: "auth-1", + Type: "account", + Provider: "claude", + TotalRequests: 7, + IsDeleted: true, + DeletedAt: &deletedAt, + } + if err := db.Create(&deleted).Error; err != nil { + t.Fatalf("seed deleted identity: %v", err) + } + + err := ReplaceUsageIdentitiesForAuthType(ctx, db, []models.UsageIdentity{ + { + Name: "Restored", + AuthTypeName: "oauth", + Identity: "auth-1", + Type: "account", + Provider: "claude-code", + }, + }, models.UsageIdentityAuthTypeAuthFile, now) + if err != nil { + t.Fatalf("ReplaceUsageIdentitiesForAuthType returned error: %v", err) + } + + rows, err := ListUsageIdentities(ctx, db) + if err != nil { + t.Fatalf("ListUsageIdentities returned error: %v", err) + } + restored := usageIdentitiesByIdentity(rows)["auth-1"] + if restored.IsDeleted || restored.DeletedAt != nil { + t.Fatalf("expected deleted identity to be restored, got %+v", restored) + } + if restored.Name != "Restored" || restored.Provider != "claude-code" || restored.TotalRequests != 7 { + t.Fatalf("expected metadata update with stats preserved, got %+v", restored) + } +} + +func TestUsageIdentityReplaceForProviderTypesMarksOnlyScopedProviderTypesDeleted(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + now := time.Date(2026, 5, 4, 10, 0, 0, 0, time.UTC) + + seed := []models.UsageIdentity{ + {Name: "OpenAI Keep", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "openai-keep", Type: "openai", Provider: "OpenAI", TotalRequests: 3}, + {Name: "OpenAI Stale", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "openai-stale", Type: "openai", Provider: "OpenAI"}, + {Name: "Gemini Untouched", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "gemini-untouched", Type: "gemini", Provider: "Gemini"}, + {Name: "Auth Untouched", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-untouched", Type: "account", Provider: "claude"}, + } + if err := db.Create(&seed).Error; err != nil { + t.Fatalf("seed usage identities: %v", err) + } + + err := ReplaceUsageIdentitiesForProviderTypes(ctx, db, []models.UsageIdentity{ + {Name: "OpenAI Updated", AuthTypeName: "apikey", Identity: "openai-keep", Type: "openai", Provider: "OpenAI"}, + {Name: "Anthropic New", AuthTypeName: "apikey", Identity: "anthropic-new", Type: "anthropic", Provider: "Anthropic"}, + }, []string{"openai", "anthropic"}, now) + if err != nil { + t.Fatalf("ReplaceUsageIdentitiesForProviderTypes returned error: %v", err) + } + + rows, err := ListUsageIdentities(ctx, db) + if err != nil { + t.Fatalf("ListUsageIdentities returned error: %v", err) + } + byIdentity := usageIdentitiesByIdentity(rows) + + openAIKeep := byIdentity["openai-keep"] + if openAIKeep.IsDeleted || openAIKeep.Name != "OpenAI Updated" || openAIKeep.TotalRequests != 3 { + t.Fatalf("expected scoped provider identity updated with stats preserved, got %+v", openAIKeep) + } + + openAIStale := byIdentity["openai-stale"] + if !openAIStale.IsDeleted || openAIStale.DeletedAt == nil || !openAIStale.DeletedAt.Equal(now) { + t.Fatalf("expected missing scoped provider identity to be deleted, got %+v", openAIStale) + } + + gemini := byIdentity["gemini-untouched"] + if gemini.IsDeleted || gemini.DeletedAt != nil { + t.Fatalf("expected unmentioned provider type untouched, got %+v", gemini) + } + + auth := byIdentity["auth-untouched"] + if auth.IsDeleted || auth.DeletedAt != nil { + t.Fatalf("expected auth identity untouched by provider replacement, got %+v", auth) + } + + anthropic := byIdentity["anthropic-new"] + if anthropic.ID == 0 || anthropic.IsDeleted || anthropic.AuthType != models.UsageIdentityAuthTypeAIProvider { + t.Fatalf("expected new provider identity active, got %+v", anthropic) + } +} + +func TestUsageIdentityReplaceForProviderTypesWithEmptyProviderTypesDoesNotDeleteExistingRows(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + deletedAt := time.Date(2026, 5, 3, 10, 0, 0, 0, time.UTC) + now := deletedAt.Add(24 * time.Hour) + + seed := []models.UsageIdentity{ + {Name: "OpenAI Active", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "openai-active", Type: "openai", Provider: "OpenAI"}, + {Name: "Gemini Active", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "gemini-active", Type: "gemini", Provider: "Gemini"}, + {Name: "Deleted Provider", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "provider-restore", Type: "anthropic", Provider: "Anthropic", TotalRequests: 9, IsDeleted: true, DeletedAt: &deletedAt}, + } + if err := db.Create(&seed).Error; err != nil { + t.Fatalf("seed usage identities: %v", err) + } + + err := ReplaceUsageIdentitiesForProviderTypes(ctx, db, []models.UsageIdentity{ + {Name: "Restored Provider", AuthTypeName: "apikey", Identity: "provider-restore", Type: "anthropic", Provider: "Anthropic Updated"}, + }, []string{"", " ", "\t"}, now) + if err != nil { + t.Fatalf("ReplaceUsageIdentitiesForProviderTypes returned error: %v", err) + } + + rows, err := ListUsageIdentities(ctx, db) + if err != nil { + t.Fatalf("ListUsageIdentities returned error: %v", err) + } + byIdentity := usageIdentitiesByIdentity(rows) + + for _, identity := range []string{"openai-active", "gemini-active"} { + row := byIdentity[identity] + if row.IsDeleted || row.DeletedAt != nil { + t.Fatalf("expected existing provider identity %s untouched, got %+v", identity, row) + } + } + + restored := byIdentity["provider-restore"] + if restored.IsDeleted || restored.DeletedAt != nil { + t.Fatalf("expected incoming provider identity restored, got %+v", restored) + } + if restored.Name != "Restored Provider" || restored.Provider != "Anthropic Updated" || restored.AuthTypeName != "apikey" || restored.TotalRequests != 9 { + t.Fatalf("expected incoming provider identity updated with stats preserved, got %+v", restored) + } +} + +func TestUsageIdentityListOrdersByAuthTypeNameIDAndIncludesDeletedRows(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + deletedAt := time.Date(2026, 5, 4, 10, 0, 0, 0, time.UTC) + + seed := []models.UsageIdentity{ + {Name: "Zulu", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "provider-zulu", Type: "openai", Provider: "OpenAI"}, + {Name: "Beta", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-beta-1", Type: "account", Provider: "claude"}, + {Name: "Alpha", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-alpha", Type: "account", Provider: "claude", IsDeleted: true, DeletedAt: &deletedAt}, + {Name: "Beta", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-beta-2", Type: "account", Provider: "claude"}, + {Name: "Alpha", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "provider-alpha", Type: "gemini", Provider: "Gemini", IsDeleted: true, DeletedAt: &deletedAt}, + } + if err := db.Create(&seed).Error; err != nil { + t.Fatalf("seed usage identities: %v", err) + } + + rows, err := ListUsageIdentities(ctx, db) + if err != nil { + t.Fatalf("ListUsageIdentities returned error: %v", err) + } + + got := make([]string, 0, len(rows)) + for _, row := range rows { + deleted := "active" + if row.IsDeleted { + deleted = "deleted" + } + got = append(got, row.Identity+":"+deleted) + } + + want := []string{ + "auth-alpha:deleted", + "auth-beta-1:active", + "auth-beta-2:active", + "provider-alpha:deleted", + "provider-zulu:active", + } + if len(got) != len(want) { + t.Fatalf("expected %d identities, got %d: %v", len(want), len(got), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("expected identities ordered by auth_type asc, name asc, id asc including deleted rows\nwant: %v\n got: %v", want, got) + } + } +} + +func TestUsageIdentityAggregateStatsForAuthFileUsesOAuthAuthIndex(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + now := time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC) + first := now.Add(-3 * time.Hour) + last := now.Add(-time.Hour) + + identity := models.UsageIdentity{ + Name: "Auth Account", + AuthType: models.UsageIdentityAuthTypeAuthFile, + AuthTypeName: "oauth", + Identity: "auth-1", + Type: "account", + Provider: "claude", + } + if err := db.Create(&identity).Error; err != nil { + t.Fatalf("seed usage identity: %v", err) + } + + events := []models.UsageEvent{ + {EventKey: "auth-1", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-1", Source: "wrong-source", RequestID: "r1", Timestamp: last, Failed: false, InputTokens: 10, OutputTokens: 20, ReasoningTokens: 3, CachedTokens: 4, TotalTokens: 37}, + {EventKey: "auth-2", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-1", Source: "wrong-source", RequestID: "r2", Timestamp: first, Failed: true, InputTokens: 5, OutputTokens: 6, ReasoningTokens: 7, CachedTokens: 8, TotalTokens: 26}, + {EventKey: "auth-ignore-auth-type", APIGroupKey: "g1", AuthType: "apikey", AuthIndex: "auth-1", Source: "auth-1", RequestID: "r3", Timestamp: now, Failed: false, InputTokens: 100, TotalTokens: 100}, + {EventKey: "auth-ignore-index", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "other-auth", Source: "auth-1", RequestID: "r4", Timestamp: now, Failed: false, InputTokens: 100, TotalTokens: 100}, + } + if err := db.Create(&events).Error; err != nil { + t.Fatalf("seed usage events: %v", err) + } + + if err := AggregateUsageIdentityStats(ctx, db, now); err != nil { + t.Fatalf("AggregateUsageIdentityStats returned error: %v", err) + } + + var got models.UsageIdentity + if err := db.First(&got, identity.ID).Error; err != nil { + t.Fatalf("load usage identity: %v", err) + } + if got.TotalRequests != 2 || got.SuccessCount != 1 || got.FailureCount != 1 || got.InputTokens != 15 || got.OutputTokens != 26 || got.ReasoningTokens != 10 || got.CachedTokens != 12 || got.TotalTokens != 63 { + t.Fatalf("expected aggregated auth stats, got %+v", got) + } + if got.FirstUsedAt == nil || !got.FirstUsedAt.Equal(first) || got.LastUsedAt == nil || !got.LastUsedAt.Equal(last) || got.StatsUpdatedAt == nil || !got.StatsUpdatedAt.Equal(now) { + t.Fatalf("expected usage timestamps first=%s last=%s updated=%s, got %+v", first, last, now, got) + } + if got.LastAggregatedUsageEventID != events[1].ID { + t.Fatalf("expected cursor %d, got %d", events[1].ID, got.LastAggregatedUsageEventID) + } +} + +func TestUsageIdentityAggregateStatsForAIProviderUsesAPIKeySourceNotProvider(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + now := time.Date(2026, 5, 4, 13, 0, 0, 0, time.UTC) + + identity := models.UsageIdentity{Name: "Provider", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "provider-source", Type: "openai", Provider: "Display Provider"} + if err := db.Create(&identity).Error; err != nil { + t.Fatalf("seed usage identity: %v", err) + } + + events := []models.UsageEvent{ + {EventKey: "provider-source-1", APIGroupKey: "g1", Provider: "wrong-provider", AuthType: "apikey", Source: "provider-source", RequestID: "r1", Timestamp: now.Add(-2 * time.Hour), Failed: false, InputTokens: 11, OutputTokens: 12, ReasoningTokens: 13, CachedTokens: 14, TotalTokens: 50}, + {EventKey: "provider-source-2", APIGroupKey: "g1", Provider: "Display Provider", AuthType: "apikey", Source: "provider-source", RequestID: "r2", Timestamp: now.Add(-time.Hour), Failed: true, InputTokens: 1, OutputTokens: 2, ReasoningTokens: 3, CachedTokens: 4, TotalTokens: 10}, + {EventKey: "provider-ignore-provider", APIGroupKey: "g1", Provider: "provider-source", AuthType: "apikey", Source: "other-source", RequestID: "r3", Timestamp: now, Failed: false, InputTokens: 100, TotalTokens: 100}, + {EventKey: "provider-ignore-auth-type", APIGroupKey: "g1", Provider: "wrong-provider", AuthType: "oauth", Source: "provider-source", RequestID: "r4", Timestamp: now, Failed: false, InputTokens: 100, TotalTokens: 100}, + } + if err := db.Create(&events).Error; err != nil { + t.Fatalf("seed usage events: %v", err) + } + + if err := AggregateUsageIdentityStats(ctx, db, now); err != nil { + t.Fatalf("AggregateUsageIdentityStats returned error: %v", err) + } + + var got models.UsageIdentity + if err := db.First(&got, identity.ID).Error; err != nil { + t.Fatalf("load usage identity: %v", err) + } + if got.TotalRequests != 2 || got.SuccessCount != 1 || got.FailureCount != 1 || got.InputTokens != 12 || got.OutputTokens != 14 || got.ReasoningTokens != 16 || got.CachedTokens != 18 || got.TotalTokens != 60 { + t.Fatalf("expected provider stats matched by source, got %+v", got) + } + if got.LastAggregatedUsageEventID != events[1].ID { + t.Fatalf("expected cursor %d, got %d", events[1].ID, got.LastAggregatedUsageEventID) + } +} + +func TestUsageIdentityAggregateStatsSecondRunOnlyIncludesEventsAfterCursor(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + now := time.Date(2026, 5, 4, 14, 0, 0, 0, time.UTC) + first := now.Add(-2 * time.Hour) + last := now.Add(-time.Hour) + + identity := models.UsageIdentity{Name: "Auth Account", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-cursor", Type: "account", Provider: "claude"} + if err := db.Create(&identity).Error; err != nil { + t.Fatalf("seed usage identity: %v", err) + } + initialEvents := []models.UsageEvent{ + {EventKey: "cursor-1", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-cursor", RequestID: "r1", Timestamp: first, Failed: false, InputTokens: 10, TotalTokens: 10}, + {EventKey: "cursor-2", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-cursor", RequestID: "r2", Timestamp: last, Failed: true, InputTokens: 20, TotalTokens: 20}, + } + if err := db.Create(&initialEvents).Error; err != nil { + t.Fatalf("seed initial usage events: %v", err) + } + if err := AggregateUsageIdentityStats(ctx, db, now); err != nil { + t.Fatalf("first AggregateUsageIdentityStats returned error: %v", err) + } + + newEvent := models.UsageEvent{EventKey: "cursor-3", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-cursor", RequestID: "r3", Timestamp: now, Failed: false, InputTokens: 30, OutputTokens: 5, TotalTokens: 35} + if err := db.Create(&newEvent).Error; err != nil { + t.Fatalf("seed new usage event: %v", err) + } + secondNow := now.Add(time.Hour) + if err := AggregateUsageIdentityStats(ctx, db, secondNow); err != nil { + t.Fatalf("second AggregateUsageIdentityStats returned error: %v", err) + } + + var got models.UsageIdentity + if err := db.First(&got, identity.ID).Error; err != nil { + t.Fatalf("load usage identity: %v", err) + } + if got.TotalRequests != 3 || got.SuccessCount != 2 || got.FailureCount != 1 || got.InputTokens != 60 || got.OutputTokens != 5 || got.TotalTokens != 65 { + t.Fatalf("expected second aggregation to include only new event once, got %+v", got) + } + if got.LastAggregatedUsageEventID != newEvent.ID || got.StatsUpdatedAt == nil || !got.StatsUpdatedAt.Equal(secondNow) { + t.Fatalf("expected cursor %d and updated timestamp %s, got %+v", newEvent.ID, secondNow, got) + } +} + +func TestUsageIdentityAggregateStatsLateTimestampWithLargerIDStillAggregates(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + now := time.Date(2026, 5, 4, 15, 0, 0, 0, time.UTC) + initialTime := now.Add(-time.Hour) + earlierLateTime := now.Add(-24 * time.Hour) + + identity := models.UsageIdentity{Name: "Auth Late", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-late", Type: "account", Provider: "claude"} + if err := db.Create(&identity).Error; err != nil { + t.Fatalf("seed usage identity: %v", err) + } + initialEvent := models.UsageEvent{EventKey: "late-1", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-late", RequestID: "r1", Timestamp: initialTime, Failed: false, InputTokens: 10, TotalTokens: 10} + if err := db.Create(&initialEvent).Error; err != nil { + t.Fatalf("seed initial event: %v", err) + } + if err := AggregateUsageIdentityStats(ctx, db, now); err != nil { + t.Fatalf("first AggregateUsageIdentityStats returned error: %v", err) + } + + lateEvent := models.UsageEvent{EventKey: "late-2", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-late", RequestID: "r2", Timestamp: earlierLateTime, Failed: true, InputTokens: 20, TotalTokens: 20} + if err := db.Create(&lateEvent).Error; err != nil { + t.Fatalf("seed late event: %v", err) + } + if err := AggregateUsageIdentityStats(ctx, db, now.Add(time.Hour)); err != nil { + t.Fatalf("second AggregateUsageIdentityStats returned error: %v", err) + } + + var got models.UsageIdentity + if err := db.First(&got, identity.ID).Error; err != nil { + t.Fatalf("load usage identity: %v", err) + } + if got.TotalRequests != 2 || got.SuccessCount != 1 || got.FailureCount != 1 || got.InputTokens != 30 || got.TotalTokens != 30 { + t.Fatalf("expected late timestamp event with larger DB id aggregated, got %+v", got) + } + if got.FirstUsedAt == nil || !got.FirstUsedAt.Equal(earlierLateTime) || got.LastUsedAt == nil || !got.LastUsedAt.Equal(initialTime) || got.LastAggregatedUsageEventID != lateEvent.ID { + t.Fatalf("expected first_used_at to move earlier and cursor to late event id %d, got %+v", lateEvent.ID, got) + } +} + +func TestUsageIdentityAggregateStatsUsesDatabaseIDNotRequestIDOrdering(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + now := time.Date(2026, 5, 4, 16, 0, 0, 0, time.UTC) + + identity := models.UsageIdentity{Name: "Auth Request", AuthType: models.UsageIdentityAuthTypeAuthFile, AuthTypeName: "oauth", Identity: "auth-request", Type: "account", Provider: "claude"} + if err := db.Create(&identity).Error; err != nil { + t.Fatalf("seed usage identity: %v", err) + } + events := []models.UsageEvent{ + {EventKey: "request-1", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-request", RequestID: "z-last-lexically", Timestamp: now.Add(-2 * time.Hour), Failed: false, InputTokens: 10, TotalTokens: 10}, + {EventKey: "request-2", APIGroupKey: "g1", AuthType: "oauth", AuthIndex: "auth-request", RequestID: "a-first-lexically", Timestamp: now.Add(-time.Hour), Failed: false, InputTokens: 20, TotalTokens: 20}, + } + if err := db.Create(&events).Error; err != nil { + t.Fatalf("seed usage events: %v", err) + } + if err := AggregateUsageIdentityStats(ctx, db, now); err != nil { + t.Fatalf("AggregateUsageIdentityStats returned error: %v", err) + } + + var got models.UsageIdentity + if err := db.First(&got, identity.ID).Error; err != nil { + t.Fatalf("load usage identity: %v", err) + } + if got.TotalRequests != 2 || got.InputTokens != 30 || got.TotalTokens != 30 || got.LastAggregatedUsageEventID != events[1].ID { + t.Fatalf("expected unordered request_id values aggregated by DB id, got %+v", got) + } +} + +func TestUsageIdentityAggregateStatsDeletedIdentityStillAggregates(t *testing.T) { + db := openTestDatabase(t) + ctx := context.Background() + now := time.Date(2026, 5, 4, 17, 0, 0, 0, time.UTC) + deletedAt := now.Add(-time.Hour) + + identity := models.UsageIdentity{Name: "Deleted Provider", AuthType: models.UsageIdentityAuthTypeAIProvider, AuthTypeName: "apikey", Identity: "deleted-source", Type: "openai", Provider: "OpenAI", IsDeleted: true, DeletedAt: &deletedAt} + if err := db.Create(&identity).Error; err != nil { + t.Fatalf("seed deleted usage identity: %v", err) + } + event := models.UsageEvent{EventKey: "deleted-1", APIGroupKey: "g1", AuthType: "apikey", Source: "deleted-source", RequestID: "r1", Timestamp: now, Failed: false, InputTokens: 10, OutputTokens: 5, TotalTokens: 15} + if err := db.Create(&event).Error; err != nil { + t.Fatalf("seed usage event: %v", err) + } + + if err := AggregateUsageIdentityStats(ctx, db, now); err != nil { + t.Fatalf("AggregateUsageIdentityStats returned error: %v", err) + } + + var got models.UsageIdentity + if err := db.First(&got, identity.ID).Error; err != nil { + t.Fatalf("load usage identity: %v", err) + } + if !got.IsDeleted || got.DeletedAt == nil || !got.DeletedAt.Equal(deletedAt) { + t.Fatalf("expected deleted state preserved, got %+v", got) + } + if got.TotalRequests != 1 || got.SuccessCount != 1 || got.FailureCount != 0 || got.InputTokens != 10 || got.OutputTokens != 5 || got.TotalTokens != 15 || got.LastAggregatedUsageEventID != event.ID { + t.Fatalf("expected deleted identity to aggregate matching event, got %+v", got) + } +} + +func usageIdentitiesByIdentity(rows []models.UsageIdentity) map[string]models.UsageIdentity { + result := make(map[string]models.UsageIdentity, len(rows)) + for _, row := range rows { + result[row.Identity] = row + } + return result +} diff --git a/internal/repository/usage_test.go b/internal/repository/usage_test.go index aa129e36..785fd25b 100644 --- a/internal/repository/usage_test.go +++ b/internal/repository/usage_test.go @@ -28,9 +28,9 @@ func TestBuildUsageSnapshotReturnsEmptyStructureWithoutEvents(t *testing.T) { func TestBuildUsageSnapshotAggregatesEvents(t *testing.T) { db := openUsageTestDatabase(t) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "codex-a", AuthIndex: "1", Failed: false, LatencyMS: 100, InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, CachedTokens: 0, TotalTokens: 35}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "codex-b", AuthIndex: "2", Failed: true, LatencyMS: 200, InputTokens: 2, OutputTokens: 3, ReasoningTokens: 0, CachedTokens: 0, TotalTokens: 5}, - {EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 17, 10, 0, 0, 0, time.UTC), Source: "codex-c", AuthIndex: "3", Failed: false, LatencyMS: 300, InputTokens: 100, OutputTokens: 50, ReasoningTokens: 25, CachedTokens: 10, TotalTokens: 185}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "codex-a", AuthIndex: "1", Failed: false, LatencyMS: 100, InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, CachedTokens: 0, TotalTokens: 35}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "codex-b", AuthIndex: "2", Failed: true, LatencyMS: 200, InputTokens: 2, OutputTokens: 3, ReasoningTokens: 0, CachedTokens: 0, TotalTokens: 5}, + {EventKey: "event-3", APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 17, 10, 0, 0, 0, time.UTC), Source: "codex-c", AuthIndex: "3", Failed: false, LatencyMS: 300, InputTokens: 100, OutputTokens: 50, ReasoningTokens: 25, CachedTokens: 10, TotalTokens: 185}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) @@ -73,7 +73,6 @@ func TestBuildUsageSnapshotBucketsDaysByLocalTime(t *testing.T) { db := openUsageTestDatabase(t) events := []models.UsageEvent{{ EventKey: "event-local-day", - SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 23, 30, 0, 0, time.UTC), @@ -112,7 +111,6 @@ func TestBuildUsageSnapshotPreservesStoredAPIKey(t *testing.T) { db := openUsageTestDatabase(t) events := []models.UsageEvent{{ EventKey: "event-1", - SnapshotRunID: 1, APIGroupKey: "sk-live-secret-value", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 20, 12, 0, 0, 0, time.UTC), @@ -136,9 +134,9 @@ func TestBuildUsageSnapshotPreservesStoredAPIKey(t *testing.T) { func TestUsageAggregatesApplyModelSourceAuthAndResultFilters(t *testing.T) { db := openUsageTestDatabase(t) events := []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", AuthIndex: "1", Failed: false, TotalTokens: 35}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-a", AuthIndex: "1", Failed: true, TotalTokens: 5}, - {EventKey: "event-3", SnapshotRunID: 1, APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Source: "source-b", AuthIndex: "2", Failed: false, TotalTokens: 185}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), Source: "source-a", AuthIndex: "1", Failed: false, TotalTokens: 35}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), Source: "source-a", AuthIndex: "1", Failed: true, TotalTokens: 5}, + {EventKey: "event-3", APIGroupKey: "provider-b", Model: "claude-opus", Timestamp: time.Date(2026, 4, 16, 11, 0, 0, 0, time.UTC), Source: "source-b", AuthIndex: "2", Failed: false, TotalTokens: 185}, } if _, _, err := InsertUsageEvents(db, events); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) diff --git a/internal/service/auth_files_service.go b/internal/service/auth_files_service.go deleted file mode 100644 index 11158e26..00000000 --- a/internal/service/auth_files_service.go +++ /dev/null @@ -1,25 +0,0 @@ -package service - -import ( - "context" - - "cpa-usage-keeper/internal/models" - "cpa-usage-keeper/internal/repository" - "gorm.io/gorm" -) - -type AuthFileProvider interface { - ListAuthFiles(context.Context) ([]models.AuthFile, error) -} - -type authFileService struct { - db *gorm.DB -} - -func NewAuthFileService(db *gorm.DB) AuthFileProvider { - return &authFileService{db: db} -} - -func (s *authFileService) ListAuthFiles(context.Context) ([]models.AuthFile, error) { - return repository.ListAuthFiles(s.db) -} diff --git a/internal/service/flatten.go b/internal/service/flatten.go index dfcf7a1a..0e470f0d 100644 --- a/internal/service/flatten.go +++ b/internal/service/flatten.go @@ -8,57 +8,8 @@ import ( "time" "cpa-usage-keeper/internal/cpa" - "cpa-usage-keeper/internal/models" ) -func FlattenUsageExport(snapshotRunID uint, export cpa.UsageExport) []models.UsageEvent { - if len(export.Usage.APIs) == 0 { - return nil - } - - events := make([]models.UsageEvent, 0) - for apiGroupKey, apiSnapshot := range export.Usage.APIs { - apiGroupKey = strings.TrimSpace(apiGroupKey) - if apiGroupKey == "" { - continue - } - - for modelName, modelSnapshot := range apiSnapshot.Models { - modelName = strings.TrimSpace(modelName) - if modelName == "" { - modelName = "unknown" - } - - for _, detail := range modelSnapshot.Details { - tokens := normalizeTokens(detail.Tokens) - timestamp := detail.Timestamp.UTC() - if timestamp.IsZero() { - timestamp = export.ExportedAt.UTC() - } - - events = append(events, models.UsageEvent{ - SnapshotRunID: snapshotRunID, - EventKey: BuildEventKey(apiGroupKey, modelName, timestamp, detail.Source, detail.AuthIndex, detail.Failed, tokens), - APIGroupKey: apiGroupKey, - Model: modelName, - Timestamp: timestamp, - Source: strings.TrimSpace(detail.Source), - AuthIndex: strings.TrimSpace(detail.AuthIndex), - Failed: detail.Failed, - LatencyMS: max(detail.LatencyMS, 0), - InputTokens: tokens.InputTokens, - OutputTokens: tokens.OutputTokens, - ReasoningTokens: tokens.ReasoningTokens, - CachedTokens: tokens.CachedTokens, - TotalTokens: tokens.TotalTokens, - }) - } - } - } - - return events -} - func BuildEventKey(apiGroupKey, model string, timestamp time.Time, source, authIndex string, failed bool, tokens cpa.TokenStats) string { normalized := normalizeTokens(tokens) payload := fmt.Sprintf( diff --git a/internal/service/flatten_test.go b/internal/service/flatten_test.go index 1c7be32f..5ae85928 100644 --- a/internal/service/flatten_test.go +++ b/internal/service/flatten_test.go @@ -18,85 +18,3 @@ func TestBuildEventKeyIsStable(t *testing.T) { t.Fatalf("expected stable event key, got %s and %s", key1, key2) } } - -func TestFlattenUsageExportBuildsEvents(t *testing.T) { - exportedAt := time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC) - export := cpa.UsageExport{ - Version: 1, - ExportedAt: exportedAt, - Usage: cpa.StatisticsSnapshot{ - APIs: map[string]cpa.APISnapshot{ - "provider-a": { - Models: map[string]cpa.ModelSnapshot{ - "claude-sonnet": { - Details: []cpa.RequestDetail{ - { - Timestamp: time.Date(2026, 4, 16, 9, 30, 0, 0, time.UTC), - LatencyMS: 123, - Source: "codex-account-a", - AuthIndex: "1", - Failed: false, - Tokens: cpa.TokenStats{ - InputTokens: 10, - OutputTokens: 20, - ReasoningTokens: 5, - CachedTokens: 0, - TotalTokens: 35, - }, - }, - }, - }, - }, - }, - }, - }, - } - - events := FlattenUsageExport(42, export) - if len(events) != 1 { - t.Fatalf("expected 1 event, got %d", len(events)) - } - - event := events[0] - if event.SnapshotRunID != 42 { - t.Fatalf("expected snapshot run id 42, got %d", event.SnapshotRunID) - } - if event.APIGroupKey != "provider-a" || event.Model != "claude-sonnet" { - t.Fatalf("unexpected event grouping: %+v", event) - } - if event.TotalTokens != 35 || event.InputTokens != 10 || event.OutputTokens != 20 || event.ReasoningTokens != 5 { - t.Fatalf("unexpected token values: %+v", event) - } - if event.EventKey == "" { - t.Fatal("expected event key to be generated") - } -} - -func TestFlattenUsageExportUsesExportedAtForZeroTimestamp(t *testing.T) { - exportedAt := time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC) - events := FlattenUsageExport(1, cpa.UsageExport{ - Version: 1, - ExportedAt: exportedAt, - Usage: cpa.StatisticsSnapshot{ - APIs: map[string]cpa.APISnapshot{ - "provider-a": { - Models: map[string]cpa.ModelSnapshot{ - "claude-sonnet": { - Details: []cpa.RequestDetail{{Tokens: cpa.TokenStats{InputTokens: 1, OutputTokens: 1}}}, - }, - }, - }, - }, - }, - }) - - if len(events) != 1 { - t.Fatalf("expected 1 event, got %d", len(events)) - } - if !events[0].Timestamp.Equal(exportedAt.UTC()) { - t.Fatalf("expected exportedAt timestamp fallback, got %s", events[0].Timestamp) - } - if events[0].TotalTokens != 2 { - t.Fatalf("expected normalized total tokens 2, got %d", events[0].TotalTokens) - } -} diff --git a/internal/service/provider_metadata_service.go b/internal/service/provider_metadata_service.go deleted file mode 100644 index 0043c501..00000000 --- a/internal/service/provider_metadata_service.go +++ /dev/null @@ -1,25 +0,0 @@ -package service - -import ( - "context" - - "cpa-usage-keeper/internal/models" - "cpa-usage-keeper/internal/repository" - "gorm.io/gorm" -) - -type ProviderMetadataProvider interface { - ListProviderMetadata(context.Context) ([]models.ProviderMetadata, error) -} - -type providerMetadataService struct { - db *gorm.DB -} - -func NewProviderMetadataService(db *gorm.DB) ProviderMetadataProvider { - return &providerMetadataService{db: db} -} - -func (s *providerMetadataService) ListProviderMetadata(context.Context) ([]models.ProviderMetadata, error) { - return repository.ListProviderMetadata(s.db) -} diff --git a/internal/service/redis_usage.go b/internal/service/redis_usage.go index 7765a490..1214c671 100644 --- a/internal/service/redis_usage.go +++ b/internal/service/redis_usage.go @@ -11,54 +11,10 @@ import ( "cpa-usage-keeper/internal/models" ) -type UsageFetchResult struct { - HTTPStatus int - RawPayload []byte - ExportedAt *time.Time - Version string - Events []models.UsageEvent -} - type RedisQueue interface { PopUsage(ctx context.Context) ([]string, error) } -type redisUsageFetcher struct { - queue RedisQueue -} - -func newRedisUsageFetcher(queue RedisQueue) redisUsageFetcher { - return redisUsageFetcher{queue: queue} -} - -func (f redisUsageFetcher) FetchUsage(ctx context.Context, fetchedAt time.Time) (*UsageFetchResult, error) { - if f.queue == nil { - return nil, fmt.Errorf("redis usage queue is nil") - } - messages, err := f.queue.PopUsage(ctx) - if err != nil { - return nil, err - } - rawMessages := make([]json.RawMessage, 0, len(messages)) - events := make([]models.UsageEvent, 0, len(messages)) - for i, message := range messages { - event, raw, err := DecodeRedisUsageMessage(message, fetchedAt) - if err != nil { - return nil, fmt.Errorf("decode redis usage message %d: %w", i, err) - } - rawMessages = append(rawMessages, raw) - events = append(events, event) - } - rawPayload, err := json.Marshal(rawMessages) - if err != nil { - return nil, fmt.Errorf("encode redis usage batch: %w", err) - } - return &UsageFetchResult{ - RawPayload: rawPayload, - Events: events, - }, nil -} - func DecodeRedisUsageMessage(message string, fetchedAt time.Time) (models.UsageEvent, json.RawMessage, error) { raw := json.RawMessage(message) var payload queuedUsageDetail @@ -83,6 +39,14 @@ type queuedUsageDetail struct { RequestID string `json:"request_id"` } +func normalizeRedisAuthType(value string) string { + trimmed := strings.ToLower(strings.TrimSpace(value)) + if trimmed == "api_key" { + return "apikey" + } + return trimmed +} + func (d queuedUsageDetail) toUsageEvent(fetchedAt time.Time) models.UsageEvent { tokens := normalizeTokens(d.Tokens) apiGroupKey := firstNonEmpty(d.APIKey, d.Provider, d.Endpoint, "unknown") @@ -100,6 +64,10 @@ func (d queuedUsageDetail) toUsageEvent(fetchedAt time.Time) models.UsageEvent { return models.UsageEvent{ EventKey: eventKey, APIGroupKey: apiGroupKey, + Provider: strings.TrimSpace(d.Provider), + Endpoint: strings.TrimSpace(d.Endpoint), + AuthType: normalizeRedisAuthType(d.AuthType), + RequestID: strings.TrimSpace(d.RequestID), Model: model, Timestamp: timestamp, Source: source, diff --git a/internal/service/redis_usage_test.go b/internal/service/redis_usage_test.go index 4a7b6e26..9fd6170c 100644 --- a/internal/service/redis_usage_test.go +++ b/internal/service/redis_usage_test.go @@ -2,7 +2,6 @@ package service import ( "context" - "encoding/json" "strings" "testing" "time" @@ -10,9 +9,10 @@ import ( "cpa-usage-keeper/internal/cpa" ) -func TestRedisUsageFetcherMapsPayloadToUsageEvent(t *testing.T) { +func TestDecodeRedisUsageMessageMapsPayloadToUsageEvent(t *testing.T) { fetchedAt := time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC) - fetcher := redisUsageFetcher{queue: staticRedisQueue{messages: []string{`{ + + event, raw, err := DecodeRedisUsageMessage(`{ "timestamp":"2026-04-27T07:59:00Z", "latency_ms":1234, "source":"sk-test", @@ -26,46 +26,40 @@ func TestRedisUsageFetcherMapsPayloadToUsageEvent(t *testing.T) { "api_key":"raw-key", "request_id":"req-123", "unknown":"ignored" - }`}}} - - result, err := fetcher.FetchUsage(context.Background(), fetchedAt) + }`, fetchedAt) if err != nil { - t.Fatalf("FetchUsage returned error: %v", err) - } - if len(result.Events) != 1 { - t.Fatalf("expected one event, got %d", len(result.Events)) + t.Fatalf("DecodeRedisUsageMessage returned error: %v", err) } - event := result.Events[0] if event.EventKey != "req-123" || event.APIGroupKey != "raw-key" || event.Model != "claude-sonnet-4-6" || event.Source != "sk-test" || event.AuthIndex != "auth-1" || !event.Failed || event.LatencyMS != 1234 { t.Fatalf("unexpected event: %+v", event) } + if event.Provider != "claude" || event.Endpoint != "/v1/messages" || event.AuthType != "apikey" || event.RequestID != "req-123" { + t.Fatalf("unexpected redis identity fields: %+v", event) + } if event.InputTokens != 10 || event.OutputTokens != 20 || event.ReasoningTokens != 3 || event.CachedTokens != 4 || event.TotalTokens != 33 { t.Fatalf("unexpected tokens: %+v", event) } if !event.Timestamp.Equal(time.Date(2026, 4, 27, 7, 59, 0, 0, time.UTC)) { t.Fatalf("unexpected timestamp: %s", event.Timestamp) } - var rawBatch []map[string]any - if err := json.Unmarshal(result.RawPayload, &rawBatch); err != nil { - t.Fatalf("raw payload is not a JSON array: %v", err) - } - if len(rawBatch) != 1 || rawBatch[0]["request_id"] != "req-123" || rawBatch[0]["unknown"] != "ignored" { - t.Fatalf("unexpected raw payload: %s", string(result.RawPayload)) + if !strings.Contains(string(raw), `"unknown":"ignored"`) { + t.Fatalf("expected raw message to be preserved, got %s", string(raw)) } } -func TestRedisUsageFetcherFallsBackFieldsAndEventKey(t *testing.T) { +func TestDecodeRedisUsageMessageFallsBackFieldsAndEventKey(t *testing.T) { fetchedAt := time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC) - fetcher := redisUsageFetcher{queue: staticRedisQueue{messages: []string{`{"latency_ms":-5,"tokens":{"input_tokens":1,"output_tokens":2},"endpoint":"/fallback"}`}}} - result, err := fetcher.FetchUsage(context.Background(), fetchedAt) + event, _, err := DecodeRedisUsageMessage(`{"latency_ms":-5,"tokens":{"input_tokens":1,"output_tokens":2},"endpoint":"/fallback"}`, fetchedAt) if err != nil { - t.Fatalf("FetchUsage returned error: %v", err) + t.Fatalf("DecodeRedisUsageMessage returned error: %v", err) } - event := result.Events[0] if event.APIGroupKey != "/fallback" || event.Model != "unknown" || event.LatencyMS != 0 { t.Fatalf("unexpected fallback event: %+v", event) } + if event.Provider != "" || event.Endpoint != "/fallback" || event.AuthType != "" || event.RequestID != "" { + t.Fatalf("unexpected fallback redis identity fields: %+v", event) + } if !event.Timestamp.Equal(fetchedAt) { t.Fatalf("expected fetchedAt timestamp, got %s", event.Timestamp) } @@ -75,14 +69,11 @@ func TestRedisUsageFetcherFallsBackFieldsAndEventKey(t *testing.T) { } } -func TestRedisUsageFetcherFallsBackToProviderWhenAPIKeyIsBlank(t *testing.T) { - fetcher := redisUsageFetcher{queue: staticRedisQueue{messages: []string{`{"api_key":" ","provider":"claude","endpoint":"/v1/messages","request_id":"req-blank-key"}`}}} - - result, err := fetcher.FetchUsage(context.Background(), time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC)) +func TestDecodeRedisUsageMessageFallsBackToProviderWhenAPIKeyIsBlank(t *testing.T) { + event, _, err := DecodeRedisUsageMessage(`{"api_key":" ","provider":"claude","endpoint":"/v1/messages","request_id":"req-blank-key"}`, time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC)) if err != nil { - t.Fatalf("FetchUsage returned error: %v", err) + t.Fatalf("DecodeRedisUsageMessage returned error: %v", err) } - event := result.Events[0] if event.EventKey != "req-blank-key" || event.APIGroupKey != "claude" { t.Fatalf("unexpected fallback event: %+v", event) } @@ -95,30 +86,6 @@ func TestDecodeRedisUsageMessageReportsOnlyMessageError(t *testing.T) { } } -func TestRedisUsageFetcherReportsMalformedJSONIndex(t *testing.T) { - fetcher := redisUsageFetcher{queue: staticRedisQueue{messages: []string{`{"request_id":"ok"}`, `{bad-json}`}}} - - _, err := fetcher.FetchUsage(context.Background(), time.Now()) - if err == nil || !strings.Contains(err.Error(), "decode redis usage message 1") { - t.Fatalf("expected message index decode error, got %v", err) - } -} - -func TestRedisUsageFetcherHandlesEmptyBatch(t *testing.T) { - fetcher := redisUsageFetcher{queue: staticRedisQueue{messages: nil}} - - result, err := fetcher.FetchUsage(context.Background(), time.Now()) - if err != nil { - t.Fatalf("FetchUsage returned error: %v", err) - } - if len(result.Events) != 0 { - t.Fatalf("expected no events, got %d", len(result.Events)) - } - if string(result.RawPayload) != "[]" { - t.Fatalf("expected empty raw payload array, got %s", string(result.RawPayload)) - } -} - type staticRedisQueue struct { messages []string err error diff --git a/internal/service/sync.go b/internal/service/sync.go index ef7e8054..f95c3e8a 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -2,8 +2,6 @@ package service import ( "context" - "crypto/sha256" - "encoding/hex" "fmt" "strings" "time" @@ -16,14 +14,6 @@ import ( "gorm.io/gorm" ) -type ExportFetcher interface { - FetchUsageExport(ctx context.Context) (*cpa.ExportResult, error) -} - -type UsageFetcher interface { - FetchUsage(ctx context.Context, fetchedAt time.Time) (*UsageFetchResult, error) -} - type MetadataFetcher interface { FetchAuthFiles(ctx context.Context) (*cpa.AuthFilesResult, error) FetchGeminiAPIKeys(ctx context.Context) (*cpa.ProviderKeyConfigResult, error) @@ -34,11 +24,9 @@ type MetadataFetcher interface { } type CPAClientFetcher interface { - ExportFetcher MetadataFetcher } -const syncPrefilterOverlapWindow = 24 * time.Hour const redisInboxProcessLimit = 1000 const ( @@ -47,33 +35,24 @@ const ( ) type SyncService struct { - db *gorm.DB - client CPAClientFetcher - usageFetcher UsageFetcher - redisUsageFetcher UsageFetcher - redisQueue RedisQueue - redisQueueKey string - usageSyncMode string - legacyUsageFetcher UsageFetcher - metadataFetcher MetadataFetcher - baseURL string - now func() time.Time + db *gorm.DB + client CPAClientFetcher + redisQueue RedisQueue + redisQueueKey string + metadataFetcher MetadataFetcher + baseURL string + now func() time.Time } type SyncResult struct { - SnapshotRunID uint Status string - HTTPStatus int InsertedEvents int DedupedEvents int - PayloadHash string - ExportedAt *time.Time } type RedisBatchSyncResult struct { Empty bool Status string - SnapshotRunID uint InsertedEvents int DedupedEvents int } @@ -88,7 +67,6 @@ func NewSyncService(db *gorm.DB, cfg config.Config) *SyncService { return NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: cfg.CPABaseURL, Client: cpa.NewClient(cfg.CPABaseURL, cfg.CPAManagementKey, cfg.RequestTimeout), - UsageSyncMode: cfg.UsageSyncMode, RedisQueue: cpa.NewRedisQueueClient(cfg.CPABaseURL, cfg.RedisQueueAddr, cfg.CPAManagementKey, cfg.RequestTimeout, cfg.RedisQueueKey, cfg.RedisQueueBatchSize), RedisQueueKey: cfg.RedisQueueKey, }) @@ -97,9 +75,7 @@ func NewSyncService(db *gorm.DB, cfg config.Config) *SyncService { type SyncServiceOptions struct { BaseURL string Client CPAClientFetcher - UsageFetcher UsageFetcher MetadataFetcher MetadataFetcher - UsageSyncMode string RedisQueue RedisQueue RedisQueueKey string Now func() time.Time @@ -110,37 +86,18 @@ func NewSyncServiceWithOptions(db *gorm.DB, opts SyncServiceOptions) *SyncServic if now == nil { now = time.Now } - usageFetcher := opts.UsageFetcher metadataFetcher := opts.MetadataFetcher if metadataFetcher == nil { metadataFetcher = opts.Client } - legacyFetcher := legacyUsageFetcher{client: opts.Client} - var redisFetcher UsageFetcher - if opts.RedisQueue != nil { - redisFetcher = newRedisUsageFetcher(opts.RedisQueue) - } - if usageFetcher == nil && opts.Client != nil { - usageFetcher = legacyFetcher - } - if opts.UsageSyncMode == "redis" { - if redisFetcher == nil { - redisFetcher = newRedisUsageFetcher(opts.RedisQueue) - } - usageFetcher = redisFetcher - } return &SyncService{ - db: db, - client: opts.Client, - usageFetcher: usageFetcher, - redisUsageFetcher: redisFetcher, - redisQueue: opts.RedisQueue, - redisQueueKey: redisQueueKey(opts.RedisQueueKey), - usageSyncMode: strings.TrimSpace(opts.UsageSyncMode), - legacyUsageFetcher: legacyFetcher, - metadataFetcher: metadataFetcher, - baseURL: strings.TrimSpace(opts.BaseURL), - now: now, + db: db, + client: opts.Client, + redisQueue: opts.RedisQueue, + redisQueueKey: redisQueueKey(opts.RedisQueueKey), + metadataFetcher: metadataFetcher, + baseURL: strings.TrimSpace(opts.BaseURL), + now: now, } } @@ -152,19 +109,16 @@ func NewSyncServiceWithClient(db *gorm.DB, baseURL string, client CPAClientFetch } func (s *SyncService) SyncOnce(ctx context.Context) error { - _, err := s.syncOnce(ctx) + _, err := s.SyncNow(ctx) return err } func (s *SyncService) SyncNow(ctx context.Context) (*SyncResult, error) { - if s != nil && s.redisQueue != nil && s.usageSyncMode == "redis" { - if _, err := s.PullRedisUsageInbox(ctx); err != nil { - return nil, err - } - result, err := s.ProcessRedisUsageInbox(ctx, true) - return syncResultFromRedisBatch(result), err + if _, err := s.PullRedisUsageInbox(ctx); err != nil { + return nil, err } - return s.syncOnce(ctx) + result, err := s.ProcessRedisUsageInbox(ctx) + return syncResultFromRedisBatch(result), err } func syncResultFromRedisBatch(result *RedisBatchSyncResult) *SyncResult { @@ -172,7 +126,6 @@ func syncResultFromRedisBatch(result *RedisBatchSyncResult) *SyncResult { return nil } return &SyncResult{ - SnapshotRunID: result.SnapshotRunID, Status: result.Status, InsertedEvents: result.InsertedEvents, DedupedEvents: result.DedupedEvents, @@ -180,7 +133,7 @@ func syncResultFromRedisBatch(result *RedisBatchSyncResult) *SyncResult { } func (s *SyncService) SyncStatus(ctx context.Context) (string, error) { - result, err := s.syncOnce(ctx) + result, err := s.SyncNow(ctx) if result == nil { return "", err } @@ -192,12 +145,20 @@ func (s *SyncService) SyncMetadata(ctx context.Context) error { return err } logrus.Debug("metadata sync started") + fetchedAt := s.now().UTC() authFilesResult, authFilesErr := s.metadataFetcher.FetchAuthFiles(ctx) providerConfig, fetchedProviderTypes, providerMetadataErr := fetchProviderMetadata(ctx, s.metadataFetcher) - err := joinErrors( - syncAuthFiles(s.db, authFilesResult, authFilesErr), - syncProviderMetadata(s.db, providerConfig, fetchedProviderTypes, providerMetadataErr), - ) + authSyncErr := syncAuthFiles(ctx, s.db, authFilesResult, authFilesErr, fetchedAt) + providerSyncErr, providerWarningErr := syncProviderMetadata(ctx, s.db, providerConfig, fetchedProviderTypes, providerMetadataErr, fetchedAt) + upsertErr := joinErrors(authSyncErr, providerSyncErr) + var aggregateErr error + if upsertErr == nil { + aggregateErr = repository.AggregateUsageIdentityStats(ctx, s.db, fetchedAt) + if aggregateErr != nil { + aggregateErr = fmt.Errorf("aggregate usage identity stats: %w", aggregateErr) + } + } + err := joinErrors(upsertErr, aggregateErr, providerWarningErr) fields := logrus.Fields{ "status": "completed", } @@ -210,7 +171,7 @@ func (s *SyncService) SyncMetadata(ctx context.Context) error { } // PullRedisUsageInbox 是 Redis 同步的拉取阶段:只 LPOP 队列消息并原样写入 redis_usage_inboxes。 -// 这个阶段不解码消息、不写 usage_events、不创建 snapshot_runs,保证 Redis 消费和本地处理职责分离。 +// 这个阶段不解码消息、不写 usage_events,保证 Redis 消费和本地处理职责分离。 func (s *SyncService) PullRedisUsageInbox(ctx context.Context) (*RedisInboxPullResult, error) { if err := s.validate(syncMetadataOptional); err != nil { return nil, err @@ -244,9 +205,9 @@ func (s *SyncService) PullRedisUsageInbox(ctx context.Context) (*RedisInboxPullR } // ProcessRedisUsageInbox 是 Redis 同步的本地处理阶段:只读取 pending/process_failed inbox 行并写入 usage_events。 -// Redis 路径不再写 snapshot_runs;成功处理后仅用 usage_event_key 记录 inbox 与最终事件的关联。 -func (s *SyncService) ProcessRedisUsageInbox(ctx context.Context, syncMetadata bool) (*RedisBatchSyncResult, error) { - if err := s.validate(syncMetadata); err != nil { +// 成功处理后仅用 usage_event_key 记录 inbox 与最终事件的关联。 +func (s *SyncService) ProcessRedisUsageInbox(ctx context.Context) (*RedisBatchSyncResult, error) { + if err := s.validate(syncMetadataOptional); err != nil { return nil, err } fetchedAt := s.now().UTC() @@ -258,7 +219,7 @@ func (s *SyncService) ProcessRedisUsageInbox(ctx context.Context, syncMetadata b return &RedisBatchSyncResult{Empty: true, Status: "empty"}, nil } logrus.WithField("row_count", len(processableRows)).Debug("redis usage inbox rows found for processing") - return s.processRedisInboxRows(ctx, processableRows, fetchedAt, syncMetadata) + return s.processRedisInboxRows(processableRows, fetchedAt) } // CleanupRedisUsageInbox 只清理 Redis inbox 表,供测试和单独维护入口使用;每日任务使用 CleanupStorage 统一执行。 @@ -270,7 +231,7 @@ func (s *SyncService) CleanupRedisUsageInbox(ctx context.Context) error { return err } -// CleanupStorage 是每日 03:00 维护任务调用的统一入口:先清 Redis inbox,再清 snapshot_runs,最后 VACUUM 收缩 SQLite。 +// CleanupStorage 是每日 03:00 维护任务调用的统一入口:先清 Redis inbox,最后 VACUUM 收缩 SQLite。 func (s *SyncService) CleanupStorage(ctx context.Context) error { if err := s.validate(syncMetadataOptional); err != nil { return err @@ -281,23 +242,20 @@ func (s *SyncService) CleanupStorage(ctx context.Context) error { // SyncRedisBatch 保留为兼容入口:先处理本地存量 inbox,空了再拉一次 Redis 并立即处理。 // 后台任务不要调用它,后台必须使用拆分后的 PullRedisUsageInbox、ProcessRedisUsageInbox 和 CleanupStorage。 -func (s *SyncService) SyncRedisBatch(ctx context.Context, syncMetadata bool) (*RedisBatchSyncResult, error) { - if result, err := s.ProcessRedisUsageInbox(ctx, syncMetadata); err != nil || result == nil || !result.Empty { +func (s *SyncService) SyncRedisBatch(ctx context.Context) (*RedisBatchSyncResult, error) { + if result, err := s.ProcessRedisUsageInbox(ctx); err != nil || result == nil || !result.Empty { return result, err } if _, err := s.PullRedisUsageInbox(ctx); err != nil { return &RedisBatchSyncResult{Status: "failed"}, err } - return s.ProcessRedisUsageInbox(ctx, syncMetadata) + return s.ProcessRedisUsageInbox(ctx) } // processRedisInboxRows 只从已落库的原始消息解码和写入事件,坏消息会标记为 decode_failed,不阻塞同批其它数据。 // 可解码但入库失败的消息标记为 process_failed,后续 ProcessRedisUsageInbox 会按 id 顺序重试。 -func (s *SyncService) processRedisInboxRows(ctx context.Context, inboxRows []models.RedisUsageInbox, fetchedAt time.Time, syncMetadata bool) (*RedisBatchSyncResult, error) { - logrus.WithFields(logrus.Fields{ - "row_count": len(inboxRows), - "sync_metadata": syncMetadata, - }).Debug("redis usage inbox processing started") +func (s *SyncService) processRedisInboxRows(inboxRows []models.RedisUsageInbox, fetchedAt time.Time) (*RedisBatchSyncResult, error) { + logrus.WithField("row_count", len(inboxRows)).Debug("redis usage inbox processing started") validRows := make([]models.RedisUsageInbox, 0, len(inboxRows)) events := make([]models.UsageEvent, 0, len(inboxRows)) decodeErrs := make([]error, 0) @@ -326,9 +284,8 @@ func (s *SyncService) processRedisInboxRows(ctx context.Context, inboxRows []mod return &RedisBatchSyncResult{Empty: true, Status: "empty"}, nil } - fetchResult := &UsageFetchResult{Events: events} logrus.WithField("event_count", len(events)).Debug("redis usage events persistence started") - result, err := s.persistRedisUsageEvents(ctx, fetchResult, syncMetadata) + result, err := s.persistRedisUsageEvents(events) if result == nil { markRedisInboxRowsProcessFailed(s.db, validRows, err) return nil, err @@ -338,7 +295,7 @@ func (s *SyncService) processRedisInboxRows(ctx context.Context, inboxRows []mod return &RedisBatchSyncResult{Status: result.Status}, err } for i, row := range validRows { - if markErr := repository.MarkRedisUsageInboxProcessedWithoutSnapshot(s.db, row.ID, fetchResult.Events[i].EventKey, fetchedAt); markErr != nil { + if markErr := repository.MarkRedisUsageInboxProcessed(s.db, row.ID, events[i].EventKey, fetchedAt); markErr != nil { return &RedisBatchSyncResult{Status: "failed"}, fmt.Errorf("mark redis usage inbox processed: %w", markErr) } } @@ -361,78 +318,15 @@ func (s *SyncService) processRedisInboxRows(ctx context.Context, inboxRows []mod } return &RedisBatchSyncResult{ Status: status, - SnapshotRunID: result.SnapshotRunID, InsertedEvents: result.InsertedEvents, DedupedEvents: result.DedupedEvents, }, returnErr } -// SyncLegacyStatus 执行 legacy_export 回退路径并返回 snapshot_run 最终状态。 -// legacy_export 仍然会创建 snapshot_runs、保存原始导出 payload,并把 usage_events 关联到本次 snapshot_run。 -func (s *SyncService) SyncLegacyStatus(ctx context.Context) (string, error) { - if err := s.validate(syncMetadataRequired); err != nil { - return "", err - } - if s.legacyUsageFetcher == nil { - return "", fmt.Errorf("sync service legacy usage fetcher is nil") - } - - fetchedAt := s.now().UTC() - fetchResult, fetchErr := s.legacyUsageFetcher.FetchUsage(ctx, fetchedAt) - result, err := s.persistUsageResult(ctx, fetchedAt, fetchResult, fetchErr, true, true) - if result == nil { - return "", err - } - return result.Status, err -} - -// syncOnce 执行一次完整的 legacy_export 同步:拉取导出、创建 snapshot_run、写 usage_events 并同步 metadata。 -// 该路径用于 legacy_export 模式以及 auto 探测 Redis 不可用后的回退模式,不参与 Redis inbox 分阶段处理。 -func (s *SyncService) syncOnce(ctx context.Context) (*SyncResult, error) { - if err := s.validate(syncMetadataRequired); err != nil { - return nil, err - } - if s.usageFetcher == nil && s.client != nil { - s.usageFetcher = legacyUsageFetcher{client: s.client} - } - if s.usageFetcher == nil { - return nil, fmt.Errorf("sync service usage fetcher is nil") - } - - fetchedAt := s.now().UTC() - fetchResult, fetchErr := s.usageFetcher.FetchUsage(ctx, fetchedAt) - return s.persistUsageResult(ctx, fetchedAt, fetchResult, fetchErr, true, true) -} - -// persistRedisUsageEvents 是 Redis inbox 专用入库路径,只写 usage_events 和可选 metadata,不创建 snapshot_runs。 -func (s *SyncService) persistRedisUsageEvents(ctx context.Context, fetchResult *UsageFetchResult, syncMetadata bool) (*SyncResult, error) { - if fetchResult == nil { - return nil, fmt.Errorf("redis usage fetch result is nil") - } - - var authFilesResult *cpa.AuthFilesResult - var providerConfig cpa.ProviderMetadataConfig - var fetchedProviderTypes []string - var authFilesErr error - var providerMetadataErr error - if syncMetadata { - if s.metadataFetcher == nil && s.client != nil { - s.metadataFetcher = s.client - } - if s.metadataFetcher == nil { - return nil, fmt.Errorf("sync service metadata fetcher is nil") - } - authFilesResult, authFilesErr = s.metadataFetcher.FetchAuthFiles(ctx) - providerConfig, fetchedProviderTypes, providerMetadataErr = fetchProviderMetadata(ctx, s.metadataFetcher) - } - - events := fetchResult.Events - for i := range events { - events[i].SnapshotRunID = 0 - } +// persistRedisUsageEvents 写入 Redis inbox 解码出的 usage_events。 +func (s *SyncService) persistRedisUsageEvents(events []models.UsageEvent) (*SyncResult, error) { var err error events, err = alignUsageEventKeysWithExistingCanonicalEvents(s.db, events) - fetchResult.Events = events if err != nil { return &SyncResult{Status: "failed"}, fmt.Errorf("align usage events: %w", err) } @@ -445,197 +339,7 @@ func (s *SyncService) persistRedisUsageEvents(ctx context.Context, fetchResult * "inserted_events": inserted, "deduped_events": deduped, }).Debug("usage events insert finished") - - var partialSyncErr error - if syncMetadata { - authFilesSyncErr := syncAuthFiles(s.db, authFilesResult, authFilesErr) - providerMetadataSyncErr := syncProviderMetadata(s.db, providerConfig, fetchedProviderTypes, providerMetadataErr) - partialSyncErr = joinErrors(authFilesSyncErr, providerMetadataSyncErr) - } - status := "completed" - if partialSyncErr != nil { - status = "completed_with_warnings" - } - result := &SyncResult{Status: status, InsertedEvents: inserted, DedupedEvents: deduped} - if partialSyncErr != nil { - return result, partialSyncErr - } - return result, nil -} - -// persistUsageResult 是 legacy_export 专用入库路径,负责 snapshot_runs 的完整生命周期和 usage_events 写入。 -// 即使拉取失败也会创建并 finalize snapshot_run,用于保留失败状态、HTTP 状态、错误信息和原始 payload 审计线索。 -func (s *SyncService) persistUsageResult(ctx context.Context, fetchedAt time.Time, fetchResult *UsageFetchResult, fetchErr error, syncMetadata bool, filterByWatermark bool) (*SyncResult, error) { - logrus.WithFields(logrus.Fields{ - "sync_metadata": syncMetadata, - "filter_by_watermark": filterByWatermark, - }).Debug("usage persistence started") - - var ( - httpStatus int - rawPayload []byte - payloadHash string - exportedAt *time.Time - version string - ) - if fetchResult != nil { - httpStatus = fetchResult.HTTPStatus - rawPayload = append([]byte(nil), fetchResult.RawPayload...) - payloadHash = hashPayload(rawPayload) - exportedAt = fetchResult.ExportedAt - version = fetchResult.Version - } - - var authFilesResult *cpa.AuthFilesResult - var providerConfig cpa.ProviderMetadataConfig - var fetchedProviderTypes []string - var authFilesErr error - var providerMetadataErr error - if syncMetadata { - if s.metadataFetcher == nil && s.client != nil { - s.metadataFetcher = s.client - } - if s.metadataFetcher == nil { - return nil, fmt.Errorf("sync service metadata fetcher is nil") - } - authFilesResult, authFilesErr = s.metadataFetcher.FetchAuthFiles(ctx) - providerConfig, fetchedProviderTypes, providerMetadataErr = fetchProviderMetadata(ctx, s.metadataFetcher) - } - - snapshotRun, err := repository.CreateSnapshotRun(s.db, repository.SnapshotRunInput{ - FetchedAt: fetchedAt, - CPABaseURL: s.baseURL, - ExportedAt: exportedAt, - Version: version, - Status: initialSnapshotStatus(fetchErr), - HTTPStatus: httpStatus, - PayloadHash: payloadHash, - RawPayload: rawPayload, - ErrorMessage: errorMessage(fetchErr), - }) - if err != nil { - return nil, err - } - logrus.WithFields(logrus.Fields{ - "snapshot_run_id": snapshotRun.ID, - "status": snapshotRun.Status, - "payload_bytes": len(rawPayload), - }).Debug("snapshot run created") - - if fetchErr != nil { - finalizeErr := repository.FinalizeSnapshotRun(s.db, snapshotRun.ID, repository.SnapshotRunResult{ - Status: "failed", - HTTPStatus: httpStatus, - ErrorMessage: errorMessage(fetchErr), - ExportedAt: exportedAt, - }) - if finalizeErr != nil { - return nil, fmt.Errorf("fetch usage export: %v; finalize snapshot run: %w", fetchErr, finalizeErr) - } - return &SyncResult{ - SnapshotRunID: snapshotRun.ID, - Status: "failed", - HTTPStatus: httpStatus, - PayloadHash: payloadHash, - ExportedAt: exportedAt, - }, fmt.Errorf("fetch usage export: %w", fetchErr) - } - - events := fetchResult.Events - for i := range events { - events[i].SnapshotRunID = snapshotRun.ID - } - if filterByWatermark { - events, err = filterUsageEventsByLocalWatermark(s.db, events, syncPrefilterOverlapWindow) - if err != nil { - finalizeErr := repository.FinalizeSnapshotRun(s.db, snapshotRun.ID, repository.SnapshotRunResult{ - Status: "failed", - HTTPStatus: httpStatus, - ErrorMessage: errorMessage(err), - ExportedAt: exportedAt, - }) - if finalizeErr != nil { - return nil, fmt.Errorf("filter usage events: %v; finalize snapshot run: %w", err, finalizeErr) - } - return nil, fmt.Errorf("filter usage events: %w", err) - } - } - events, err = alignUsageEventKeysWithExistingCanonicalEvents(s.db, events) - fetchResult.Events = events - if err != nil { - finalizeErr := repository.FinalizeSnapshotRun(s.db, snapshotRun.ID, repository.SnapshotRunResult{ - Status: "failed", - HTTPStatus: httpStatus, - ErrorMessage: errorMessage(err), - ExportedAt: exportedAt, - }) - if finalizeErr != nil { - return nil, fmt.Errorf("align usage events: %v; finalize snapshot run: %w", err, finalizeErr) - } - return nil, fmt.Errorf("align usage events: %w", err) - } - logrus.WithField("event_count", len(events)).Debug("usage events insert started") - inserted, deduped, err := repository.InsertUsageEvents(s.db, events) - if err != nil { - finalizeErr := repository.FinalizeSnapshotRun(s.db, snapshotRun.ID, repository.SnapshotRunResult{ - Status: "failed", - HTTPStatus: httpStatus, - ErrorMessage: errorMessage(err), - ExportedAt: exportedAt, - }) - if finalizeErr != nil { - return nil, fmt.Errorf("insert usage events: %v; finalize snapshot run: %w", err, finalizeErr) - } - return nil, fmt.Errorf("insert usage events: %w", err) - } - - logrus.WithFields(logrus.Fields{ - "snapshot_run_id": snapshotRun.ID, - "inserted_events": inserted, - "deduped_events": deduped, - }).Debug("usage events insert finished") - - var partialSyncErr error - if syncMetadata { - authFilesSyncErr := syncAuthFiles(s.db, authFilesResult, authFilesErr) - providerMetadataSyncErr := syncProviderMetadata(s.db, providerConfig, fetchedProviderTypes, providerMetadataErr) - partialSyncErr = joinErrors(authFilesSyncErr, providerMetadataSyncErr) - } - finalStatus := "completed" - if partialSyncErr != nil { - finalStatus = "completed_with_warnings" - } - finalErrorMessage := errorMessage(partialSyncErr) - if err := repository.FinalizeSnapshotRun(s.db, snapshotRun.ID, repository.SnapshotRunResult{ - Status: finalStatus, - HTTPStatus: httpStatus, - InsertedEvents: inserted, - DedupedEvents: deduped, - ExportedAt: exportedAt, - ErrorMessage: finalErrorMessage, - }); err != nil { - return nil, err - } - logrus.WithFields(logrus.Fields{ - "snapshot_run_id": snapshotRun.ID, - "status": finalStatus, - "inserted_events": inserted, - "deduped_events": deduped, - }).Debug("snapshot run finalized") - - result := &SyncResult{ - SnapshotRunID: snapshotRun.ID, - Status: finalStatus, - HTTPStatus: httpStatus, - InsertedEvents: inserted, - DedupedEvents: deduped, - PayloadHash: payloadHash, - ExportedAt: exportedAt, - } - if partialSyncErr != nil { - return result, partialSyncErr - } - return result, nil + return &SyncResult{Status: "completed", InsertedEvents: inserted, DedupedEvents: deduped}, nil } func alignUsageEventKeysWithExistingCanonicalEvents(db *gorm.DB, events []models.UsageEvent) ([]models.UsageEvent, error) { @@ -648,6 +352,10 @@ func alignUsageEventKeysWithExistingCanonicalEvents(db *gorm.DB, events []models events[i].Timestamp = events[i].Timestamp.UTC() canonicalKey := canonicalUsageEventKey(events[i]) incomingKey := strings.TrimSpace(events[i].EventKey) + if strings.TrimSpace(events[i].RequestID) != "" { + canonicalEventKeys[canonicalKey] = incomingKey + continue + } if existingKey := canonicalEventKeys[canonicalKey]; existingKey != "" { if incomingKey == canonicalKey { events[i].EventKey = existingKey @@ -730,52 +438,6 @@ func canonicalUsageEventKey(event models.UsageEvent) string { ) } -type legacyUsageFetcher struct { - client interface { - FetchUsageExport(ctx context.Context) (*cpa.ExportResult, error) - } -} - -// FetchUsage 从 legacy export 接口拉取完整导出结果,保留 raw payload 给 persistUsageResult 写入 snapshot_runs。 -// 这是 Redis 队列不可用时的回退数据源,事件 key 仍在后续入库阶段与既有 canonical event 对齐。 -func (f legacyUsageFetcher) FetchUsage(ctx context.Context, _ time.Time) (*UsageFetchResult, error) { - if f.client == nil { - return nil, fmt.Errorf("legacy usage client is nil") - } - logrus.Debug("legacy usage pull started") - result, err := f.client.FetchUsageExport(ctx) - if result == nil { - logrus.WithError(err).Debug("legacy usage pull finished") - return nil, err - } - var exportedAt *time.Time - if !result.Payload.ExportedAt.IsZero() { - normalized := result.Payload.ExportedAt.UTC() - exportedAt = &normalized - } - version := "" - if result.Payload.Version > 0 { - version = fmt.Sprintf("%d", result.Payload.Version) - } - events := FlattenUsageExport(0, result.Payload) - logFields := logrus.Fields{ - "http_status": result.StatusCode, - "event_count": len(events), - "payload_bytes": len(result.Body), - } - if err != nil { - logFields["error"] = err.Error() - } - logrus.WithFields(logFields).Debug("legacy usage pull finished") - return &UsageFetchResult{ - HTTPStatus: result.StatusCode, - RawPayload: append([]byte(nil), result.Body...), - ExportedAt: exportedAt, - Version: version, - Events: events, - }, err -} - func (s *SyncService) validate(syncMetadata bool) error { if s == nil { return fmt.Errorf("sync service is nil") @@ -823,13 +485,12 @@ func markRedisInboxRowsProcessFailed(db *gorm.DB, rows []models.RedisUsageInbox, } if stored.Status == repository.RedisUsageInboxStatusDiscarded { logrus.WithFields(logrus.Fields{ - "inbox_id": stored.ID, - "queue_key": stored.QueueKey, - "message_hash": stored.MessageHash, - "attempt_count": stored.AttemptCount, - "last_error": stored.LastError, - "popped_at": stored.PoppedAt, - "snapshot_run_id": stored.SnapshotRunID, + "inbox_id": stored.ID, + "queue_key": stored.QueueKey, + "message_hash": stored.MessageHash, + "attempt_count": stored.AttemptCount, + "last_error": stored.LastError, + "popped_at": stored.PoppedAt, }).Warn("discarded redis usage inbox row after repeated process failures") } } @@ -843,54 +504,6 @@ func redisQueueKey(value string) string { return trimmed } -func filterUsageEventsByLocalWatermark(db *gorm.DB, events []models.UsageEvent, overlapWindow time.Duration) ([]models.UsageEvent, error) { - if len(events) == 0 { - return events, nil - } - - watermark, err := repository.FindLatestUsageEventTimestamp(db) - if err != nil { - return nil, err - } - if watermark == nil { - return events, nil - } - - cutoff := watermark.UTC().Add(-overlapWindow) - filtered := make([]models.UsageEvent, 0, len(events)) - for _, event := range events { - if event.Timestamp.IsZero() || !event.Timestamp.UTC().Before(cutoff) { - filtered = append(filtered, event) - } - } - skipped := len(events) - len(filtered) - if skipped > 0 { - logrus.WithFields(logrus.Fields{ - "watermark": watermark.UTC().Format(time.RFC3339), - "cutoff": cutoff.Format(time.RFC3339), - "overlap_hours": overlapWindow.Hours(), - "filtered_events": skipped, - "total_events": len(events), - }).Info("filtered old usage events before insert") - } - return filtered, nil -} - -func hashPayload(payload []byte) string { - if len(payload) == 0 { - return "" - } - sum := sha256.Sum256(payload) - return hex.EncodeToString(sum[:]) -} - -func initialSnapshotStatus(err error) string { - if err != nil { - return "failed" - } - return "pending" -} - func errorMessage(err error) string { if err == nil { return "" @@ -898,7 +511,7 @@ func errorMessage(err error) string { return strings.TrimSpace(err.Error()) } -func syncAuthFiles(db *gorm.DB, result *cpa.AuthFilesResult, fetchErr error) error { +func syncAuthFiles(ctx context.Context, db *gorm.DB, result *cpa.AuthFilesResult, fetchErr error, now time.Time) error { if fetchErr != nil { return fmt.Errorf("fetch auth files: %w", fetchErr) } @@ -909,24 +522,19 @@ func syncAuthFiles(db *gorm.DB, result *cpa.AuthFilesResult, fetchErr error) err return fmt.Errorf("fetch auth files: empty response") } - inputs := make([]repository.AuthFileInput, 0, len(result.Payload.Files)) + identities := make([]models.UsageIdentity, 0, len(result.Payload.Files)) for _, file := range result.Payload.Files { - inputs = append(inputs, repository.AuthFileInput{ - AuthIndex: file.AuthIndex, - Name: file.Name, - Email: file.Email, - Type: file.Type, - Provider: file.Provider, - Label: file.Label, - Status: file.Status, - Source: file.Source, - Disabled: file.Disabled, - Unavailable: file.Unavailable, - RuntimeOnly: file.RuntimeOnly, + identities = append(identities, models.UsageIdentity{ + Name: firstNonEmpty(file.Email, file.Label, file.Name, file.AuthIndex), + AuthType: models.UsageIdentityAuthTypeAuthFile, + AuthTypeName: "oauth", + Identity: file.AuthIndex, + Type: file.Type, + Provider: file.Provider, }) } - if err := repository.ReplaceAuthFiles(db, inputs); err != nil { - return fmt.Errorf("sync auth files: %w", err) + if err := repository.ReplaceUsageIdentitiesForAuthType(ctx, db, identities, models.UsageIdentityAuthTypeAuthFile, now); err != nil { + return fmt.Errorf("sync auth file usage identities: %w", err) } return nil } @@ -980,51 +588,67 @@ func fetchProviderMetadata(ctx context.Context, fetcher MetadataFetcher) (cpa.Pr return cfg, fetchedProviderTypes, joinErrors(errs...) } -func syncProviderMetadata(db *gorm.DB, cfg cpa.ProviderMetadataConfig, fetchedProviderTypes []string, fetchErr error) error { +func syncProviderMetadata(ctx context.Context, db *gorm.DB, cfg cpa.ProviderMetadataConfig, fetchedProviderTypes []string, fetchErr error, now time.Time) (error, error) { if db == nil { - return fmt.Errorf("database is nil") + return fmt.Errorf("database is nil"), nil } inputs := flattenProviderMetadata(cfg) - if err := repository.ReplaceProviderMetadataForProviderTypes(db, inputs, fetchedProviderTypes); err != nil { - return fmt.Errorf("sync provider metadata: %w", err) + identities := providerMetadataUsageIdentities(inputs) + if err := repository.ReplaceUsageIdentitiesForProviderTypes(ctx, db, identities, fetchedProviderTypes, now); err != nil { + return fmt.Errorf("sync provider usage identities: %w", err), nil } if fetchErr != nil { - return fmt.Errorf("fetch provider metadata: %w", fetchErr) + return nil, fmt.Errorf("fetch provider metadata: %w", fetchErr) } - return nil + return nil, nil +} + +type providerMetadataInput struct { + LookupKey string + ProviderType string + DisplayName string +} + +func providerMetadataUsageIdentities(inputs []providerMetadataInput) []models.UsageIdentity { + identities := make([]models.UsageIdentity, 0, len(inputs)) + for _, input := range inputs { + identities = append(identities, models.UsageIdentity{ + Name: input.DisplayName, + AuthType: models.UsageIdentityAuthTypeAIProvider, + AuthTypeName: "apikey", + Identity: input.LookupKey, + Type: input.ProviderType, + Provider: input.DisplayName, + }) + } + return identities } -func flattenProviderMetadata(cfg cpa.ProviderMetadataConfig) []repository.ProviderMetadataInput { - items := make([]repository.ProviderMetadataInput, 0) +func flattenProviderMetadata(cfg cpa.ProviderMetadataConfig) []providerMetadataInput { + items := make([]providerMetadataInput, 0) seen := make(map[string]struct{}) - appendItem := func(lookupKey, providerType, displayName, providerKey, matchKind string) { + appendItem := func(lookupKey, providerType, displayName string) { lookupKey = strings.TrimSpace(lookupKey) providerType = strings.TrimSpace(providerType) displayName = strings.TrimSpace(displayName) - providerKey = strings.TrimSpace(providerKey) - matchKind = strings.TrimSpace(matchKind) - if lookupKey == "" || providerType == "" || displayName == "" || providerKey == "" || matchKind == "" { + if lookupKey == "" || providerType == "" || displayName == "" { return } if _, ok := seen[lookupKey]; ok { return } seen[lookupKey] = struct{}{} - items = append(items, repository.ProviderMetadataInput{ + items = append(items, providerMetadataInput{ LookupKey: lookupKey, ProviderType: providerType, DisplayName: displayName, - ProviderKey: providerKey, - MatchKind: matchKind, }) } appendProviderEntries := func(providerType string, configs []cpa.ProviderKeyConfig) { for _, cfg := range configs { - displayName := firstNonEmpty(cfg.Prefix, cfg.Name, providerType) - providerKey := providerType + ":" + displayName - appendItem(cfg.APIKey, providerType, displayName, providerKey, "api_key") - appendItem(cfg.Prefix, providerType, displayName, providerKey, "prefix") + displayName := firstNonEmpty(cfg.Name, providerType) + appendItem(cfg.APIKey, providerType, displayName) } } @@ -1034,11 +658,9 @@ func flattenProviderMetadata(cfg cpa.ProviderMetadataConfig) []repository.Provid appendProviderEntries("vertex", cfg.VertexAPIKeys) for _, provider := range cfg.OpenAICompatibility { - displayName := firstNonEmpty(provider.Name, provider.Prefix, "openai") - providerKey := "openai:" + displayName - appendItem(provider.Prefix, "openai", displayName, providerKey, "prefix") + displayName := firstNonEmpty(provider.Name, "openai") for _, entry := range provider.APIKeyEntries { - appendItem(entry.APIKey, "openai", displayName, providerKey, "api_key") + appendItem(entry.APIKey, "openai", displayName) } } diff --git a/internal/service/sync_test.go b/internal/service/sync_test.go index 26dc0df4..dc572005 100644 --- a/internal/service/sync_test.go +++ b/internal/service/sync_test.go @@ -21,9 +21,7 @@ import ( gormlogger "gorm.io/gorm/logger" ) -type stubExportFetcher struct { - result *cpa.ExportResult - err error +type stubMetadataFetcher struct { authFilesResult *cpa.AuthFilesResult authFilesErr error providerConfig cpa.ProviderMetadataConfig @@ -46,37 +44,38 @@ type trackingMetadataFetcher struct { providerErr error } -func (s stubExportFetcher) FetchUsageExport(context.Context) (*cpa.ExportResult, error) { - return s.result, s.err +type observingMetadataFetcher struct { + db *gorm.DB + usageEventsBeforeMetadataSync int64 } -func (s stubExportFetcher) FetchAuthFiles(context.Context) (*cpa.AuthFilesResult, error) { +func (s stubMetadataFetcher) FetchAuthFiles(context.Context) (*cpa.AuthFilesResult, error) { if s.authFilesResult != nil || s.authFilesErr != nil { return s.authFilesResult, s.authFilesErr } return &cpa.AuthFilesResult{StatusCode: 200, Payload: cpa.AuthFilesResponse{}}, nil } -func (s stubExportFetcher) FetchGeminiAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { +func (s stubMetadataFetcher) FetchGeminiAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { if s.geminiNilResult { return nil, nil } return providerKeyConfigResult(s.providerConfig.GeminiAPIKeys, s.geminiErr) } -func (s stubExportFetcher) FetchClaudeAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { +func (s stubMetadataFetcher) FetchClaudeAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { return providerKeyConfigResult(s.providerConfig.ClaudeAPIKeys, s.claudeErr) } -func (s stubExportFetcher) FetchCodexAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { +func (s stubMetadataFetcher) FetchCodexAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { return providerKeyConfigResult(s.providerConfig.CodexAPIKeys, s.codexErr) } -func (s stubExportFetcher) FetchVertexAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { +func (s stubMetadataFetcher) FetchVertexAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { return providerKeyConfigResult(s.providerConfig.VertexAPIKeys, s.vertexErr) } -func (s stubExportFetcher) FetchOpenAICompatibility(context.Context) (*cpa.OpenAICompatibilityResult, error) { +func (s stubMetadataFetcher) FetchOpenAICompatibility(context.Context) (*cpa.OpenAICompatibilityResult, error) { return openAICompatibilityResult(s.providerConfig.OpenAICompatibility, s.openAIErr) } @@ -127,304 +126,35 @@ func (s *trackingMetadataFetcher) FetchOpenAICompatibility(context.Context) (*cp return openAICompatibilityResult(nil, s.providerErr) } -func (s *trackingMetadataFetcher) providerCalls() int { - return s.geminiCalls + s.claudeCalls + s.codexCalls + s.vertexCalls + s.openAICalls -} - -func TestSyncOncePersistsSnapshotAndEvents(t *testing.T) { - db := openSyncTestDatabase(t) - body := []byte(`{"version":1,"exported_at":"2026-04-16T10:00:00Z","usage":{"apis":{"provider-a":{"models":{"claude-sonnet":{"details":[{"timestamp":"2026-04-16T09:30:00Z","latency_ms":123,"source":"codex-a","auth_index":"1","failed":false,"tokens":{"input_tokens":10,"output_tokens":20,"reasoning_tokens":5,"cached_tokens":0,"total_tokens":35}}]}}}}}}`) - service := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - Client: stubExportFetcher{ - result: successfulExportResult(body), - authFilesResult: &cpa.AuthFilesResult{StatusCode: 200, Payload: cpa.AuthFilesResponse{Files: []cpa.AuthFile{{ - AuthIndex: "1", - Name: "Claude Desktop", - Email: "user@example.com", - Type: "claude", - Provider: "anthropic", - }}}}, - }, - }) - - result, err := service.SyncNow(context.Background()) - if err != nil { - t.Fatalf("SyncOnce returned error: %v", err) - } - if result.Status != "completed" || result.HTTPStatus != 200 { - t.Fatalf("unexpected sync result: %+v", result) - } - if result.InsertedEvents != 1 || result.DedupedEvents != 0 { - t.Fatalf("unexpected sync counts: %+v", result) - } - var snapshot models.SnapshotRun - if err := db.First(&snapshot, result.SnapshotRunID).Error; err != nil { - t.Fatalf("load snapshot run: %v", err) - } - if snapshot.Status != "completed" { - t.Fatalf("expected completed snapshot run, got %q", snapshot.Status) - } - if snapshot.PayloadHash == "" || snapshot.InsertedEvents != 1 { - t.Fatalf("unexpected snapshot values: %+v", snapshot) - } - var event models.UsageEvent - if err := db.First(&event).Error; err != nil { - t.Fatalf("load usage event: %v", err) - } - if event.SnapshotRunID != result.SnapshotRunID || event.Source != "codex-a" || event.TotalTokens != 35 { - t.Fatalf("unexpected usage event: %+v", event) - } - - var authFile models.AuthFile - if err := db.First(&authFile).Error; err != nil { - t.Fatalf("load auth file: %v", err) - } - if authFile.AuthIndex != "1" || authFile.Email != "user@example.com" { - t.Fatalf("unexpected auth file: %+v", authFile) +func (s *observingMetadataFetcher) FetchAuthFiles(context.Context) (*cpa.AuthFilesResult, error) { + if err := s.db.Model(&models.UsageEvent{}).Count(&s.usageEventsBeforeMetadataSync).Error; err != nil { + return nil, err } + return &cpa.AuthFilesResult{StatusCode: 200, Payload: cpa.AuthFilesResponse{}}, nil } -func TestSyncOnceMarksFetchFailureOnSnapshotRun(t *testing.T) { - db := openSyncTestDatabase(t) - service := NewSyncServiceWithClient(db, "https://cpa.example.com", stubExportFetcher{ - err: errors.New("management export request failed with status 401"), - result: &cpa.ExportResult{ - StatusCode: 401, - Body: []byte(`{"error":"unauthorized"}`), - }, - }) - - result, err := service.SyncNow(context.Background()) - if err == nil { - t.Fatal("expected sync error") - } - if result == nil || result.Status != "failed" || result.HTTPStatus != 401 { - t.Fatalf("unexpected sync result: %+v", result) - } - - var snapshot models.SnapshotRun - if err := db.First(&snapshot, result.SnapshotRunID).Error; err != nil { - t.Fatalf("load snapshot run: %v", err) - } - if snapshot.Status != "failed" { - t.Fatalf("expected failed snapshot run, got %q", snapshot.Status) - } - if snapshot.ErrorMessage == "" { - t.Fatal("expected snapshot error message to be stored") - } +func (s *observingMetadataFetcher) FetchGeminiAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { + return providerKeyConfigResult(nil, nil) } -func TestSyncOnceReturnsAuthFilesFailureWithoutClearingExistingData(t *testing.T) { - db := openSyncTestDatabase(t) - if err := repository.ReplaceAuthFiles(db, []repository.AuthFileInput{{ - AuthIndex: "existing", - Email: "existing@example.com", - }}); err != nil { - t.Fatalf("seed auth files: %v", err) - } - - service := NewSyncServiceWithClient(db, "https://cpa.example.com", stubExportFetcher{ - result: successfulExportResult([]byte(`{"version":1}`)), - authFilesErr: errors.New("management auth files request failed with status 503"), - }) - - result, err := service.SyncNow(context.Background()) - if err == nil { - t.Fatal("expected auth files sync error") - } - if result == nil || result.Status != "completed_with_warnings" { - t.Fatalf("expected completed_with_warnings sync result with partial failure, got %+v", result) - } - - files, listErr := repository.ListAuthFiles(db) - if listErr != nil { - t.Fatalf("list auth files: %v", listErr) - } - if len(files) != 1 || files[0].AuthIndex != "existing" { - t.Fatalf("expected existing auth files to remain available, got %+v", files) - } - - var snapshot models.SnapshotRun - if err := db.First(&snapshot, result.SnapshotRunID).Error; err != nil { - t.Fatalf("load snapshot run: %v", err) - } - if snapshot.Status != "completed_with_warnings" || snapshot.ErrorMessage == "" { - t.Fatalf("expected completed_with_warnings snapshot with error message, got %+v", snapshot) - } +func (s *observingMetadataFetcher) FetchClaudeAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { + return providerKeyConfigResult(nil, nil) } -func TestSyncOnceDeduplicatesExistingEvents(t *testing.T) { - db := openSyncTestDatabase(t) - service := NewSyncServiceWithClient(db, "https://cpa.example.com", stubExportFetcher{result: successfulExportResult([]byte(`{"version":1}`))}) - - first, err := service.SyncNow(context.Background()) - if err != nil { - t.Fatalf("first SyncOnce returned error: %v", err) - } - second, err := service.SyncNow(context.Background()) - if err != nil { - t.Fatalf("second SyncOnce returned error: %v", err) - } - if first.InsertedEvents != 1 || second.InsertedEvents != 0 || second.DedupedEvents != 1 { - t.Fatalf("unexpected dedup results: first=%+v second=%+v", first, second) - } +func (s *observingMetadataFetcher) FetchCodexAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { + return providerKeyConfigResult(nil, nil) } -func TestSyncOnceDoesNotLogExpectedEventAlignmentMiss(t *testing.T) { - db, logs := openSyncTestDatabaseWithLogs(t) - service := NewSyncServiceWithClient(db, "https://cpa.example.com", stubExportFetcher{result: successfulExportResult([]byte(`{"version":1}`))}) - - if _, err := service.SyncNow(context.Background()); err != nil { - t.Fatalf("SyncNow returned error: %v", err) - } - if strings.Contains(logs.String(), "record not found") { - t.Fatalf("expected normal event alignment miss not to be logged, got %s", logs.String()) - } +func (s *observingMetadataFetcher) FetchVertexAPIKeys(context.Context) (*cpa.ProviderKeyConfigResult, error) { + return providerKeyConfigResult(nil, nil) } -func TestSyncOnceFiltersEventsOlderThanLocalWatermarkOverlap(t *testing.T) { - db := openSyncTestDatabase(t) - seedTime := time.Date(2026, 4, 20, 12, 0, 0, 0, time.UTC) - if _, _, err := repository.InsertUsageEvents(db, []models.UsageEvent{{ - EventKey: "seed-event", - SnapshotRunID: 1, - APIGroupKey: "provider-a", - Model: "claude-sonnet", - Timestamp: seedTime, - Source: "seed-source", - AuthIndex: "1", - TotalTokens: 10, - }}); err != nil { - t.Fatalf("seed usage event: %v", err) - } - - service := NewSyncServiceWithClient(db, "https://cpa.example.com", stubExportFetcher{result: &cpa.ExportResult{ - StatusCode: 200, - Payload: cpa.UsageExport{ - Version: 1, - ExportedAt: seedTime.Add(time.Hour), - Usage: cpa.StatisticsSnapshot{APIs: map[string]cpa.APISnapshot{ - "provider-a": {Models: map[string]cpa.ModelSnapshot{ - "claude-sonnet": {Details: []cpa.RequestDetail{ - {Timestamp: seedTime.Add(-48 * time.Hour), Source: "old-source", AuthIndex: "2", Tokens: cpa.TokenStats{InputTokens: 1, OutputTokens: 1}}, - {Timestamp: seedTime.Add(-12 * time.Hour), Source: "recent-source", AuthIndex: "3", Tokens: cpa.TokenStats{InputTokens: 2, OutputTokens: 2}}, - }}, - }}, - }}, - }, - }}) - - result, err := service.SyncNow(context.Background()) - if err != nil { - t.Fatalf("SyncNow returned error: %v", err) - } - if result.InsertedEvents != 1 || result.DedupedEvents != 0 { - t.Fatalf("expected only recent event to be inserted, got %+v", result) - } - - var count int64 - if err := db.Model(&models.UsageEvent{}).Where("source = ?", "old-source").Count(&count).Error; err != nil { - t.Fatalf("count old filtered events: %v", err) - } - if count != 0 { - t.Fatalf("expected old event to be filtered out, found %d rows", count) - } - if err := db.Model(&models.UsageEvent{}).Where("source = ?", "recent-source").Count(&count).Error; err != nil { - t.Fatalf("count recent events: %v", err) - } - if count != 1 { - t.Fatalf("expected recent event to be inserted, found %d rows", count) - } +func (s *observingMetadataFetcher) FetchOpenAICompatibility(context.Context) (*cpa.OpenAICompatibilityResult, error) { + return openAICompatibilityResult(nil, nil) } -func TestSyncOnceKeepsOverlapWindowEventsForExistingDedupe(t *testing.T) { - db := openSyncTestDatabase(t) - seedTime := time.Date(2026, 4, 20, 12, 0, 0, 0, time.UTC) - seedTokens := cpa.TokenStats{InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, TotalTokens: 35} - seedEvent := models.UsageEvent{ - EventKey: BuildEventKey("provider-a", "claude-sonnet", seedTime.Add(-2*time.Hour), "codex-a", "1", false, seedTokens), - SnapshotRunID: 1, - APIGroupKey: "provider-a", - Model: "claude-sonnet", - Timestamp: seedTime.Add(-2 * time.Hour), - Source: "codex-a", - AuthIndex: "1", - TotalTokens: 35, - InputTokens: 10, - OutputTokens: 20, - ReasoningTokens: 5, - } - if _, _, err := repository.InsertUsageEvents(db, []models.UsageEvent{seedEvent}); err != nil { - t.Fatalf("seed usage event: %v", err) - } - - service := NewSyncServiceWithClient(db, "https://cpa.example.com", stubExportFetcher{result: &cpa.ExportResult{ - StatusCode: 200, - Payload: cpa.UsageExport{ - Version: 1, - ExportedAt: seedTime.Add(time.Hour), - Usage: cpa.StatisticsSnapshot{APIs: map[string]cpa.APISnapshot{ - "provider-a": {Models: map[string]cpa.ModelSnapshot{ - "claude-sonnet": {Details: []cpa.RequestDetail{{ - Timestamp: seedTime.Add(-2 * time.Hour), - Source: "codex-a", - AuthIndex: "1", - Tokens: seedTokens, - }}}, - }}, - }}, - }, - }}) - - result, err := service.SyncNow(context.Background()) - if err != nil { - t.Fatalf("SyncNow returned error: %v", err) - } - if result.InsertedEvents != 0 || result.DedupedEvents != 1 { - t.Fatalf("expected overlap event to reach dedupe path, got %+v", result) - } -} - -func TestSyncOnceKeepsZeroTimestampEvents(t *testing.T) { - db := openSyncTestDatabase(t) - seedTime := time.Date(2026, 4, 20, 12, 0, 0, 0, time.UTC) - if _, _, err := repository.InsertUsageEvents(db, []models.UsageEvent{{ - EventKey: "seed-event", - SnapshotRunID: 1, - APIGroupKey: "provider-a", - Model: "claude-sonnet", - Timestamp: seedTime, - Source: "seed-source", - AuthIndex: "1", - TotalTokens: 10, - }}); err != nil { - t.Fatalf("seed usage event: %v", err) - } - - service := NewSyncServiceWithClient(db, "https://cpa.example.com", stubExportFetcher{result: &cpa.ExportResult{ - StatusCode: 200, - Payload: cpa.UsageExport{ - Version: 1, - Usage: cpa.StatisticsSnapshot{APIs: map[string]cpa.APISnapshot{ - "provider-a": {Models: map[string]cpa.ModelSnapshot{ - "claude-sonnet": {Details: []cpa.RequestDetail{{ - Source: "zero-ts-source", - AuthIndex: "5", - Tokens: cpa.TokenStats{InputTokens: 3, OutputTokens: 4}, - }}}, - }}, - }}, - }, - }}) - - result, err := service.SyncNow(context.Background()) - if err != nil { - t.Fatalf("SyncNow returned error: %v", err) - } - if result.InsertedEvents != 1 { - t.Fatalf("expected zero timestamp event to be kept, got %+v", result) - } +func (s *trackingMetadataFetcher) providerCalls() int { + return s.geminiCalls + s.claudeCalls + s.codexCalls + s.vertexCalls + s.openAICalls } func TestPullRedisUsageInboxOnlyStoresPendingRows(t *testing.T) { @@ -448,7 +178,7 @@ func TestPullRedisUsageInboxOnlyStoresPendingRows(t *testing.T) { if err := db.First(&inbox).Error; err != nil { t.Fatalf("load inbox row: %v", err) } - if inbox.Status != repository.RedisUsageInboxStatusPending || inbox.UsageEventKey != "" || inbox.SnapshotRunID != nil { + if inbox.Status != repository.RedisUsageInboxStatusPending || inbox.UsageEventKey != "" { t.Fatalf("expected pending inbox row without processing links, got %+v", inbox) } var eventCount int64 @@ -458,20 +188,13 @@ func TestPullRedisUsageInboxOnlyStoresPendingRows(t *testing.T) { if eventCount != 0 { t.Fatalf("expected pull not to write usage events, got %d", eventCount) } - var snapshotCount int64 - if err := db.Model(&models.SnapshotRun{}).Count(&snapshotCount).Error; err != nil { - t.Fatalf("count snapshot runs: %v", err) - } - if snapshotCount != 0 { - t.Fatalf("expected pull not to write snapshot runs, got %d", snapshotCount) - } } func TestProcessRedisUsageInboxPersistsEventsWithoutSnapshot(t *testing.T) { db := openSyncTestDatabase(t) rows, err := repository.InsertRedisUsageInboxMessages(db, []repository.RedisInboxInsert{{ QueueKey: cpa.ManagementUsageQueueKey, - RawMessage: `{"timestamp":"2026-04-27T08:00:00Z","provider":"claude","model":"sonnet","request_id":"process-only","tokens":{"input_tokens":1,"output_tokens":2}}`, + RawMessage: `{"timestamp":"2026-04-27T08:00:00Z","provider":"claude","endpoint":"/v1/messages","auth_type":"api_key","model":"sonnet","request_id":"process-only","tokens":{"input_tokens":1,"output_tokens":2}}`, PoppedAt: time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC), }}) if err != nil { @@ -482,33 +205,64 @@ func TestProcessRedisUsageInboxPersistsEventsWithoutSnapshot(t *testing.T) { RedisQueue: staticRedisQueue{err: errors.New("redis should not be popped while processing inbox")}, }) - result, err := service.ProcessRedisUsageInbox(context.Background(), false) + result, err := service.ProcessRedisUsageInbox(context.Background()) if err != nil { t.Fatalf("ProcessRedisUsageInbox returned error: %v", err) } - if result == nil || result.Status != "completed" || result.InsertedEvents != 1 || result.SnapshotRunID != 0 { + if result == nil || result.Status != "completed" || result.InsertedEvents != 1 { t.Fatalf("unexpected process result: %+v", result) } var event models.UsageEvent if err := db.First(&event).Error; err != nil { t.Fatalf("load usage event: %v", err) } - if event.EventKey != "process-only" || event.SnapshotRunID != 0 { + if event.EventKey != "process-only" { t.Fatalf("expected Redis event without snapshot run id, got %+v", event) } + if event.Provider != "claude" || event.Endpoint != "/v1/messages" || event.AuthType != "apikey" || event.RequestID != "process-only" { + t.Fatalf("expected Redis identity fields to persist, got %+v", event) + } var inbox models.RedisUsageInbox if err := db.First(&inbox, rows[0].ID).Error; err != nil { t.Fatalf("load inbox row: %v", err) } - if inbox.Status != repository.RedisUsageInboxStatusProcessed || inbox.SnapshotRunID != nil || inbox.UsageEventKey != "process-only" { + if inbox.Status != repository.RedisUsageInboxStatusProcessed || inbox.UsageEventKey != "process-only" { t.Fatalf("expected processed inbox row without snapshot link, got %+v", inbox) } - var snapshotCount int64 - if err := db.Model(&models.SnapshotRun{}).Count(&snapshotCount).Error; err != nil { - t.Fatalf("count snapshot runs: %v", err) +} + +func TestProcessRedisUsageInboxDoesNotFetchMetadata(t *testing.T) { + db := openSyncTestDatabase(t) + metadata := &trackingMetadataFetcher{} + rows, err := repository.InsertRedisUsageInboxMessages(db, []repository.RedisInboxInsert{{ + QueueKey: cpa.ManagementUsageQueueKey, + RawMessage: `{"timestamp":"2026-04-27T08:00:00Z","provider":"claude","model":"sonnet","request_id":"redis-no-metadata","tokens":{"input_tokens":1,"output_tokens":2}}`, + PoppedAt: time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC), + }}) + if err != nil { + t.Fatalf("seed inbox row: %v", err) } - if snapshotCount != 0 { - t.Fatalf("expected Redis processing not to write snapshot runs, got %d", snapshotCount) + service := NewSyncServiceWithOptions(db, SyncServiceOptions{ + BaseURL: "https://cpa.example.com", + MetadataFetcher: metadata, + }) + + result, err := service.ProcessRedisUsageInbox(context.Background()) + if err != nil { + t.Fatalf("ProcessRedisUsageInbox returned error: %v", err) + } + if result == nil || result.Status != "completed" || result.InsertedEvents != 1 { + t.Fatalf("unexpected process result: %+v", result) + } + if metadata.authCalls != 0 || metadata.providerCalls() != 0 { + t.Fatalf("expected redis processing not to fetch metadata, got auth=%d provider=%d", metadata.authCalls, metadata.providerCalls()) + } + var inbox models.RedisUsageInbox + if err := db.First(&inbox, rows[0].ID).Error; err != nil { + t.Fatalf("load inbox row: %v", err) + } + if inbox.Status != repository.RedisUsageInboxStatusProcessed || inbox.UsageEventKey != "redis-no-metadata" { + t.Fatalf("expected inbox row processed, got %+v", inbox) } } @@ -521,7 +275,7 @@ func TestSyncRedisBatchSkipsEmptyBatchWithoutSnapshotOrMetadata(t *testing.T) { MetadataFetcher: metadata, }) - result, err := service.SyncRedisBatch(context.Background(), true) + result, err := service.SyncRedisBatch(context.Background()) if err != nil { t.Fatalf("SyncRedisBatch returned error: %v", err) } @@ -532,13 +286,6 @@ func TestSyncRedisBatchSkipsEmptyBatchWithoutSnapshotOrMetadata(t *testing.T) { t.Fatalf("expected metadata fetch to be skipped for empty batch, got auth=%d provider=%d", metadata.authCalls, metadata.providerCalls()) } - var snapshotCount int64 - if err := db.Model(&models.SnapshotRun{}).Count(&snapshotCount).Error; err != nil { - t.Fatalf("count snapshot runs: %v", err) - } - if snapshotCount != 0 { - t.Fatalf("expected no snapshot runs for empty batch, got %d", snapshotCount) - } } func TestSyncRedisBatchPersistsNonEmptyBatchWithoutMetadata(t *testing.T) { @@ -550,7 +297,7 @@ func TestSyncRedisBatchPersistsNonEmptyBatchWithoutMetadata(t *testing.T) { MetadataFetcher: metadata, }) - result, err := service.SyncRedisBatch(context.Background(), false) + result, err := service.SyncRedisBatch(context.Background()) if err != nil { t.Fatalf("SyncRedisBatch returned error: %v", err) } @@ -561,25 +308,18 @@ func TestSyncRedisBatchPersistsNonEmptyBatchWithoutMetadata(t *testing.T) { t.Fatalf("expected metadata fetch to be skipped, got auth=%d provider=%d", metadata.authCalls, metadata.providerCalls()) } - var snapshotCount int64 - if err := db.Model(&models.SnapshotRun{}).Count(&snapshotCount).Error; err != nil { - t.Fatalf("count snapshot runs: %v", err) - } - if snapshotCount != 0 { - t.Fatalf("expected Redis batch not to create snapshot runs, got %d", snapshotCount) - } var event models.UsageEvent if err := db.First(&event).Error; err != nil { t.Fatalf("load usage event: %v", err) } - if event.EventKey != "redis-1" || event.SnapshotRunID != 0 { + if event.EventKey != "redis-1" { t.Fatalf("unexpected usage event: %+v", event) } var inbox models.RedisUsageInbox if err := db.First(&inbox).Error; err != nil { t.Fatalf("load inbox row: %v", err) } - if inbox.Status != repository.RedisUsageInboxStatusProcessed || inbox.SnapshotRunID != nil || inbox.UsageEventKey != "redis-1" { + if inbox.Status != repository.RedisUsageInboxStatusProcessed || inbox.UsageEventKey != "redis-1" { t.Fatalf("expected processed inbox row without snapshot link, got %+v", inbox) } } @@ -594,7 +334,7 @@ func TestSyncRedisBatchPersistsValidRowsWhenBatchContainsMalformedMessage(t *tes }}, }) - result, err := service.SyncRedisBatch(context.Background(), false) + result, err := service.SyncRedisBatch(context.Background()) if err == nil || !strings.Contains(err.Error(), "decode redis usage message") { t.Fatalf("expected decode warning, got %v", err) } @@ -632,7 +372,7 @@ func TestSyncRedisBatchMarksMalformedOnlyBatchWithoutSnapshot(t *testing.T) { RedisQueue: staticRedisQueue{messages: []string{`{bad-json}`}}, }) - result, err := service.SyncRedisBatch(context.Background(), false) + result, err := service.SyncRedisBatch(context.Background()) if err == nil || !strings.Contains(err.Error(), "decode redis usage message") { t.Fatalf("expected decode warning, got %v", err) } @@ -640,14 +380,6 @@ func TestSyncRedisBatchMarksMalformedOnlyBatchWithoutSnapshot(t *testing.T) { t.Fatalf("expected warning result, got %+v", result) } - var snapshotCount int64 - if err := db.Model(&models.SnapshotRun{}).Count(&snapshotCount).Error; err != nil { - t.Fatalf("count snapshot runs: %v", err) - } - if snapshotCount != 0 { - t.Fatalf("expected no snapshot for malformed-only batch, got %d", snapshotCount) - } - var inbox models.RedisUsageInbox if err := db.First(&inbox).Error; err != nil { t.Fatalf("load inbox row: %v", err) @@ -673,7 +405,7 @@ func TestSyncRedisBatchProcessesPendingInboxBeforePoppingRedis(t *testing.T) { RedisQueue: staticRedisQueue{err: errors.New("redis should not be popped while inbox is pending")}, }) - result, err := service.SyncRedisBatch(context.Background(), false) + result, err := service.SyncRedisBatch(context.Background()) if err != nil { t.Fatalf("SyncRedisBatch returned error: %v", err) } @@ -700,11 +432,10 @@ func TestSyncRedisBatchProcessesPendingInboxBeforePoppingRedis(t *testing.T) { func TestSyncRedisBatchDoesNotWatermarkFilterRedisInboxEvents(t *testing.T) { db := openSyncTestDatabase(t) if _, _, err := repository.InsertUsageEvents(db, []models.UsageEvent{{ - EventKey: "future-watermark", - SnapshotRunID: 1, - APIGroupKey: "claude", - Model: "sonnet", - Timestamp: time.Date(2026, 4, 28, 8, 0, 0, 0, time.UTC), + EventKey: "future-watermark", + APIGroupKey: "claude", + Model: "sonnet", + Timestamp: time.Date(2026, 4, 28, 8, 0, 0, 0, time.UTC), }}); err != nil { t.Fatalf("seed future event: %v", err) } @@ -715,7 +446,7 @@ func TestSyncRedisBatchDoesNotWatermarkFilterRedisInboxEvents(t *testing.T) { }}, }) - result, err := service.SyncRedisBatch(context.Background(), false) + result, err := service.SyncRedisBatch(context.Background()) if err != nil { t.Fatalf("SyncRedisBatch returned error: %v", err) } @@ -748,7 +479,7 @@ func TestSyncRedisBatchRetriesProcessFailedInboxBeforePoppingRedis(t *testing.T) RedisQueue: staticRedisQueue{err: errors.New("redis should not be popped while process_failed inbox is retryable")}, }) - result, err := service.SyncRedisBatch(context.Background(), false) + result, err := service.SyncRedisBatch(context.Background()) if err != nil { t.Fatalf("SyncRedisBatch returned error: %v", err) } @@ -766,13 +497,13 @@ func TestSyncRedisBatchRetriesProcessFailedInboxBeforePoppingRedis(t *testing.T) func TestSyncNowInRedisModeUsesDurableInbox(t *testing.T) { db := openSyncTestDatabase(t) + metadata := &trackingMetadataFetcher{} service := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - UsageSyncMode: "redis", + BaseURL: "https://cpa.example.com", RedisQueue: staticRedisQueue{messages: []string{ `{"timestamp":"2026-04-27T08:00:00Z","provider":"claude","model":"sonnet","request_id":"sync-now-redis","tokens":{"input_tokens":1,"output_tokens":2}}`, }}, - MetadataFetcher: stubExportFetcher{}, + MetadataFetcher: metadata, }) result, err := service.SyncNow(context.Background()) @@ -782,6 +513,9 @@ func TestSyncNowInRedisModeUsesDurableInbox(t *testing.T) { if result == nil || result.InsertedEvents != 1 { t.Fatalf("unexpected SyncNow result: %+v", result) } + if metadata.authCalls != 0 || metadata.providerCalls() != 0 { + t.Fatalf("expected SyncNow not to fetch metadata, got auth=%d provider=%d", metadata.authCalls, metadata.providerCalls()) + } var inbox models.RedisUsageInbox if err := db.First(&inbox).Error; err != nil { t.Fatalf("load inbox row: %v", err) @@ -791,271 +525,354 @@ func TestSyncNowInRedisModeUsesDurableInbox(t *testing.T) { } } -func TestLegacyThenRedisEquivalentRequestDedupesAcrossPaths(t *testing.T) { +func TestSyncRedisBatchKeepsRedisRequestIDWhenEquivalentCanonicalEventExists(t *testing.T) { db := openSyncTestDatabase(t) timestamp := time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC) tokens := cpa.TokenStats{InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, CachedTokens: 4, TotalTokens: 39} - legacyService := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - Client: stubExportFetcher{result: equivalentExportResult("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens)}, - }) - - first, err := legacyService.SyncNow(context.Background()) - if err != nil { - t.Fatalf("legacy SyncNow returned error: %v", err) - } - if first.InsertedEvents != 1 || first.DedupedEvents != 0 { - t.Fatalf("unexpected first sync result: %+v", first) + canonicalKey := BuildEventKey("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, tokens) + if _, _, err := repository.InsertUsageEvents(db, []models.UsageEvent{{ + EventKey: canonicalKey, + APIGroupKey: "external-api-key", + Model: "claude-sonnet", + Timestamp: timestamp, + Source: "codex-a", + AuthIndex: "1", + Failed: false, + LatencyMS: 123, + InputTokens: tokens.InputTokens, + OutputTokens: tokens.OutputTokens, + ReasoningTokens: tokens.ReasoningTokens, + CachedTokens: tokens.CachedTokens, + TotalTokens: tokens.TotalTokens, + }}); err != nil { + t.Fatalf("seed canonical usage event: %v", err) } - redisService := NewSyncServiceWithOptions(db, SyncServiceOptions{ + service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", RedisQueue: staticRedisQueue{messages: []string{ - equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-id"), + equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-canonical"), }}, }) - second, err := redisService.SyncRedisBatch(context.Background(), false) + result, err := service.SyncRedisBatch(context.Background()) if err != nil { - t.Fatalf("redis SyncRedisBatch returned error: %v", err) + t.Fatalf("SyncRedisBatch returned error: %v", err) } - if second.InsertedEvents != 0 || second.DedupedEvents != 1 { - t.Fatalf("expected Redis duplicate to dedupe against legacy event, got %+v", second) + if result.InsertedEvents != 1 || result.DedupedEvents != 0 { + t.Fatalf("expected Redis request_id event to insert separately from canonical event, got %+v", result) + } + assertUsageEventCount(t, db, 2) + var inbox models.RedisUsageInbox + if err := db.First(&inbox).Error; err != nil { + t.Fatalf("load inbox row: %v", err) + } + if inbox.UsageEventKey != "redis-request-canonical" { + t.Fatalf("expected inbox to keep Redis request_id event key, got %+v", inbox) } - assertUsageEventCount(t, db, 1) } -func TestRedisThenLegacyEquivalentRequestDedupesAcrossPaths(t *testing.T) { +func TestSyncRedisBatchKeepsDistinctRedisRequestIDsWithSameCanonicalFields(t *testing.T) { db := openSyncTestDatabase(t) timestamp := time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC) tokens := cpa.TokenStats{InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, CachedTokens: 4, TotalTokens: 39} - redisService := NewSyncServiceWithOptions(db, SyncServiceOptions{ + service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", RedisQueue: staticRedisQueue{messages: []string{ - equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-id"), + equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-1"), + equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-2"), }}, }) - first, err := redisService.SyncRedisBatch(context.Background(), false) - if err != nil { - t.Fatalf("redis SyncRedisBatch returned error: %v", err) - } - if first.InsertedEvents != 1 || first.DedupedEvents != 0 { - t.Fatalf("unexpected first sync result: %+v", first) - } - legacyService := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - Client: stubExportFetcher{result: equivalentExportResult("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens)}, - }) - - second, err := legacyService.SyncNow(context.Background()) + result, err := service.SyncRedisBatch(context.Background()) if err != nil { - t.Fatalf("legacy SyncNow returned error: %v", err) - } - if second.InsertedEvents != 0 || second.DedupedEvents != 1 { - t.Fatalf("expected legacy duplicate to dedupe against Redis event, got %+v", second) - } - assertUsageEventCount(t, db, 1) - - var event models.UsageEvent - if err := db.First(&event).Error; err != nil { - t.Fatalf("load usage event: %v", err) + t.Fatalf("SyncRedisBatch returned error: %v", err) } - if event.EventKey != "redis-request-id" { - t.Fatalf("expected Redis request_id to be preserved, got %+v", event) + if result.InsertedEvents != 2 || result.DedupedEvents != 0 { + t.Fatalf("expected distinct Redis request IDs to insert separately, got %+v", result) } + assertUsageEventCount(t, db, 2) } -func TestSyncRedisBatchDedupesOnlyOneRedisRequestAgainstExistingLegacyCanonicalEvent(t *testing.T) { +func TestSyncRedisBatchWritesDebugLogsWithoutRawPayload(t *testing.T) { db := openSyncTestDatabase(t) - timestamp := time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC) - tokens := cpa.TokenStats{InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, CachedTokens: 4, TotalTokens: 39} - legacyService := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - Client: stubExportFetcher{result: equivalentExportResult("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens)}, - }) - if _, err := legacyService.SyncNow(context.Background()); err != nil { - t.Fatalf("legacy SyncNow returned error: %v", err) - } - redisService := NewSyncServiceWithOptions(db, SyncServiceOptions{ + logs := captureSyncDebugLogs(t) + + service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", RedisQueue: staticRedisQueue{messages: []string{ - equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-1"), - equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-2"), + `{"timestamp":"2026-04-27T08:00:00Z","provider":"claude","model":"sonnet","request_id":"redis-log","api_key":"raw-secret-key","tokens":{"input_tokens":1,"output_tokens":2}}`, }}, }) - result, err := redisService.SyncRedisBatch(context.Background(), false) + _, err := service.SyncRedisBatch(context.Background()) if err != nil { t.Fatalf("SyncRedisBatch returned error: %v", err) } - if result.InsertedEvents != 1 || result.DedupedEvents != 1 { - t.Fatalf("expected one Redis request to dedupe against legacy and one to remain distinct, got %+v", result) + output := logs.String() + for _, expected := range []string{ + "redis usage batch popped", + "redis usage inbox rows inserted", + "redis usage inbox rows processed", + } { + if !strings.Contains(output, expected) { + t.Fatalf("expected debug log %q in output:\n%s", expected, output) + } + } + if strings.Contains(output, "raw-secret-key") || strings.Contains(output, "redis-log") { + t.Fatalf("debug logs should not include raw payload fields, got:\n%s", output) } - assertUsageEventCount(t, db, 2) } -func TestSyncRedisBatchKeepsDistinctRedisRequestIDsWithSameCanonicalFields(t *testing.T) { +func TestSyncMetadataRefreshesMetadataWithoutSnapshot(t *testing.T) { db := openSyncTestDatabase(t) - timestamp := time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC) - tokens := cpa.TokenStats{InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, CachedTokens: 4, TotalTokens: 39} + metadata := &trackingMetadataFetcher{} service := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - RedisQueue: staticRedisQueue{messages: []string{ - equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-1"), - equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-2"), - }}, + BaseURL: "https://cpa.example.com", + MetadataFetcher: metadata, }) - result, err := service.SyncRedisBatch(context.Background(), false) - if err != nil { - t.Fatalf("SyncRedisBatch returned error: %v", err) + if err := service.SyncMetadata(context.Background()); err != nil { + t.Fatalf("SyncMetadata returned error: %v", err) } - if result.InsertedEvents != 2 || result.DedupedEvents != 0 { - t.Fatalf("expected distinct Redis request IDs to insert separately, got %+v", result) + if metadata.authCalls != 1 || metadata.providerCalls() != 5 { + t.Fatalf("expected metadata fetch once, got auth=%d provider=%d", metadata.authCalls, metadata.providerCalls()) } - assertUsageEventCount(t, db, 2) } -func TestSyncRedisBatchRecordsPersistedEventKeyForLegacyDuplicateInboxRow(t *testing.T) { +func TestSyncMetadataWritesAuthFilesToUsageIdentities(t *testing.T) { db := openSyncTestDatabase(t) - timestamp := time.Date(2026, 4, 27, 8, 0, 0, 0, time.UTC) - tokens := cpa.TokenStats{InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, CachedTokens: 4, TotalTokens: 39} - legacyService := NewSyncServiceWithOptions(db, SyncServiceOptions{ + service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", - Client: stubExportFetcher{result: equivalentExportResult("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens)}, + MetadataFetcher: stubMetadataFetcher{authFilesResult: &cpa.AuthFilesResult{StatusCode: 200, Payload: cpa.AuthFilesResponse{Files: []cpa.AuthFile{{ + AuthIndex: "auth-1", + Name: "Fallback Name", + Email: "user@example.com", + Type: "claude", + Provider: "Claude", + Label: "Label Name", + }, { + AuthIndex: "auth-2", + Name: "Name Fallback", + Type: "gemini", + Provider: "Gemini", + Label: "Label Fallback", + }, { + AuthIndex: "auth-3", + Name: "Name Fallback", + Type: "codex", + Provider: "Codex", + }, { + AuthIndex: "auth-4", + Type: "vertex", + Provider: "Vertex", + }}}}}, }) - first, err := legacyService.SyncNow(context.Background()) + + if err := service.SyncMetadata(context.Background()); err != nil { + t.Fatalf("SyncMetadata returned error: %v", err) + } + items, err := repository.ListUsageIdentities(context.Background(), db) if err != nil { - t.Fatalf("legacy SyncNow returned error: %v", err) + t.Fatalf("list usage identities: %v", err) + } + byIdentity := usageIdentitiesByIdentity(items) + first := byIdentity["auth-1"] + if first.Name != "user@example.com" || first.AuthType != models.UsageIdentityAuthTypeAuthFile || first.AuthTypeName != "oauth" || first.Identity != "auth-1" || first.Type != "claude" || first.Provider != "Claude" || first.IsDeleted { + t.Fatalf("unexpected auth usage identity for auth-1: %+v", first) + } + second := byIdentity["auth-2"] + if second.Name != "Label Fallback" || second.AuthTypeName != "oauth" || second.Identity != "auth-2" || second.Type != "gemini" || second.Provider != "Gemini" || second.IsDeleted { + t.Fatalf("unexpected auth usage identity for auth-2: %+v", second) } - redisService := NewSyncServiceWithOptions(db, SyncServiceOptions{ + third := byIdentity["auth-3"] + if third.Name != "Name Fallback" || third.AuthTypeName != "oauth" || third.Identity != "auth-3" || third.Type != "codex" || third.Provider != "Codex" || third.IsDeleted { + t.Fatalf("unexpected auth usage identity for auth-3: %+v", third) + } + fourth := byIdentity["auth-4"] + if fourth.Name != "auth-4" || fourth.AuthTypeName != "oauth" || fourth.Identity != "auth-4" || fourth.Type != "vertex" || fourth.Provider != "Vertex" || fourth.IsDeleted { + t.Fatalf("unexpected auth usage identity for auth-4: %+v", fourth) + } + assertTableNotExists(t, db, "auth_files") +} + +func TestSyncMetadataWritesProviderMetadataToUsageIdentities(t *testing.T) { + db := openSyncTestDatabase(t) + service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", - RedisQueue: staticRedisQueue{messages: []string{ - equivalentRedisMessage("external-api-key", "claude-sonnet", timestamp, "codex-a", "1", false, 123, tokens, "redis-request-id"), + MetadataFetcher: stubMetadataFetcher{providerConfig: cpa.ProviderMetadataConfig{ + ClaudeAPIKeys: []cpa.ProviderKeyConfig{{APIKey: "claude-key", Prefix: "claude-prefix", Name: "Claude Team"}}, }}, }) - if _, err := redisService.SyncRedisBatch(context.Background(), false); err != nil { - t.Fatalf("redis SyncRedisBatch returned error: %v", err) + if err := service.SyncMetadata(context.Background()); err != nil { + t.Fatalf("SyncMetadata returned error: %v", err) } - var event models.UsageEvent - if err := db.First(&event, "snapshot_run_id = ?", first.SnapshotRunID).Error; err != nil { - t.Fatalf("load legacy usage event: %v", err) + items, err := repository.ListUsageIdentities(context.Background(), db) + if err != nil { + t.Fatalf("list usage identities: %v", err) } - var inbox models.RedisUsageInbox - if err := db.First(&inbox).Error; err != nil { - t.Fatalf("load redis inbox row: %v", err) + byIdentity := usageIdentitiesByIdentity(items) + apiKey := byIdentity["claude-key"] + if apiKey.Name != "Claude Team" || apiKey.AuthType != models.UsageIdentityAuthTypeAIProvider || apiKey.AuthTypeName != "apikey" || apiKey.Identity != "claude-key" || apiKey.Type != "claude" || apiKey.Provider != "Claude Team" || apiKey.IsDeleted { + t.Fatalf("unexpected provider usage identity for api key: %+v", apiKey) } - if inbox.UsageEventKey != event.EventKey { - t.Fatalf("expected inbox to reference persisted event key %q, got %+v", event.EventKey, inbox) + if _, ok := byIdentity["claude-prefix"]; ok { + t.Fatalf("expected provider prefix not to be stored as usage identity, got %+v", byIdentity["claude-prefix"]) } + assertTableNotExists(t, db, "provider_metadata") } -func TestSyncRedisBatchWritesDebugLogsWithoutRawPayload(t *testing.T) { +func TestSyncMetadataKeepsProviderIdentityWhenPrefixEqualsAPIKey(t *testing.T) { db := openSyncTestDatabase(t) - logs := captureSyncDebugLogs(t) - service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", - RedisQueue: staticRedisQueue{messages: []string{ - `{"timestamp":"2026-04-27T08:00:00Z","provider":"claude","model":"sonnet","request_id":"redis-log","api_key":"raw-secret-key","tokens":{"input_tokens":1,"output_tokens":2}}`, + MetadataFetcher: stubMetadataFetcher{providerConfig: cpa.ProviderMetadataConfig{ + ClaudeAPIKeys: []cpa.ProviderKeyConfig{{APIKey: "same-value", Prefix: "same-value", Name: "Claude Same"}}, }}, }) - _, err := service.SyncRedisBatch(context.Background(), false) - if err != nil { - t.Fatalf("SyncRedisBatch returned error: %v", err) + if err := service.SyncMetadata(context.Background()); err != nil { + t.Fatalf("SyncMetadata returned error: %v", err) } - output := logs.String() - for _, expected := range []string{ - "redis usage batch popped", - "redis usage inbox rows inserted", - "redis usage inbox rows processed", - } { - if !strings.Contains(output, expected) { - t.Fatalf("expected debug log %q in output:\n%s", expected, output) - } + var identity models.UsageIdentity + if err := db.Where("auth_type = ? AND identity = ?", models.UsageIdentityAuthTypeAIProvider, "same-value").First(&identity).Error; err != nil { + t.Fatalf("load protected api key usage identity: %v", err) } - if strings.Contains(output, "raw-secret-key") || strings.Contains(output, "redis-log") { - t.Fatalf("debug logs should not include raw payload fields, got:\n%s", output) + if identity.IsDeleted || identity.Type != "claude" || identity.Provider != "Claude Same" { + t.Fatalf("expected api key matching prefix to remain active, got %+v", identity) } } -func TestSyncOnceWritesCoreDebugLogsForLegacyPull(t *testing.T) { +func TestSyncMetadataDoesNotUseOpenAICompatibilityPrefixAsDisplayName(t *testing.T) { db := openSyncTestDatabase(t) - logs := captureSyncDebugLogs(t) service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", - Client: stubExportFetcher{result: successfulExportResult([]byte(`{"version":1}`))}, + MetadataFetcher: stubMetadataFetcher{providerConfig: cpa.ProviderMetadataConfig{ + OpenAICompatibility: []cpa.OpenAICompatibilityConfig{{ + Prefix: "https://proxy.internal/v1", + APIKeyEntries: []cpa.OpenAIApiKeyEntry{{APIKey: "openai-compatible-key"}}, + }}, + }}, }) - _, err := service.SyncNow(context.Background()) + if err := service.SyncMetadata(context.Background()); err != nil { + t.Fatalf("SyncMetadata returned error: %v", err) + } + items, err := repository.ListUsageIdentities(context.Background(), db) if err != nil { - t.Fatalf("SyncNow returned error: %v", err) + t.Fatalf("list usage identities: %v", err) } - output := logs.String() - for _, expected := range []string{ - "legacy usage pull started", - "legacy usage pull finished", - "usage persistence started", - "usage events insert finished", - "snapshot run finalized", - } { - if !strings.Contains(output, expected) { - t.Fatalf("expected debug log %q in output:\n%s", expected, output) - } + byIdentity := usageIdentitiesByIdentity(items) + identity := byIdentity["openai-compatible-key"] + if identity.Identity != "openai-compatible-key" { + t.Fatalf("expected OpenAI compatibility api key usage identity, got %+v", identity) + } + if identity.Name != "openai" || identity.Provider != "openai" { + t.Fatalf("expected raw OpenAI compatibility prefix not to be used as display value, got %+v", identity) + } + if _, ok := byIdentity["https://proxy.internal/v1"]; ok { + t.Fatalf("expected OpenAI compatibility prefix not to create usage identity, got %+v", items) } } -func TestSyncRedisBatchReturnsMetadataWarningAfterPersistingEvents(t *testing.T) { +func TestSyncMetadataUsageIdentityPartialFailureKeepsFailedProviderType(t *testing.T) { db := openSyncTestDatabase(t) - metadata := &trackingMetadataFetcher{authErr: errors.New("metadata unavailable")} + now := time.Date(2026, 5, 4, 9, 0, 0, 0, time.UTC) + oldDeletedAt := time.Date(2026, 5, 1, 9, 0, 0, 0, time.UTC) + if err := db.Create(&[]models.UsageIdentity{{ + Name: "Old Gemini", + AuthType: models.UsageIdentityAuthTypeAIProvider, + AuthTypeName: "apikey", + Identity: "old-gemini-key", + Type: "gemini", + Provider: "Old Gemini", + }, { + Name: "Old Claude", + AuthType: models.UsageIdentityAuthTypeAIProvider, + AuthTypeName: "apikey", + Identity: "old-claude-key", + Type: "claude", + Provider: "Old Claude", + DeletedAt: &oldDeletedAt, + }}).Error; err != nil { + t.Fatalf("seed usage identities: %v", err) + } service := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - RedisQueue: staticRedisQueue{messages: []string{`{"timestamp":"2026-04-27T08:00:00Z","provider":"claude","model":"sonnet","request_id":"redis-2","tokens":{"input_tokens":1,"output_tokens":2}}`}}, - MetadataFetcher: metadata, + BaseURL: "https://cpa.example.com", + Now: func() time.Time { return now }, + MetadataFetcher: stubMetadataFetcher{ + providerConfig: cpa.ProviderMetadataConfig{ClaudeAPIKeys: []cpa.ProviderKeyConfig{{APIKey: "new-claude-key", Prefix: "new-claude-prefix", Name: "New Claude"}}}, + geminiErr: errors.New("gemini unavailable"), + }, }) - result, err := service.SyncRedisBatch(context.Background(), true) - if err == nil || !strings.Contains(err.Error(), "metadata unavailable") { - t.Fatalf("expected metadata warning error, got %v", err) + err := service.SyncMetadata(context.Background()) + if err == nil || !strings.Contains(err.Error(), "gemini unavailable") { + t.Fatalf("expected provider metadata warning, got %v", err) } - if result == nil || result.Status != "completed_with_warnings" || result.InsertedEvents != 1 { - t.Fatalf("expected warning result with persisted event, got %+v", result) + items, listErr := repository.ListUsageIdentities(context.Background(), db) + if listErr != nil { + t.Fatalf("list usage identities: %v", listErr) } - if metadata.authCalls != 1 || metadata.providerCalls() != 5 { - t.Fatalf("expected metadata fetch once, got auth=%d provider=%d", metadata.authCalls, metadata.providerCalls()) + byIdentity := usageIdentitiesByIdentity(items) + if oldGemini := byIdentity["old-gemini-key"]; oldGemini.Identity == "" || oldGemini.IsDeleted || oldGemini.DeletedAt != nil { + t.Fatalf("expected failed gemini identity to remain untouched, got %+v", oldGemini) + } + if oldClaude := byIdentity["old-claude-key"]; oldClaude.Identity == "" || !oldClaude.IsDeleted || oldClaude.DeletedAt == nil || !oldClaude.DeletedAt.Equal(now) { + t.Fatalf("expected stale successful claude identity to be deleted at sync time, got %+v", oldClaude) + } + if newClaude := byIdentity["new-claude-key"]; newClaude.Identity == "" || newClaude.IsDeleted { + t.Fatalf("expected new claude identity to be active, got %+v", newClaude) } } -func TestSyncMetadataRefreshesMetadataWithoutSnapshot(t *testing.T) { +func TestSyncMetadataAggregatesUsageIdentityStatsAfterUpsert(t *testing.T) { db := openSyncTestDatabase(t) - metadata := &trackingMetadataFetcher{} + eventTime := time.Date(2026, 5, 4, 8, 0, 0, 0, time.UTC) + now := time.Date(2026, 5, 4, 9, 0, 0, 0, time.UTC) + if _, _, err := repository.InsertUsageEvents(db, []models.UsageEvent{{ + EventKey: "auth-stat-event", + AuthType: "oauth", + AuthIndex: "auth-stat", + Model: "sonnet", + Timestamp: eventTime, + InputTokens: 11, + OutputTokens: 13, + TotalTokens: 24, + }}); err != nil { + t.Fatalf("seed usage event: %v", err) + } service := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - MetadataFetcher: metadata, + BaseURL: "https://cpa.example.com", + Now: func() time.Time { return now }, + MetadataFetcher: stubMetadataFetcher{authFilesResult: &cpa.AuthFilesResult{StatusCode: 200, Payload: cpa.AuthFilesResponse{Files: []cpa.AuthFile{{ + AuthIndex: "auth-stat", + Email: "stats@example.com", + Type: "claude", + Provider: "Claude", + }}}}}, }) if err := service.SyncMetadata(context.Background()); err != nil { t.Fatalf("SyncMetadata returned error: %v", err) } - if metadata.authCalls != 1 || metadata.providerCalls() != 5 { - t.Fatalf("expected metadata fetch once, got auth=%d provider=%d", metadata.authCalls, metadata.providerCalls()) + var identity models.UsageIdentity + if err := db.Where("identity = ?", "auth-stat").First(&identity).Error; err != nil { + t.Fatalf("load usage identity: %v", err) } - var snapshotCount int64 - if err := db.Model(&models.SnapshotRun{}).Count(&snapshotCount).Error; err != nil { - t.Fatalf("count snapshot runs: %v", err) + if identity.TotalRequests != 1 || identity.SuccessCount != 1 || identity.InputTokens != 11 || identity.OutputTokens != 13 || identity.TotalTokens != 24 || identity.LastAggregatedUsageEventID == 0 || identity.StatsUpdatedAt == nil || !identity.StatsUpdatedAt.Equal(now) { + t.Fatalf("expected usage identity stats aggregated after metadata upsert, got %+v", identity) } - if snapshotCount != 0 { - t.Fatalf("expected metadata sync not to create snapshots, got %d", snapshotCount) + if identity.FirstUsedAt == nil || !identity.FirstUsedAt.Equal(eventTime) || identity.LastUsedAt == nil || !identity.LastUsedAt.Equal(eventTime) { + t.Fatalf("expected usage identity first/last usage times from seeded event, got %+v", identity) } } -func TestSyncMetadataPersistsProviderMetadataFromDedicatedEndpoints(t *testing.T) { +func TestSyncMetadataPersistsProviderUsageIdentitiesFromDedicatedEndpoints(t *testing.T) { db := openSyncTestDatabase(t) service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", - MetadataFetcher: stubExportFetcher{providerConfig: cpa.ProviderMetadataConfig{ + MetadataFetcher: stubMetadataFetcher{providerConfig: cpa.ProviderMetadataConfig{ GeminiAPIKeys: []cpa.ProviderKeyConfig{{APIKey: "gemini-key", Prefix: "gemini-prefix", Name: "Gemini"}}, ClaudeAPIKeys: []cpa.ProviderKeyConfig{{APIKey: "claude-key", Prefix: "claude-prefix", Name: "Claude"}}, OpenAICompatibility: []cpa.OpenAICompatibilityConfig{{ @@ -1069,20 +886,30 @@ func TestSyncMetadataPersistsProviderMetadataFromDedicatedEndpoints(t *testing.T if err := service.SyncMetadata(context.Background()); err != nil { t.Fatalf("SyncMetadata returned error: %v", err) } - items, err := repository.ListProviderMetadata(db) + items, err := repository.ListUsageIdentities(context.Background(), db) if err != nil { - t.Fatalf("list provider metadata: %v", err) + t.Fatalf("list usage identities: %v", err) } - if len(items) != 6 { - t.Fatalf("expected provider metadata rows from dedicated endpoints, got %+v", items) + providerItems := usageIdentitiesByIdentity(items) + for _, expected := range []string{"gemini-key", "claude-key", "custom-key"} { + identity := providerItems[expected] + if identity.Identity != expected || identity.AuthType != models.UsageIdentityAuthTypeAIProvider || identity.AuthTypeName != "apikey" || identity.IsDeleted { + t.Fatalf("expected active provider usage identity %q, got %+v", expected, identity) + } + } + for _, prefix := range []string{"gemini-prefix", "claude-prefix", "custom-openai"} { + if _, ok := providerItems[prefix]; ok { + t.Fatalf("expected provider prefix %q not to create usage identity, got %+v", prefix, items) + } } + assertTableNotExists(t, db, "provider_metadata") } -func TestSyncMetadataPersistsSuccessfulProviderMetadataWhenOneEndpointFails(t *testing.T) { +func TestSyncMetadataPersistsSuccessfulProviderUsageIdentitiesWhenOneEndpointFails(t *testing.T) { db := openSyncTestDatabase(t) service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", - MetadataFetcher: stubExportFetcher{ + MetadataFetcher: stubMetadataFetcher{ providerConfig: cpa.ProviderMetadataConfig{ClaudeAPIKeys: []cpa.ProviderKeyConfig{{APIKey: "claude-key", Prefix: "claude-prefix", Name: "Claude"}}}, geminiErr: errors.New("gemini unavailable"), }, @@ -1092,35 +919,47 @@ func TestSyncMetadataPersistsSuccessfulProviderMetadataWhenOneEndpointFails(t *t if err == nil || !strings.Contains(err.Error(), "gemini unavailable") { t.Fatalf("expected provider metadata warning, got %v", err) } - items, listErr := repository.ListProviderMetadata(db) + items, listErr := repository.ListUsageIdentities(context.Background(), db) if listErr != nil { - t.Fatalf("list provider metadata: %v", listErr) + t.Fatalf("list usage identities: %v", listErr) } - if len(items) != 2 || items[0].ProviderType != "claude" { - t.Fatalf("expected successful provider metadata to persist, got %+v", items) + byIdentity := usageIdentitiesByIdentity(items) + identity := byIdentity["claude-key"] + if identity.Identity != "claude-key" || identity.Type != "claude" || identity.AuthType != models.UsageIdentityAuthTypeAIProvider || identity.IsDeleted { + t.Fatalf("expected successful provider usage identity to persist, got %+v", identity) + } + if _, ok := byIdentity["claude-prefix"]; ok { + t.Fatalf("expected successful provider prefix not to create usage identity, got %+v", items) + } + if _, ok := byIdentity["gemini-key"]; ok { + t.Fatalf("expected failed gemini endpoint not to create usage identity, got %+v", items) } } -func TestSyncMetadataKeepsFailedProviderRowsDuringPartialFailure(t *testing.T) { +func TestSyncMetadataKeepsFailedProviderUsageIdentitiesDuringPartialFailure(t *testing.T) { db := openSyncTestDatabase(t) - if err := repository.ReplaceProviderMetadata(db, []repository.ProviderMetadataInput{{ - LookupKey: "old-gemini-key", - ProviderType: "gemini", - DisplayName: "Old Gemini", - ProviderKey: "gemini:Old Gemini", - MatchKind: "api_key", + now := time.Date(2026, 5, 4, 9, 30, 0, 0, time.UTC) + if err := db.Create(&[]models.UsageIdentity{{ + Name: "Old Gemini", + AuthType: models.UsageIdentityAuthTypeAIProvider, + AuthTypeName: "apikey", + Identity: "old-gemini-key", + Type: "gemini", + Provider: "Old Gemini", }, { - LookupKey: "old-claude-key", - ProviderType: "claude", - DisplayName: "Old Claude", - ProviderKey: "claude:Old Claude", - MatchKind: "api_key", - }}); err != nil { - t.Fatalf("seed provider metadata: %v", err) + Name: "Old Claude", + AuthType: models.UsageIdentityAuthTypeAIProvider, + AuthTypeName: "apikey", + Identity: "old-claude-key", + Type: "claude", + Provider: "Old Claude", + }}).Error; err != nil { + t.Fatalf("seed usage identities: %v", err) } service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", - MetadataFetcher: stubExportFetcher{ + Now: func() time.Time { return now }, + MetadataFetcher: stubMetadataFetcher{ providerConfig: cpa.ProviderMetadataConfig{ClaudeAPIKeys: []cpa.ProviderKeyConfig{{APIKey: "new-claude-key", Prefix: "new-claude-prefix", Name: "New Claude"}}}, geminiErr: errors.New("gemini unavailable"), }, @@ -1130,50 +969,55 @@ func TestSyncMetadataKeepsFailedProviderRowsDuringPartialFailure(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "gemini unavailable") { t.Fatalf("expected provider metadata warning, got %v", err) } - items, listErr := repository.ListProviderMetadata(db) + items, listErr := repository.ListUsageIdentities(context.Background(), db) if listErr != nil { - t.Fatalf("list provider metadata: %v", listErr) + t.Fatalf("list usage identities: %v", listErr) } - lookupKeys := make(map[string]struct{}, len(items)) - for _, item := range items { - lookupKeys[item.LookupKey] = struct{}{} + byIdentity := usageIdentitiesByIdentity(items) + if oldGemini := byIdentity["old-gemini-key"]; oldGemini.Identity == "" || oldGemini.IsDeleted || oldGemini.DeletedAt != nil { + t.Fatalf("expected failed gemini usage identity to remain untouched, got %+v", oldGemini) } - for _, expected := range []string{"old-gemini-key", "new-claude-key", "new-claude-prefix"} { - if _, ok := lookupKeys[expected]; !ok { - t.Fatalf("expected provider metadata %q to exist after partial failure, got %+v", expected, items) - } + if oldClaude := byIdentity["old-claude-key"]; oldClaude.Identity == "" || !oldClaude.IsDeleted || oldClaude.DeletedAt == nil || !oldClaude.DeletedAt.Equal(now) { + t.Fatalf("expected stale successful claude usage identity to be deleted, got %+v", oldClaude) + } + newClaude := byIdentity["new-claude-key"] + if newClaude.Identity != "new-claude-key" || newClaude.IsDeleted { + t.Fatalf("expected active replacement usage identity, got %+v", newClaude) } - if _, ok := lookupKeys["old-claude-key"]; ok { - t.Fatalf("expected stale successful claude row to be replaced, got %+v", items) + if _, ok := byIdentity["new-claude-prefix"]; ok { + t.Fatalf("expected replacement prefix not to create usage identity, got %+v", items) } } -func TestSyncMetadataKeepsProviderRowsWhenEndpointReturnsNilResult(t *testing.T) { +func TestSyncMetadataKeepsProviderUsageIdentitiesWhenEndpointReturnsNilResult(t *testing.T) { db := openSyncTestDatabase(t) - if err := repository.ReplaceProviderMetadata(db, []repository.ProviderMetadataInput{{ - LookupKey: "old-gemini-key", - ProviderType: "gemini", - DisplayName: "Old Gemini", - ProviderKey: "gemini:Old Gemini", - MatchKind: "api_key", - }}); err != nil { - t.Fatalf("seed provider metadata: %v", err) + if err := db.Create(&models.UsageIdentity{ + Name: "Old Gemini", + AuthType: models.UsageIdentityAuthTypeAIProvider, + AuthTypeName: "apikey", + Identity: "old-gemini-key", + Type: "gemini", + Provider: "Old Gemini", + }).Error; err != nil { + t.Fatalf("seed usage identity: %v", err) } service := NewSyncServiceWithOptions(db, SyncServiceOptions{ BaseURL: "https://cpa.example.com", - MetadataFetcher: stubExportFetcher{geminiNilResult: true}, + MetadataFetcher: stubMetadataFetcher{geminiNilResult: true}, }) err := service.SyncMetadata(context.Background()) if err == nil || !strings.Contains(err.Error(), "gemini api keys response is nil") { t.Fatalf("expected nil gemini response warning, got %v", err) } - items, listErr := repository.ListProviderMetadata(db) + items, listErr := repository.ListUsageIdentities(context.Background(), db) if listErr != nil { - t.Fatalf("list provider metadata: %v", listErr) + t.Fatalf("list usage identities: %v", listErr) } - if len(items) != 1 || items[0].LookupKey != "old-gemini-key" { - t.Fatalf("expected old gemini metadata to remain, got %+v", items) + byIdentity := usageIdentitiesByIdentity(items) + oldGemini := byIdentity["old-gemini-key"] + if oldGemini.Identity == "" || oldGemini.IsDeleted || oldGemini.DeletedAt != nil { + t.Fatalf("expected old gemini usage identity to remain, got %+v", oldGemini) } } @@ -1184,42 +1028,10 @@ func TestSyncRedisBatchErrorDoesNotCreateSnapshot(t *testing.T) { RedisQueue: staticRedisQueue{err: errors.New("dial failed")}, }) - result, err := service.SyncRedisBatch(context.Background(), false) + result, err := service.SyncRedisBatch(context.Background()) if err == nil || result == nil || result.Status != "failed" { t.Fatalf("expected failed redis batch result, got result=%+v err=%v", result, err) } - var snapshotCount int64 - if countErr := db.Model(&models.SnapshotRun{}).Count(&snapshotCount).Error; countErr != nil { - t.Fatalf("count snapshot runs: %v", countErr) - } - if snapshotCount != 0 { - t.Fatalf("expected no snapshot runs after redis pop error, got %d", snapshotCount) - } -} - -func TestSyncOnceUsesRedisUsageFetcher(t *testing.T) { - db := openSyncTestDatabase(t) - service := NewSyncServiceWithOptions(db, SyncServiceOptions{ - BaseURL: "https://cpa.example.com", - UsageFetcher: redisUsageFetcher{queue: staticRedisQueue{messages: []string{`{"timestamp":"2026-04-27T08:00:00Z","provider":"claude","model":"sonnet","request_id":"redis-1","tokens":{"input_tokens":1,"output_tokens":2}}`}}}, - MetadataFetcher: stubExportFetcher{}, - }) - - result, err := service.SyncNow(context.Background()) - if err != nil { - t.Fatalf("SyncNow returned error: %v", err) - } - if result.InsertedEvents != 1 || result.HTTPStatus != 0 { - t.Fatalf("unexpected redis sync result: %+v", result) - } - - var event models.UsageEvent - if err := db.First(&event).Error; err != nil { - t.Fatalf("load usage event: %v", err) - } - if event.EventKey != "redis-1" || event.APIGroupKey != "claude" || event.Model != "sonnet" { - t.Fatalf("unexpected redis usage event: %+v", event) - } } func TestNewSyncServiceBuildsClientFromConfig(t *testing.T) { @@ -1237,29 +1049,6 @@ func TestNewSyncServiceBuildsClientFromConfig(t *testing.T) { } } -func equivalentExportResult(apiGroupKey, model string, timestamp time.Time, source, authIndex string, failed bool, latencyMS int64, tokens cpa.TokenStats) *cpa.ExportResult { - return &cpa.ExportResult{ - StatusCode: 200, - Body: []byte(`{"version":1}`), - Payload: cpa.UsageExport{ - Version: 1, - ExportedAt: timestamp.UTC(), - Usage: cpa.StatisticsSnapshot{APIs: map[string]cpa.APISnapshot{ - apiGroupKey: {Models: map[string]cpa.ModelSnapshot{ - model: {Details: []cpa.RequestDetail{{ - Timestamp: timestamp, - LatencyMS: latencyMS, - Source: source, - AuthIndex: authIndex, - Failed: failed, - Tokens: tokens, - }}}, - }}, - }}, - }, - } -} - func equivalentRedisMessage(apiGroupKey, model string, timestamp time.Time, source, authIndex string, failed bool, latencyMS int64, tokens cpa.TokenStats, requestID string) string { failedValue := "false" if failed { @@ -1272,6 +1061,14 @@ func int64String(value int64) string { return strconv.FormatInt(value, 10) } +func usageIdentitiesByIdentity(items []models.UsageIdentity) map[string]models.UsageIdentity { + byIdentity := make(map[string]models.UsageIdentity, len(items)) + for _, item := range items { + byIdentity[item.Identity] = item + } + return byIdentity +} + func assertUsageEventCount(t *testing.T, db *gorm.DB, expected int64) { t.Helper() var count int64 @@ -1283,31 +1080,10 @@ func assertUsageEventCount(t *testing.T, db *gorm.DB, expected int64) { } } -func successfulExportResult(body []byte) *cpa.ExportResult { - return &cpa.ExportResult{ - StatusCode: 200, - Body: body, - Payload: cpa.UsageExport{ - Version: 1, - ExportedAt: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), - Usage: cpa.StatisticsSnapshot{ - APIs: map[string]cpa.APISnapshot{ - "provider-a": { - Models: map[string]cpa.ModelSnapshot{ - "claude-sonnet": { - Details: []cpa.RequestDetail{{ - Timestamp: time.Date(2026, 4, 16, 9, 30, 0, 0, time.UTC), - LatencyMS: 123, - Source: "codex-a", - AuthIndex: "1", - Tokens: cpa.TokenStats{InputTokens: 10, OutputTokens: 20, ReasoningTokens: 5, TotalTokens: 35}, - }}, - }, - }, - }, - }, - }, - }, +func assertTableNotExists(t *testing.T, db *gorm.DB, table string) { + t.Helper() + if db.Migrator().HasTable(table) { + t.Fatalf("expected %s table not to exist", table) } } diff --git a/internal/service/usage.go b/internal/service/usage.go index 9d5f6830..f055d612 100644 --- a/internal/service/usage.go +++ b/internal/service/usage.go @@ -105,6 +105,8 @@ func (s *usageService) ListUsageEvents(_ context.Context, filter UsageFilter) (* Model: filter.Model, Source: filter.Source, AuthIndex: filter.AuthIndex, + AuthType: filter.AuthType, + Provider: filter.Provider, Result: filter.Result, }) if err != nil { @@ -117,6 +119,8 @@ func (s *usageService) ListUsageEvents(_ context.Context, filter UsageFilter) (* Timestamp: row.Timestamp, APIGroupKey: row.APIGroupKey, Model: row.Model, + AuthType: row.AuthType, + Provider: row.Provider, Source: row.Source, AuthIndex: row.AuthIndex, Failed: row.Failed, diff --git a/internal/service/usage_filter.go b/internal/service/usage_filter.go index dd2b18d7..eb0df651 100644 --- a/internal/service/usage_filter.go +++ b/internal/service/usage_filter.go @@ -17,6 +17,8 @@ type UsageFilter struct { Model string Source string AuthIndex string + AuthType string + Provider string Result string } @@ -42,6 +44,8 @@ type UsageEventRecord struct { Timestamp time.Time APIGroupKey string Model string + AuthType string + Provider string Source string AuthIndex string Failed bool diff --git a/internal/service/usage_filter_test.go b/internal/service/usage_filter_test.go index ae420940..76c56362 100644 --- a/internal/service/usage_filter_test.go +++ b/internal/service/usage_filter_test.go @@ -19,8 +19,8 @@ func TestUsageServiceGetUsageWithFilterDelegatesToFilteredSnapshot(t *testing.T) } closeTestDatabase(t, db) if _, _, err := repository.InsertUsageEvents(db, []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 10}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), TotalTokens: 20}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), TotalTokens: 10}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), TotalTokens: 20}, }); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) } @@ -52,8 +52,8 @@ func TestUsageServiceGetUsageOverviewDelegatesToFilteredOverview(t *testing.T) { t.Fatalf("UpsertModelPriceSetting returned error: %v", err) } if _, _, err := repository.InsertUsageEvents(db, []models.UsageEvent{ - {EventKey: "event-1", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), InputTokens: 1000, OutputTokens: 500, CachedTokens: 100, ReasoningTokens: 50, TotalTokens: 1650}, - {EventKey: "event-2", SnapshotRunID: 1, APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), InputTokens: 500, OutputTokens: 250, CachedTokens: 0, ReasoningTokens: 25, TotalTokens: 775}, + {EventKey: "event-1", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 9, 0, 0, 0, time.UTC), InputTokens: 1000, OutputTokens: 500, CachedTokens: 100, ReasoningTokens: 50, TotalTokens: 1650}, + {EventKey: "event-2", APIGroupKey: "provider-a", Model: "claude-sonnet", Timestamp: time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC), InputTokens: 500, OutputTokens: 250, CachedTokens: 0, ReasoningTokens: 25, TotalTokens: 775}, }); err != nil { t.Fatalf("InsertUsageEvents returned error: %v", err) } diff --git a/internal/service/usage_identities_service.go b/internal/service/usage_identities_service.go new file mode 100644 index 00000000..38a8686c --- /dev/null +++ b/internal/service/usage_identities_service.go @@ -0,0 +1,25 @@ +package service + +import ( + "context" + + "cpa-usage-keeper/internal/models" + "cpa-usage-keeper/internal/repository" + "gorm.io/gorm" +) + +type UsageIdentityProvider interface { + ListUsageIdentities(context.Context) ([]models.UsageIdentity, error) +} + +type usageIdentityService struct { + db *gorm.DB +} + +func NewUsageIdentityService(db *gorm.DB) UsageIdentityProvider { + return &usageIdentityService{db: db} +} + +func (s *usageIdentityService) ListUsageIdentities(ctx context.Context) ([]models.UsageIdentity, error) { + return repository.ListUsageIdentities(ctx, s.db) +} diff --git a/web/src/components/ui/icons.tsx b/web/src/components/ui/icons.tsx index 6b61280e..5fe2b2c7 100644 --- a/web/src/components/ui/icons.tsx +++ b/web/src/components/ui/icons.tsx @@ -396,16 +396,6 @@ export function IconSidebarProviders({ size = 20, ...props }: IconProps) { ); } -export function IconSidebarAuthFiles({ size = 20, ...props }: IconProps) { - return ( - - - - - - ); -} - export function IconSidebarOauth({ size = 20, ...props }: IconProps) { return ( diff --git a/web/src/components/usage/CredentialStatsCard.test.ts b/web/src/components/usage/CredentialStatsCard.test.ts index 246edb3e..2589d53d 100644 --- a/web/src/components/usage/CredentialStatsCard.test.ts +++ b/web/src/components/usage/CredentialStatsCard.test.ts @@ -1,29 +1,55 @@ import { describe, expect, it } from 'vitest'; -import type { UsageCredential } from '@/lib/types'; +import type { UsageIdentity } from '@/lib/types'; import { buildCredentialRows, getTopCredentialRows } from './CredentialStatsCard'; +const usageIdentity = (overrides: Partial): UsageIdentity => ({ + id: 1, + name: '', + auth_type: 1, + auth_type_name: 'oauth', + identity: '', + type: '', + provider: '', + total_requests: 0, + success_count: 0, + failure_count: 0, + input_tokens: 0, + output_tokens: 0, + reasoning_tokens: 0, + cached_tokens: 0, + total_tokens: 0, + last_aggregated_usage_event_id: 0, + is_deleted: false, + created_at: '2026-05-04T00:00:00Z', + updated_at: '2026-05-04T00:00:00Z', + ...overrides, +}); + describe('CredentialStatsCard helpers', () => { it('sorts credentials by total request count descending', () => { - const credentials: UsageCredential[] = [ - { - source: 'low', - source_key: 'low', + const credentials = [ + usageIdentity({ + id: 1, + identity: 'low', success_count: 1, - failure_count: 0, - total_count: 1, - }, - { - source: 'high', - source_key: 'high', + total_requests: 1, + }), + usageIdentity({ + id: 2, + name: 'High Provider', + auth_type: 2, + auth_type_name: 'apikey', + identity: 'sk-a***1234', + type: 'claude', success_count: 8, failure_count: 2, - total_count: 10, - }, - ]; + total_requests: 10, + }), + ] satisfies UsageIdentity[]; const rows = buildCredentialRows(credentials); - expect(rows.map((row) => row.displayName)).toEqual(['high', 'low']); + expect(rows.map((row) => row.displayName)).toEqual(['High Provider', 'low']); expect(rows[0]).toMatchObject({ success: 8, failure: 2, @@ -32,38 +58,49 @@ describe('CredentialStatsCard helpers', () => { }); }); + it('prefers identity type over auth type name for the credential tag', () => { + const credentials = [ + usageIdentity({ + auth_type_name: 'apikey', + identity: 'sk-a***1234', + type: 'openai', + }), + ] satisfies UsageIdentity[]; + + const rows = buildCredentialRows(credentials); + + expect(rows[0].type).toBe('openai'); + }); + it('falls back to success plus failure when total count is empty', () => { - const rows = buildCredentialRows([ - { - source: 'fallback-total', - source_key: 'fallback-total', + const credentials = [ + usageIdentity({ + identity: 'fallback-total', success_count: 3, failure_count: 2, - total_count: 0, - }, - ]); + total_requests: 0, + }), + ] satisfies UsageIdentity[]; + + const rows = buildCredentialRows(credentials); expect(rows[0].total).toBe(5); expect(rows[0].successRate).toBe(60); }); it('returns only the top 10 non-empty credential rows', () => { - const credentials: UsageCredential[] = [ - { - source: 'empty', - source_key: 'empty', - success_count: 0, - failure_count: 0, - total_count: 0, - }, - ...Array.from({ length: 12 }, (_, index) => ({ - source: `credential-${index + 1}`, - source_key: `credential-${index + 1}`, + const credentials = [ + usageIdentity({ + id: 1, + identity: 'empty', + }), + ...Array.from({ length: 12 }, (_, index) => usageIdentity({ + id: index + 2, + identity: `credential-${index + 1}`, success_count: index + 1, - failure_count: 0, - total_count: index + 1, + total_requests: index + 1, })), - ]; + ] satisfies UsageIdentity[]; const rows = buildCredentialRows(credentials); const topRows = getTopCredentialRows(rows); diff --git a/web/src/components/usage/CredentialStatsCard.tsx b/web/src/components/usage/CredentialStatsCard.tsx index 14ee3a63..86273fd4 100644 --- a/web/src/components/usage/CredentialStatsCard.tsx +++ b/web/src/components/usage/CredentialStatsCard.tsx @@ -4,11 +4,11 @@ import { Bar } from 'react-chartjs-2'; import type { ChartData, ChartOptions, TooltipItem } from 'chart.js'; import { Card } from '@/components/ui/Card'; import { formatCompactNumber } from '@/utils/usage'; -import type { UsageCredential } from '@/lib/types'; +import type { UsageIdentity } from '@/lib/types'; import styles from '@/pages/UsagePage.module.scss'; export interface CredentialStatsCardProps { - credentials: UsageCredential[]; + credentials: UsageIdentity[]; loading: boolean; } @@ -22,15 +22,15 @@ export interface CredentialRow { successRate: number; } -export function buildCredentialRows(credentials: UsageCredential[]): CredentialRow[] { +export function buildCredentialRows(credentials: UsageIdentity[]): CredentialRow[] { return credentials .map((credential) => { - const displayName = String(credential.source ?? '').trim() || '-'; - const sourceType = String(credential.source_type ?? '').trim(); - const key = String(credential.source_key ?? '').trim() || displayName; + const displayName = String(credential.name || credential.identity || '').trim() || '-'; + const sourceType = String(credential.type || credential.auth_type_name || '').trim(); + const key = String(credential.id || credential.identity || '').trim() || displayName; const success = Number(credential.success_count) || 0; const failure = Number(credential.failure_count) || 0; - const total = Number(credential.total_count) || success + failure; + const total = Number(credential.total_requests) || success + failure; return { key, displayName, diff --git a/web/src/components/usage/PriceSettingsCard.tsx b/web/src/components/usage/PriceSettingsCard.tsx index aa8157a9..ae2f74ba 100644 --- a/web/src/components/usage/PriceSettingsCard.tsx +++ b/web/src/components/usage/PriceSettingsCard.tsx @@ -156,7 +156,7 @@ export function PriceSettingsCard({ subtitle={t('usage_stats.model_price_settings_subtitle')} /> } - className={styles.detailsFixedCard} + className={`${styles.detailsFixedCard} ${styles.pricingFixedCard}`} >
{loading && modelNames.length === 0 && Object.keys(modelPrices).length === 0 ? ( diff --git a/web/src/lib/api.test.ts b/web/src/lib/api.test.ts index 1f7d74dd..a707d166 100644 --- a/web/src/lib/api.test.ts +++ b/web/src/lib/api.test.ts @@ -1,5 +1,5 @@ import { afterEach, describe, expect, it, vi } from 'vitest'; -import { fetchUsageEventFilterOptions, fetchUsageEvents, triggerSync } from './api'; +import { fetchUsageEventFilterOptions, fetchUsageEvents, fetchUsageIdentities, triggerSync } from './api'; describe('fetchUsageEvents', () => { afterEach(() => { @@ -7,7 +7,7 @@ describe('fetchUsageEvents', () => { vi.unstubAllGlobals(); }); - it('loads stable filter options without event pagination or selected filters', async () => { + it('loads stable filter options without query params', async () => { vi.stubGlobal('window', { __APP_BASE_PATH__: undefined }); const fetchMock = vi.spyOn(globalThis, 'fetch').mockResolvedValue({ ok: true, @@ -15,16 +15,17 @@ describe('fetchUsageEvents', () => { } as Response); const signal = new AbortController().signal; - const response = await fetchUsageEventFilterOptions('custom', '2026-04-20T00:00:00Z', '2026-04-21T00:00:00Z', signal); + const response = await fetchUsageEventFilterOptions(signal); const [url, init] = fetchMock.mock.calls[0]; const parsed = new URL(String(url), 'http://localhost'); expect(response.models).toEqual(['claude-sonnet']); expect(parsed.pathname).toBe('/api/v1/usage/events/filters'); - expect(parsed.searchParams.get('range')).toBe('custom'); - expect(parsed.searchParams.get('start')).toBe('2026-04-20T00:00:00Z'); - expect(parsed.searchParams.get('end')).toBe('2026-04-21T00:00:00Z'); + expect(parsed.search).toBe(''); + expect(parsed.searchParams.get('range')).toBeNull(); + expect(parsed.searchParams.get('start')).toBeNull(); + expect(parsed.searchParams.get('end')).toBeNull(); expect(parsed.searchParams.get('page')).toBeNull(); expect(parsed.searchParams.get('page_size')).toBeNull(); expect(parsed.searchParams.get('model')).toBeNull(); @@ -65,6 +66,51 @@ describe('fetchUsageEvents', () => { expect(init).toMatchObject({ credentials: 'include', signal }); }); + it('loads unified usage identities for credential stats', async () => { + vi.stubGlobal('window', { __APP_BASE_PATH__: undefined }); + const fetchMock = vi.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: true, + json: async () => ({ + identities: [ + { + id: 1, + name: 'Claude primary', + auth_type: 2, + auth_type_name: 'apikey', + identity: 'sk-a***1234', + type: 'claude', + provider: 'anthropic', + total_requests: 3, + success_count: 2, + failure_count: 1, + input_tokens: 10, + output_tokens: 20, + reasoning_tokens: 0, + cached_tokens: 0, + total_tokens: 30, + last_aggregated_usage_event_id: 9, + is_deleted: false, + created_at: '2026-05-04T00:00:00Z', + updated_at: '2026-05-04T00:00:00Z', + }, + ], + }), + } as Response); + const signal = new AbortController().signal; + + const response = await fetchUsageIdentities(signal); + + const [url, init] = fetchMock.mock.calls[0]; + const parsed = new URL(String(url), 'http://localhost'); + + expect(response.identities[0].identity).toBe('sk-a***1234'); + expect(response.identities[0].auth_type).toBe(2); + expect(typeof response.identities[0].auth_type).toBe('number'); + expect(parsed.pathname).toBe('/api/v1/usage/identities'); + expect(parsed.search).toBe(''); + expect(init).toMatchObject({ credentials: 'include', signal }); + }); + it('posts to the manual sync endpoint', async () => { vi.stubGlobal('window', { __APP_BASE_PATH__: undefined }); const fetchMock = vi.spyOn(globalThis, 'fetch').mockResolvedValue({ diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index d86e8b92..30744abf 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -1,4 +1,4 @@ -import type { AuthSessionResponse, PricingEntry, PricingResponse, StatusResponse, UsageAnalysisResponse, UsageEventFilterOptionsResponse, UsedModelsResponse, UsageCredentialsResponse, UsageEventsResponse, UsageOverviewResponse } from './types' +import type { AuthSessionResponse, PricingEntry, PricingResponse, StatusResponse, UsageAnalysisResponse, UsageEventFilterOptionsResponse, UsedModelsResponse, UsageIdentitiesResponse, UsageEventsResponse, UsageOverviewResponse } from './types' export class ApiError extends Error { status: number @@ -96,17 +96,8 @@ export interface FetchUsageEventsOptions { result?: string } -export async function fetchUsageEventFilterOptions(range: string, start?: string, end?: string, signal?: AbortSignal): Promise { - const params = new URLSearchParams() - params.set('range', range) - if (start) { - params.set('start', start) - } - if (end) { - params.set('end', end) - } - const query = params.toString() - const response = await apiFetch(`${apiPath('/usage/events/filters')}${query ? `?${query}` : ''}`, { signal }) +export async function fetchUsageEventFilterOptions(signal?: AbortSignal): Promise { + const response = await apiFetch(apiPath('/usage/events/filters'), { signal }) if (!response.ok) { await parseApiError(response, `Failed to load usage event filters: ${response.status}`) } @@ -148,19 +139,10 @@ export async function fetchUsageEvents(range: string, start?: string, end?: stri return response.json() } -export async function fetchUsageCredentials(range: string, start?: string, end?: string, signal?: AbortSignal): Promise { - const params = new URLSearchParams() - params.set('range', range) - if (start) { - params.set('start', start) - } - if (end) { - params.set('end', end) - } - const query = params.toString() - const response = await apiFetch(`${apiPath('/usage/credentials')}${query ? `?${query}` : ''}`, { signal }) +export async function fetchUsageIdentities(signal?: AbortSignal): Promise { + const response = await apiFetch(apiPath('/usage/identities'), { signal }) if (!response.ok) { - await parseApiError(response, `Failed to load usage credentials: ${response.status}`) + await parseApiError(response, `Failed to load usage identities: ${response.status}`) } return response.json() } diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 1fbe4802..25245cb8 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -161,17 +161,36 @@ export interface UsageEventFilterOptionsResponse { sources: UsageSourceFilterOption[] } -export interface UsageCredential { - source: string - source_type?: string - source_key?: string +export type UsageIdentityAuthType = 1 | 2 + +export interface UsageIdentity { + id: number + name: string + auth_type: UsageIdentityAuthType + auth_type_name: string + identity: string + type: string + provider: string + total_requests: number success_count: number failure_count: number - total_count: number -} - -export interface UsageCredentialsResponse { - credentials: UsageCredential[] + input_tokens: number + output_tokens: number + reasoning_tokens: number + cached_tokens: number + total_tokens: number + last_aggregated_usage_event_id: number + first_used_at?: string + last_used_at?: string + stats_updated_at?: string + is_deleted: boolean + created_at: string + updated_at: string + deleted_at?: string +} + +export interface UsageIdentitiesResponse { + identities: UsageIdentity[] } export interface UsageAnalysisModel { diff --git a/web/src/pages/UsagePage.logic.test.ts b/web/src/pages/UsagePage.logic.test.ts index 027ffb04..9c2b744f 100644 --- a/web/src/pages/UsagePage.logic.test.ts +++ b/web/src/pages/UsagePage.logic.test.ts @@ -1,5 +1,6 @@ import { afterEach, describe, expect, it, vi } from 'vitest'; import { buildCustomDateRangeQuery, getOverviewChartEndMs, getOverviewDisplayLoading, getOverviewHourWindowHours, getTimeRangeOptions, getUsageTabOptions, refreshPageData, sanitizeRequestEventFilters, scheduleOverviewAutoRefresh, syncCpaData } from './UsagePage'; +import { ApiError } from '@/lib/api'; import { filterUsageByWindow, type UsageFilterWindow } from '@/utils/usage'; import type { StatusResponse, UsageSnapshot } from '@/lib/types'; @@ -334,4 +335,37 @@ describe('UsagePage sync action', () => { expect(calls).toEqual(['sync', 'refresh', 'status', 'set-status']); expect(receivedStatus).toBe(refreshedStatus); }); + + it('reloads status and preserves the sync error when backend sync fails', async () => { + const calls: string[] = []; + let receivedStatus: StatusResponse | null = null; + const refreshedStatus: StatusResponse = { + running: true, + sync_running: false, + last_status: 'completed', + last_run_at: '2026-04-26T13:00:00.000Z', + }; + const syncError = new ApiError('metadata sync failed', 500); + + await expect(syncCpaData({ + triggerBackendSync: async () => { + calls.push('sync'); + throw syncError; + }, + refreshActiveTab: async () => { + calls.push('refresh'); + }, + refreshStatus: async () => { + calls.push('status'); + return refreshedStatus; + }, + onStatus: (status) => { + calls.push('set-status'); + receivedStatus = status; + }, + })).rejects.toBe(syncError); + + expect(calls).toEqual(['sync', 'status', 'set-status']); + expect(receivedStatus).toBe(refreshedStatus); + }); }); diff --git a/web/src/pages/UsagePage.module.scss b/web/src/pages/UsagePage.module.scss index a8d91675..a142c75c 100644 --- a/web/src/pages/UsagePage.module.scss +++ b/web/src/pages/UsagePage.module.scss @@ -1135,6 +1135,14 @@ } // Pricing Section (80%比例) +.pricingFixedCard { + @include mobile { + height: auto; + min-height: 0; + overflow: visible; + } +} + .pricingSection { display: flex; flex-direction: column; @@ -1215,6 +1223,12 @@ overflow: auto; -webkit-overflow-scrolling: touch; overscroll-behavior: contain; + + .pricingFixedCard & { + @include mobile { + overflow: visible; + } + } } .priceItem { @@ -1264,6 +1278,15 @@ display: flex; gap: 3px; flex-shrink: 0; + + @include mobile { + width: 100%; + + :global(.btn) { + flex: 1 1 0; + justify-content: center; + } + } } .editModalBody { diff --git a/web/src/pages/UsagePage.tsx b/web/src/pages/UsagePage.tsx index dd9b8dad..b8ed4517 100644 --- a/web/src/pages/UsagePage.tsx +++ b/web/src/pages/UsagePage.tsx @@ -14,8 +14,8 @@ import { Legend, Filler } from 'chart.js'; -import { ApiError, fetchStatus, fetchUsageAnalysis, fetchUsageCredentials, fetchUsageEventFilterOptions, fetchUsageEvents, triggerSync } from '@/lib/api'; -import type { StatusResponse, UsageAnalysisResponse, UsageCredential, UsageEvent, UsageSourceFilterOption } from '@/lib/types'; +import { ApiError, fetchStatus, fetchUsageAnalysis, fetchUsageEventFilterOptions, fetchUsageEvents, fetchUsageIdentities, triggerSync } from '@/lib/api'; +import type { StatusResponse, UsageAnalysisResponse, UsageEvent, UsageIdentity, UsageSourceFilterOption } from '@/lib/types'; import { Button } from '@/components/ui/Button'; import { LoadingSpinner } from '@/components/ui/LoadingSpinner'; import { Select } from '@/components/ui/Select'; @@ -200,10 +200,22 @@ export const scheduleOverviewAutoRefresh = ({ }; export const syncCpaData = async ({ triggerBackendSync, refreshActiveTab, refreshStatus, onStatus }: SyncCpaDataOptions) => { - await triggerBackendSync(); - await refreshActiveTab(); - const nextStatus = await refreshStatus(); - onStatus(nextStatus); + try { + await triggerBackendSync(); + await refreshActiveTab(); + const nextStatus = await refreshStatus(); + onStatus(nextStatus); + } catch (error) { + if (!(error instanceof ApiError && error.status === 401)) { + try { + const nextStatus = await refreshStatus(); + onStatus(nextStatus); + } catch { + // 忽略状态刷新失败,继续抛出原始同步错误。 + } + } + throw error; + } }; export const sanitizeRequestEventFilters = ( @@ -447,7 +459,7 @@ export function UsagePage({ onAuthRequired }: { onAuthRequired?: () => void }) { const [manualSyncLoading, setManualSyncLoading] = useState(false); const [credentialsLoading, setCredentialsLoading] = useState(false); const [credentialsError, setCredentialsError] = useState(''); - const [credentialsData, setCredentialsData] = useState([]); + const [credentialsData, setCredentialsData] = useState([]); const credentialsRequestControllerRef = useRef(null); const [analysisLoading, setAnalysisLoading] = useState(false); const [analysisError, setAnalysisError] = useState(''); @@ -672,17 +684,8 @@ export function UsagePage({ onAuthRequired }: { onAuthRequired?: () => void }) { }, [customTimeRange.end, customTimeRange.start, timeRange]); const loadEventFilterOptions = useCallback(async () => { - const queryWindow = getEventQueryWindow(); - const optionsKey = `${timeRange}|${queryWindow.start ?? ''}|${queryWindow.end ?? ''}`; + const optionsKey = 'stable'; - if (!queryWindow.valid) { - eventsFilterOptionsRequestControllerRef.current?.abort(); - eventsFilterOptionsRequestControllerRef.current = null; - loadedEventsFilterOptionsKeyRef.current = ''; - setEventsModelOptions([]); - setEventsSourceOptions([]); - return; - } if (loadedEventsFilterOptionsKeyRef.current === optionsKey) { return; } @@ -692,7 +695,7 @@ export function UsagePage({ onAuthRequired }: { onAuthRequired?: () => void }) { eventsFilterOptionsRequestControllerRef.current = controller; try { - const response = await fetchUsageEventFilterOptions(timeRange, queryWindow.start, queryWindow.end, controller.signal); + const response = await fetchUsageEventFilterOptions(controller.signal); if (eventsFilterOptionsRequestControllerRef.current !== controller) { return; } @@ -716,7 +719,7 @@ export function UsagePage({ onAuthRequired }: { onAuthRequired?: () => void }) { eventsFilterOptionsRequestControllerRef.current = null; } } - }, [getEventQueryWindow, onAuthRequired, timeRange]); + }, [onAuthRequired]); const loadEvents = useCallback(async () => { const queryWindow = getEventQueryWindow(); @@ -802,27 +805,6 @@ export function UsagePage({ onAuthRequired }: { onAuthRequired?: () => void }) { }, [resetEventsPage]); const loadCredentials = useCallback(async () => { - if (timeRange === 'custom') { - if (!customTimeRange.start || !customTimeRange.end) { - credentialsRequestControllerRef.current?.abort(); - credentialsRequestControllerRef.current = null; - setCredentialsData([]); - setCredentialsError(''); - setCredentialsLoading(false); - return; - } - const startMs = parseCustomDateStart(customTimeRange.start); - const endMs = parseCustomDateEnd(customTimeRange.end); - if (startMs === undefined || endMs === undefined || startMs > endMs) { - credentialsRequestControllerRef.current?.abort(); - credentialsRequestControllerRef.current = null; - setCredentialsData([]); - setCredentialsError(''); - setCredentialsLoading(false); - return; - } - } - credentialsRequestControllerRef.current?.abort(); const controller = new AbortController(); credentialsRequestControllerRef.current = controller; @@ -831,12 +813,11 @@ export function UsagePage({ onAuthRequired }: { onAuthRequired?: () => void }) { setCredentialsError(''); setCredentialsData([]); try { - const queryWindow = timeRange === 'custom' ? buildCustomDateRangeQuery({ start: customTimeRange.start, end: customTimeRange.end }) : { start: undefined, end: undefined }; - const response = await fetchUsageCredentials(timeRange, queryWindow.start, queryWindow.end, controller.signal); + const response = await fetchUsageIdentities(controller.signal); if (credentialsRequestControllerRef.current !== controller) { return; } - setCredentialsData(response.credentials); + setCredentialsData(response.identities); } catch (error) { if (controller.signal.aborted) { return; @@ -848,14 +829,14 @@ export function UsagePage({ onAuthRequired }: { onAuthRequired?: () => void }) { onAuthRequired?.(); return; } - setCredentialsError(error instanceof Error ? error.message : 'Failed to load usage credentials'); + setCredentialsError(error instanceof Error ? error.message : 'Failed to load usage identities'); } finally { if (credentialsRequestControllerRef.current === controller) { setCredentialsLoading(false); credentialsRequestControllerRef.current = null; } } - }, [customTimeRange.end, customTimeRange.start, onAuthRequired, timeRange]); + }, [onAuthRequired]); const refreshActiveTab = useCallback(async () => { if (activeTab === 'events') { diff --git a/web/src/services/api/authFiles.ts b/web/src/services/api/authFiles.ts deleted file mode 100644 index b8eaab98..00000000 --- a/web/src/services/api/authFiles.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { ApiError, apiPath } from '@/lib/api'; -import type { AuthFileItem } from '@/types/authFile'; - -export async function fetchAuthFiles(signal?: AbortSignal): Promise<{ files: AuthFileItem[] }> { - const response = await fetch(apiPath('/auth-files'), { - credentials: 'include', - signal, - }); - if (!response.ok) { - let message = `Failed to load auth files: ${response.status}`; - try { - const payload = await response.json() as { error?: string }; - if (payload.error) message = payload.error; - } catch { - // ignore invalid error payloads - } - throw new ApiError(message, response.status); - } - return response.json(); -} - -export const authFilesApi = { - list: fetchAuthFiles, -}; diff --git a/web/src/services/api/providerMetadata.ts b/web/src/services/api/providerMetadata.ts deleted file mode 100644 index db651789..00000000 --- a/web/src/services/api/providerMetadata.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { ApiError, apiPath } from '@/lib/api'; -import type { ProviderMetadataItem } from '@/types/providerMetadata'; - -export async function fetchProviderMetadata(signal?: AbortSignal): Promise<{ items: ProviderMetadataItem[] }> { - const response = await fetch(apiPath('/provider-metadata'), { - credentials: 'include', - signal, - }); - if (!response.ok) { - let message = `Failed to load provider metadata: ${response.status}`; - try { - const payload = (await response.json()) as { error?: string }; - if (payload.error) message = payload.error; - } catch { - // ignore invalid error payloads - } - throw new ApiError(message, response.status); - } - return response.json(); -} - -export const providerMetadataApi = { - list: fetchProviderMetadata, -}; diff --git a/web/src/types/authFile.ts b/web/src/types/authFile.ts deleted file mode 100644 index a6e45ae2..00000000 --- a/web/src/types/authFile.ts +++ /dev/null @@ -1,8 +0,0 @@ -export interface AuthFileItem { - auth_index?: string; - authIndex?: string; - name?: string; - email?: string; - type?: string; - provider?: string; -} diff --git a/web/src/types/providerMetadata.ts b/web/src/types/providerMetadata.ts deleted file mode 100644 index 03bf4443..00000000 --- a/web/src/types/providerMetadata.ts +++ /dev/null @@ -1,6 +0,0 @@ -export interface ProviderMetadataItem { - lookup_key: string; - provider_type?: string; - display_name?: string; - provider_key?: string; -} diff --git a/web/src/types/sourceInfo.ts b/web/src/types/sourceInfo.ts deleted file mode 100644 index f5b2750a..00000000 --- a/web/src/types/sourceInfo.ts +++ /dev/null @@ -1,5 +0,0 @@ -export interface CredentialInfo { - name: string; - type: string; - key?: string; -} diff --git a/web/src/utils/sourceResolver.ts b/web/src/utils/sourceResolver.ts deleted file mode 100644 index 19a2a60a..00000000 --- a/web/src/utils/sourceResolver.ts +++ /dev/null @@ -1,145 +0,0 @@ -import type { GeminiKeyConfig, OpenAIProviderConfig, ProviderKeyConfig } from '@/types'; -import type { CredentialInfo } from '@/types/sourceInfo'; - -interface SourceInfoInput { - geminiApiKeys: GeminiKeyConfig[]; - claudeApiKeys: ProviderKeyConfig[]; - codexApiKeys: ProviderKeyConfig[]; - vertexApiKeys: ProviderKeyConfig[]; - openaiCompatibility: OpenAIProviderConfig[]; -} - -export interface ResolvedSourceDisplay { - displayName: string; - type: string; - key: string; -} - -function maskValue(value: string): string { - const normalized = value.trim(); - if (!normalized || normalized === '-') return '-'; - if (normalized.length <= 4) return '*'.repeat(normalized.length); - if (normalized.length <= 8) return `${normalized.slice(0, 1)}${'*'.repeat(normalized.length - 2)}${normalized.slice(-1)}`; - return `${normalized.slice(0, 4)}${'*'.repeat(normalized.length - 8)}${normalized.slice(-4)}`; -} - -function looksLikeEmail(value: string): boolean { - const normalized = value.trim(); - return /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(normalized); -} - -function inferProviderType(sourceKey: string): string { - const value = sourceKey.trim().toLowerCase(); - if (!value) return ''; - if (value.startsWith('sk-ant-') || value.includes('anthropic') || value.includes('claude')) return 'claude'; - if (value.startsWith('sk-proj-') || value.startsWith('sk-') || value.includes('openai') || value.includes('gpt')) return 'openai'; - if (value.startsWith('aiza') || value.includes('gemini')) return 'gemini'; - if (value.includes('vertex')) return 'vertex'; - if (value.includes('codex')) return 'codex'; - if (value.includes('ampcode')) return 'ampcode'; - return ''; -} - -function buildFallbackSource(sourceKey: string): ResolvedSourceDisplay { - if (looksLikeEmail(sourceKey)) { - return { - displayName: sourceKey, - type: '', - key: `email:${sourceKey}`, - }; - } - - const inferredType = inferProviderType(sourceKey); - if (inferredType) { - return { - displayName: inferredType, - type: inferredType, - key: `provider:fallback:${inferredType}`, - }; - } - - const masked = maskValue(sourceKey || '-'); - return { - displayName: masked, - type: '', - key: `raw:${masked}`, - }; -} - -export function buildSourceInfoMap({ - geminiApiKeys, - claudeApiKeys, - codexApiKeys, - vertexApiKeys, - openaiCompatibility -}: SourceInfoInput): Map { - const map = new Map(); - - const addProviderEntries = (items: Array, type: string) => { - items.forEach((item, index) => { - const label = item.prefix?.trim() || item.name?.trim() || `${type} #${index + 1}`; - const key = `${type}:${label}`; - if (item.apiKey) map.set(item.apiKey, { name: label, type, key }); - if (item.prefix) map.set(item.prefix, { name: label, type, key }); - }); - }; - - addProviderEntries(geminiApiKeys, 'gemini'); - addProviderEntries(claudeApiKeys, 'claude'); - addProviderEntries(codexApiKeys, 'codex'); - addProviderEntries(vertexApiKeys, 'vertex'); - - openaiCompatibility.forEach((provider, index) => { - const label = provider.name?.trim() || provider.prefix?.trim() || `openai #${index + 1}`; - const key = `openai:${label}`; - if (provider.prefix) map.set(provider.prefix, { name: label, type: 'openai', key }); - provider.apiKeyEntries?.forEach((entry) => { - if (entry.apiKey) map.set(entry.apiKey, { name: label, type: 'openai', key }); - }); - }); - - return map; -} - -export function resolveSourceDisplay( - sourceRaw: string, - authIndex: unknown, - sourceInfoMap: Map, - authFileMap: Map, - providerMetadataMap?: Map -): ResolvedSourceDisplay { - const sourceKey = sourceRaw?.trim(); - const normalizedAuthIndex = - authIndex === null || authIndex === undefined || String(authIndex).trim() === '' - ? '' - : String(authIndex).trim(); - - if (normalizedAuthIndex && authFileMap.has(normalizedAuthIndex)) { - const source = authFileMap.get(normalizedAuthIndex)!; - return { - displayName: source.name, - type: source.type, - key: source.key || `auth:${normalizedAuthIndex}`, - }; - } - - if (sourceKey && providerMetadataMap?.has(sourceKey)) { - const source = providerMetadataMap.get(sourceKey)!; - return { - displayName: source.name, - type: source.type, - key: source.key || `provider:${source.type}:${source.name}`, - }; - } - - if (sourceKey && sourceInfoMap.has(sourceKey)) { - const source = sourceInfoMap.get(sourceKey)!; - return { - displayName: source.name, - type: source.type, - key: source.key || `provider:${source.type}:${source.name}`, - }; - } - - return buildFallbackSource(sourceKey || '-'); -}