diff --git a/packages/ai-bot/main.ts b/packages/ai-bot/main.ts index f7fcda49cea..532c9a85d86 100644 --- a/packages/ai-bot/main.ts +++ b/packages/ai-bot/main.ts @@ -40,7 +40,10 @@ import { import type { MatrixEvent as DiscreteMatrixEvent } from 'https://cardstack.com/base/matrix-event'; import * as Sentry from '@sentry/node'; -import { saveUsageCost } from '@cardstack/billing/ai-billing'; +import { + spendUsageCost, + fetchGenerationCostWithBackoff, +} from '@cardstack/billing/ai-billing'; import { PgAdapter } from '@cardstack/postgres'; import type { ChatCompletionMessageParam } from 'openai/resources'; import type { OpenAIError } from 'openai/error'; @@ -86,22 +89,41 @@ class Assistant { this.aiBotInstanceId = aiBotInstanceId; } - async trackAiUsageCost(matrixUserId: string, generationId: string) { + async trackAiUsageCost( + matrixUserId: string, + opts: { costInUsd?: number; generationId?: string }, + ) { if (trackAiUsageCostPromises.has(matrixUserId)) { return; } - // intentionally do not await saveUsageCost promise - it has a backoff mechanism to retry if the cost is not immediately available so we don't want to block the main thread - trackAiUsageCostPromises.set( - matrixUserId, - saveUsageCost( - this.pgAdapter, - matrixUserId, - generationId, - process.env.OPENROUTER_API_KEY!, - ).finally(() => { - trackAiUsageCostPromises.delete(matrixUserId); - }), - ); + const promise = (async () => { + let { costInUsd, generationId } = opts; + if ( + typeof costInUsd === 'number' && + Number.isFinite(costInUsd) && + costInUsd > 0 + ) { + await spendUsageCost(this.pgAdapter, matrixUserId, costInUsd); + } else if (generationId) { + log.info( + `No inline cost for user ${matrixUserId}, falling back to generation cost API (generationId: ${generationId})`, + ); + const fetchedCost = await fetchGenerationCostWithBackoff( + generationId, + process.env.OPENROUTER_API_KEY!, + ); + if (fetchedCost !== null) { + await spendUsageCost(this.pgAdapter, matrixUserId, fetchedCost); + } + } else { + log.warn( + `No usage cost and no generation ID for user ${matrixUserId}, skipping credit deduction`, + ); + } + })().finally(() => { + trackAiUsageCostPromises.delete(matrixUserId); + }); + trackAiUsageCostPromises.set(matrixUserId, promise); } getResponse(prompt: PromptParts, senderMatrixUserId?: string) { @@ -288,10 +310,9 @@ Common issues are: isCanceled: true, }); if (activeGeneration.lastGeneratedChunkId) { - await assistant.trackAiUsageCost( - senderMatrixUserId, - activeGeneration.lastGeneratedChunkId, - ); + await assistant.trackAiUsageCost(senderMatrixUserId, { + generationId: activeGeneration.lastGeneratedChunkId, + }); } activeGenerations.delete(room.roomId); } @@ -448,6 +469,7 @@ Common issues are: let chunkHandlingError: string | undefined; let generationId: string | undefined; + let costInUsd: number | undefined; log.info( `[${eventId}] Starting generation with model %s`, promptParts.model, @@ -471,6 +493,9 @@ Common issues are: }); } generationId = chunk.id; + if (chunk.usage && (chunk.usage as any).cost != null) { + costInUsd = (chunk.usage as any).cost; + } let activeGeneration = activeGenerations.get(room.roomId); if (activeGeneration) { activeGeneration.lastGeneratedChunkId = generationId; @@ -525,9 +550,10 @@ Common issues are: await responder.onError(error as OpenAIError); } } finally { - if (generationId) { - assistant.trackAiUsageCost(senderMatrixUserId, generationId); - } + assistant.trackAiUsageCost(senderMatrixUserId, { + costInUsd, + generationId, + }); activeGenerations.delete(room.roomId); } diff --git a/packages/billing/ai-billing.ts b/packages/billing/ai-billing.ts index 97197b90415..d65b6ecb468 100644 --- a/packages/billing/ai-billing.ts +++ b/packages/billing/ai-billing.ts @@ -109,50 +109,7 @@ export async function spendUsageCost( } } -export async function saveUsageCost( - dbAdapter: DBAdapter, - matrixUserId: string, - generationId: string, - openRouterApiKey: string, -) { - try { - // Generation data is sometimes not immediately available, so we retry a couple of times until we are able to get the cost - let costInUsd = await fetchGenerationCostWithBackoff( - generationId, - openRouterApiKey, - ); - - if (costInUsd === null) { - Sentry.captureException( - new Error( - `Failed to fetch generation cost after retries (generationId: ${generationId})`, - ), - ); - return; - } - - let creditsConsumed = Math.round(costInUsd * CREDITS_PER_USD); - - let user = await getUserByMatrixUserId(dbAdapter, matrixUserId); - - if (!user) { - throw new Error( - `should not happen: user with matrix id ${matrixUserId} not found in the users table`, - ); - } - - await spendCredits(dbAdapter, user.id, creditsConsumed); - } catch (err) { - log.error( - `Failed to track AI usage (matrixUserId: ${matrixUserId}, generationId: ${generationId}):`, - err, - ); - Sentry.captureException(err); - // Don't throw, because we don't want to crash the application over this - } -} - -async function fetchGenerationCostWithBackoff( +export async function fetchGenerationCostWithBackoff( generationId: string, openRouterApiKey: string, ): Promise { @@ -202,7 +159,6 @@ async function fetchGenerationCost( }, ); - // 404 means generation data probably isn't available yet - return null to trigger retry if (response.status === 404) { return null; } @@ -224,24 +180,3 @@ async function fetchGenerationCost( return data.data.total_cost; } - -export function extractGenerationIdFromResponse( - response: any, -): string | undefined { - // OpenRouter responses typically include a generation_id in the response - // This might be in different places depending on the endpoint - if (response.id) { - return response.id; - } - - if (response.choices && response.choices[0] && response.choices[0].id) { - return response.choices[0].id; - } - - // For chat completions, the generation ID might be in usage - if (response.usage && response.usage.generation_id) { - return response.usage.generation_id; - } - - return undefined; -} diff --git a/packages/realm-server/handlers/handle-request-forward.ts b/packages/realm-server/handlers/handle-request-forward.ts index 84971433bd9..b3dda87d842 100644 --- a/packages/realm-server/handlers/handle-request-forward.ts +++ b/packages/realm-server/handlers/handle-request-forward.ts @@ -57,6 +57,7 @@ async function handleStreamingRequest( if (!reader) throw new Error('No readable stream available'); let generationId: string | undefined; + let costInUsd: number | undefined; let lastPing = Date.now(); await proxySSE( @@ -64,26 +65,25 @@ async function handleStreamingRequest( async (data) => { // Handle end of stream if (data === '[DONE]') { - if (generationId) { - // Save cost in the background so we don't block the stream on OpenRouter's generation cost API. - // Chain per-user promises so costs are recorded sequentially. - const previousPromise = - pendingCostPromises.get(matrixUserId) ?? Promise.resolve(); - const costPromise = previousPromise - .then(() => - endpointConfig.creditStrategy.saveUsageCost( - dbAdapter, - matrixUserId, - { id: generationId }, - ), - ) - .finally(() => { - if (pendingCostPromises.get(matrixUserId) === costPromise) { - pendingCostPromises.delete(matrixUserId); - } - }); - pendingCostPromises.set(matrixUserId, costPromise); - } + // Deduct credits using the cost from the streaming response. + // Chain per-user promises so costs are recorded sequentially. + const previousPromise = + pendingCostPromises.get(matrixUserId) ?? Promise.resolve(); + const costPromise = previousPromise + .then(() => + endpointConfig.creditStrategy.saveUsageCost( + dbAdapter, + matrixUserId, + { id: generationId, usage: { cost: costInUsd } }, + ), + ) + .finally(() => { + if (pendingCostPromises.get(matrixUserId) === costPromise) { + pendingCostPromises.delete(matrixUserId); + } + }); + pendingCostPromises.set(matrixUserId, costPromise); + ctxt.res.write(`data: [DONE]\n\n`); return 'stop'; } @@ -95,6 +95,10 @@ async function handleStreamingRequest( if (!generationId && dataObj.id) { generationId = dataObj.id; } + + if (dataObj.usage?.cost != null) { + costInUsd = dataObj.usage.cost; + } } catch { log.warn('Invalid JSON in streaming response:', data); } @@ -499,46 +503,22 @@ export default function handleRequestForward({ const responseData = await externalResponse.json(); - // 6. Deduct credits in the background using the cost from the response, - // or fall back to saveUsageCost when the cost is not provided. - const costInUsd = responseData?.usage?.cost; + // 6. Deduct credits in the background using the cost from the response. const previousPromise = pendingCostPromises.get(matrixUserId) ?? Promise.resolve(); - let costPromise: Promise; - - if ( - typeof costInUsd === 'number' && - Number.isFinite(costInUsd) && - costInUsd > 0 - ) { - costPromise = previousPromise - .then(() => - destinationConfig.creditStrategy.spendUsageCost( - dbAdapter, - matrixUserId, - costInUsd, - ), - ) - .finally(() => { - if (pendingCostPromises.get(matrixUserId) === costPromise) { - pendingCostPromises.delete(matrixUserId); - } - }); - } else { - costPromise = previousPromise - .then(() => - destinationConfig.creditStrategy.saveUsageCost( - dbAdapter, - matrixUserId, - responseData, - ), - ) - .finally(() => { - if (pendingCostPromises.get(matrixUserId) === costPromise) { - pendingCostPromises.delete(matrixUserId); - } - }); - } + const costPromise = previousPromise + .then(() => + destinationConfig.creditStrategy.saveUsageCost( + dbAdapter, + matrixUserId, + responseData, + ), + ) + .finally(() => { + if (pendingCostPromises.get(matrixUserId) === costPromise) { + pendingCostPromises.delete(matrixUserId); + } + }); pendingCostPromises.set(matrixUserId, costPromise); // 7. Return response diff --git a/packages/realm-server/lib/credit-strategies.ts b/packages/realm-server/lib/credit-strategies.ts index e97c90bdc02..cfd1ffdc6b1 100644 --- a/packages/realm-server/lib/credit-strategies.ts +++ b/packages/realm-server/lib/credit-strategies.ts @@ -1,14 +1,16 @@ import { type DBAdapter, MINIMUM_AI_CREDITS_TO_CONTINUE, + logger, } from '@cardstack/runtime-common'; import { validateAICredits, - extractGenerationIdFromResponse, - saveUsageCost as saveUsageCostFromBilling, spendUsageCost as spendUsageCostFromBilling, + fetchGenerationCostWithBackoff, } from '@cardstack/billing/ai-billing'; +const log = logger('credit-strategies'); + export interface CreditStrategy { name: string; validateCredits( @@ -24,11 +26,6 @@ export interface CreditStrategy { matrixUserId: string, response: any, ): Promise; - spendUsageCost( - dbAdapter: DBAdapter, - matrixUserId: string, - costInUsd: number, - ): Promise; } // Default AI Bot Credit Strategy (reused from AI bot) @@ -58,24 +55,34 @@ export class OpenRouterCreditStrategy implements CreditStrategy { matrixUserId: string, response: any, ): Promise { - const generationId = extractGenerationIdFromResponse(response); + const costInUsd = response?.usage?.cost; + if ( + typeof costInUsd === 'number' && + Number.isFinite(costInUsd) && + costInUsd > 0 + ) { + await spendUsageCostFromBilling(dbAdapter, matrixUserId, costInUsd); + return; + } + + const generationId = response?.id; if (generationId) { - await saveUsageCostFromBilling( - dbAdapter, - matrixUserId, + log.info( + `No inline cost for user ${matrixUserId}, falling back to generation cost API (generationId: ${generationId})`, + ); + const fetchedCost = await fetchGenerationCostWithBackoff( generationId, this.openRouterApiKey, ); + if (fetchedCost !== null) { + await spendUsageCostFromBilling(dbAdapter, matrixUserId, fetchedCost); + } + } else { + log.warn( + `No usage cost and no generation ID in response for user ${matrixUserId}, skipping credit deduction`, + ); } } - - async spendUsageCost( - dbAdapter: DBAdapter, - matrixUserId: string, - costInUsd: number, - ): Promise { - await spendUsageCostFromBilling(dbAdapter, matrixUserId, costInUsd); - } } // No Credit Strategy (for free endpoints) @@ -96,14 +103,6 @@ export class NoCreditStrategy implements CreditStrategy { ): Promise { // No-op for no-credit strategy } - - async spendUsageCost( - _dbAdapter: DBAdapter, - _matrixUserId: string, - _costInUsd: number, - ): Promise { - // No-op for no-credit strategy - } } // Credit Strategy Factory diff --git a/packages/realm-server/tests/request-forward-test.ts b/packages/realm-server/tests/request-forward-test.ts index 19785176d62..fb74ce989af 100644 --- a/packages/realm-server/tests/request-forward-test.ts +++ b/packages/realm-server/tests/request-forward-test.ts @@ -257,23 +257,23 @@ module(basename(__filename), function () { ); }); - test('should handle streaming requests', async function (assert) { + test('should handle streaming requests and deduct credits from inline cost', async function (assert) { // Mock external fetch calls const originalFetch = global.fetch; const mockFetch = sinon.stub(global, 'fetch'); - // Mock streaming response + // Mock streaming response with usage.cost in the final data chunk const mockStreamResponse = new Response( new ReadableStream({ start(controller) { controller.enqueue( new TextEncoder().encode( - 'data: {"id":"gen-stream-123","choices":[{"text":"Hello"}]}\n\n', + 'data: {"id":"gen-stream-123","choices":[{"delta":{"content":"Hello"}}]}\n\n', ), ); controller.enqueue( new TextEncoder().encode( - 'data: {"choices":[{"text":" world"}]}\n\n', + 'data: {"choices":[{"delta":{"content":" world"}}],"usage":{"prompt_tokens":10,"completion_tokens":5,"cost":0.002}}\n\n', ), ); controller.enqueue(new TextEncoder().encode('data: [DONE]\n\n')); @@ -286,27 +286,12 @@ module(basename(__filename), function () { }, ); - // Mock generation cost API response - const mockCostResponse = { - data: { - id: 'gen-stream-123', - total_cost: 0.002, - total_tokens: 100, - model: 'openai/gpt-3.5-turbo', - }, - }; - - // Set up fetch to return different responses based on URL + // Set up fetch to return streaming response (no generation cost API mock needed) mockFetch.callsFake( async (input: string | URL | Request, _init?: RequestInit) => { const url = typeof input === 'string' ? input : input.toString(); - if (url.includes('/generation?id=')) { - return new Response(JSON.stringify(mockCostResponse), { - status: 200, - headers: { 'content-type': 'application/json' }, - }); - } else if (url.includes('/chat/completions')) { + if (url.includes('/chat/completions')) { return mockStreamResponse; } else { return new Response(JSON.stringify({ error: 'Not found' }), { @@ -339,8 +324,6 @@ module(basename(__filename), function () { // Verify streaming response headers assert.strictEqual(response.status, 200, 'Should return 200 status'); - // Note: content-type header is not captured by supertest for streaming responses - // because it's sent immediately with flushHeaders(), but we can verify other SSE headers assert.strictEqual( response.headers['cache-control'], 'no-cache, no-store, must-revalidate', @@ -364,13 +347,130 @@ module(basename(__filename), function () { 'Should include first streaming data', ); assert.true( - responseText.includes('data: {"choices":[{"text":" world"}]}'), - 'Should include second streaming data', + responseText.includes('data: [DONE]'), + 'Should include end of stream marker', + ); + + // Verify credits were deducted from inline cost (0.002 USD * 1000 = 2 credits) + const user = await getUserByMatrixUserId( + dbAdapter, + '@testuser:localhost', ); + await waitUntil( + async () => { + const credits = await sumUpCreditsLedger(dbAdapter, { + creditType: ['extra_credit', 'extra_credit_used'], + userId: user!.id, + }); + return credits === 48; + }, + { timeoutMessage: 'Credits should be deducted (50 - 2 = 48)' }, + ); + } finally { + mockFetch.restore(); + global.fetch = originalFetch; + } + }); + + test('should fall back to generation cost API when inline cost is missing', async function (assert) { + // Mock streaming response WITHOUT usage.cost (simulates cancelled stream or missing cost) + const originalFetch = global.fetch; + const mockFetch = sinon.stub(global, 'fetch'); + + const mockStreamResponse = new Response( + new ReadableStream({ + start(controller) { + controller.enqueue( + new TextEncoder().encode( + 'data: {"id":"gen-no-cost-456","choices":[{"delta":{"content":"Hello"}}]}\n\n', + ), + ); + // No usage.cost in any chunk + controller.enqueue(new TextEncoder().encode('data: [DONE]\n\n')); + controller.close(); + }, + }), + { + status: 200, + headers: { + 'content-type': 'text/event-stream', + }, + }, + ); + + // Mock generation cost API response (fallback) + const mockCostResponse = { + data: { + id: 'gen-no-cost-456', + total_cost: 0.003, + }, + }; + + mockFetch.callsFake( + async (input: string | URL | Request, _init?: RequestInit) => { + const url = typeof input === 'string' ? input : input.toString(); + + if (url.includes('/generation?id=')) { + return new Response(JSON.stringify(mockCostResponse), { + status: 200, + headers: { 'content-type': 'application/json' }, + }); + } else if (url.includes('/chat/completions')) { + return mockStreamResponse; + } else { + return new Response(JSON.stringify({ error: 'Not found' }), { + status: 404, + }); + } + }, + ); + + try { + const jwt = createRealmServerJWT( + { user: '@testuser:localhost', sessionRoom: 'test-session-room' }, + realmSecretSeed, + ); + + const response = await request + .post('/_request-forward') + .set('Accept', 'text/event-stream') + .set('Content-Type', 'application/json') + .set('Authorization', `Bearer ${jwt}`) + .send({ + url: 'https://openrouter.ai/api/v1/chat/completions', + method: 'POST', + requestBody: JSON.stringify({ + model: 'openai/gpt-3.5-turbo', + messages: [{ role: 'user', content: 'Hello' }], + stream: true, + }), + stream: true, + }); + + assert.strictEqual(response.status, 200, 'Should return 200 status'); assert.true( - responseText.includes('data: [DONE]'), + response.text.includes('data: [DONE]'), 'Should include end of stream marker', ); + + // Verify credits were deducted via fallback (0.003 USD * 1000 = 3 credits) + const user = await getUserByMatrixUserId( + dbAdapter, + '@testuser:localhost', + ); + await waitUntil( + async () => { + const credits = await sumUpCreditsLedger(dbAdapter, { + creditType: ['extra_credit', 'extra_credit_used'], + userId: user!.id, + }); + return credits === 47; + }, + { + timeoutMessage: + 'Credits should be deducted via fallback (50 - 3 = 47)', + }, + ); } finally { mockFetch.restore(); global.fetch = originalFetch;