Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 68 additions & 32 deletions packages/ai-bot/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ import {
import type { MatrixEvent as DiscreteMatrixEvent } from 'https://cardstack.com/base/matrix-event';
import * as Sentry from '@sentry/node';

import { saveUsageCost } from '@cardstack/billing/ai-billing';
import {
spendUsageCost,
fetchGenerationCostWithBackoff,
} from '@cardstack/billing/ai-billing';
import { PgAdapter } from '@cardstack/postgres';
import type { ChatCompletionMessageParam } from 'openai/resources';
import { APIUserAbortError } from 'openai/error';
import type { OpenAIError } from 'openai/error';
import type { ChatCompletionStream } from 'openai/lib/ChatCompletionStream';
import { acquireRoomLock, releaseRoomLock } from './lib/queries';
Expand Down Expand Up @@ -86,22 +90,45 @@ class Assistant {
this.aiBotInstanceId = aiBotInstanceId;
}

async trackAiUsageCost(matrixUserId: string, generationId: string) {
async trackAiUsageCost(
matrixUserId: string,
opts: { costInUsd?: number; generationId?: string },
) {
if (trackAiUsageCostPromises.has(matrixUserId)) {
return;
}
// intentionally do not await saveUsageCost promise - it has a backoff mechanism to retry if the cost is not immediately available so we don't want to block the main thread
trackAiUsageCostPromises.set(
matrixUserId,
saveUsageCost(
this.pgAdapter,
matrixUserId,
generationId,
process.env.OPENROUTER_API_KEY!,
).finally(() => {
trackAiUsageCostPromises.delete(matrixUserId);
}),
);
const promise = (async () => {
let { costInUsd, generationId } = opts;
if (
typeof costInUsd === 'number' &&
Number.isFinite(costInUsd) &&
costInUsd > 0
) {
await spendUsageCost(this.pgAdapter, matrixUserId, costInUsd);
} else if (generationId) {
log.info(
`No inline cost for user ${matrixUserId}, falling back to generation cost API (generationId: ${generationId})`,
);
const fetchedCost = await fetchGenerationCostWithBackoff(
generationId,
process.env.OPENROUTER_API_KEY!,
);
if (fetchedCost !== null) {
await spendUsageCost(this.pgAdapter, matrixUserId, fetchedCost);
} else {
const message = `Failed to fetch generation cost for user ${matrixUserId} (generationId: ${generationId}), credit deduction skipped`;
log.error(message);
Sentry.captureMessage(message, 'error');
}
} else {
log.warn(
`No usage cost and no generation ID for user ${matrixUserId}, skipping credit deduction`,
);
}
})().finally(() => {
trackAiUsageCostPromises.delete(matrixUserId);
});
trackAiUsageCostPromises.set(matrixUserId, promise);
}

getResponse(prompt: PromptParts, senderMatrixUserId?: string) {
Expand Down Expand Up @@ -284,16 +311,9 @@ Common issues are:
event.getType() === 'm.room.message')
) {
activeGeneration.runner.abort();
await activeGeneration.responder.finalize({
isCanceled: true,
});
if (activeGeneration.lastGeneratedChunkId) {
await assistant.trackAiUsageCost(
senderMatrixUserId,
activeGeneration.lastGeneratedChunkId,
);
}
activeGenerations.delete(room.roomId);
// Finalization, credit tracking, and cleanup are all
// handled by the streaming code path's catch/finally
// blocks after the APIUserAbortError is thrown.
}

if (isShuttingDown()) {
Expand Down Expand Up @@ -448,6 +468,7 @@ Common issues are:

let chunkHandlingError: string | undefined;
let generationId: string | undefined;
let costInUsd: number | undefined;
log.info(
`[${eventId}] Starting generation with model %s`,
promptParts.model,
Expand All @@ -471,6 +492,9 @@ Common issues are:
});
}
generationId = chunk.id;
if (chunk.usage && (chunk.usage as any).cost != null) {
costInUsd = (chunk.usage as any).cost;
}
let activeGeneration = activeGenerations.get(room.roomId);
if (activeGeneration) {
activeGeneration.lastGeneratedChunkId = generationId;
Expand Down Expand Up @@ -517,17 +541,29 @@ Common issues are:
);
log.info(`[${eventId}] Response finalized`);
} catch (error) {
log.error(`[${eventId}] Error during generation or finalization`);
log.error(error);
if (chunkHandlingError) {
await responder.onError(chunkHandlingError); // E.g. MatrixError: [413] event too large
// When the cancel handler aborts the runner,
// finalChatCompletion() throws APIUserAbortError.
// Finalize the responder with the canceled flag and let
// the finally block handle credit tracking.
if (error instanceof APIUserAbortError) {
log.info(`[${eventId}] Generation was canceled by user`);
await responder.finalize({ isCanceled: true });
} else {
await responder.onError(error as OpenAIError);
log.error(`[${eventId}] Error during generation or finalization`);
log.error(error);
if (chunkHandlingError) {
await responder.onError(chunkHandlingError); // E.g. MatrixError: [413] event too large
} else {
await responder.onError(error as OpenAIError);
}
}
} finally {
if (generationId) {
assistant.trackAiUsageCost(senderMatrixUserId, generationId);
}
// Always track cost here — this path has the best data
// (both costInUsd from inline chunks and generationId).
assistant.trackAiUsageCost(senderMatrixUserId, {
costInUsd,
generationId,
});
activeGenerations.delete(room.roomId);
}

Expand Down
53 changes: 47 additions & 6 deletions packages/base/card-api.gts
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ function cardTypeFor(
.constructor as typeof BaseDef;
}

function assertNoDeserializeOverride(cardClass: typeof BaseDef) {
if (
!(primitive in cardClass) &&
Object.prototype.hasOwnProperty.call(cardClass, deserialize)
) {
throw new Error(
`${cardClass.name} overrides [deserialize] directly. Composite fields must use a registered fieldSerializer instead.`,
);
}
}

class ContainsMany<FieldT extends FieldDefConstructor> implements Field<
FieldT,
any[] | null
Expand Down Expand Up @@ -723,6 +734,17 @@ class ContainsMany<FieldT extends FieldDefConstructor> implements Field<
}
return entry;
} else {
if (fieldSerializer in this.card) {
assertIsSerializerName(this.card[fieldSerializer]);
let serializer = getSerializer(this.card[fieldSerializer]);
entry = await serializer.deserialize(
entry,
relativeTo,
doc,
store,
opts,
);
}
let meta = metas[index];
let resource: LooseCardResource = {
attributes: entry,
Expand All @@ -745,9 +767,19 @@ class ContainsMany<FieldT extends FieldDefConstructor> implements Field<
}),
);
}
return (
await cardClassFromResource(resource, this.card, relativeTo)
)[deserialize](resource, relativeTo, doc, store, opts);
let cardClass = await cardClassFromResource(
resource,
this.card,
relativeTo,
);
assertNoDeserializeOverride(cardClass);
return cardClass[deserialize](
resource,
relativeTo,
doc,
store,
opts,
);
}
}),
),
Expand Down Expand Up @@ -959,6 +991,11 @@ class Contains<CardT extends FieldDefConstructor> implements Field<CardT, any> {
}
return value;
}
if (fieldSerializer in this.card) {
assertIsSerializerName(this.card[fieldSerializer]);
let serializer = getSerializer(this.card[fieldSerializer]);
value = await serializer.deserialize(value, relativeTo, doc, store, opts);
}
if (fieldMeta && Array.isArray(fieldMeta)) {
throw new Error(
`fieldMeta for contains field '${
Expand All @@ -983,9 +1020,13 @@ class Contains<CardT extends FieldDefConstructor> implements Field<CardT, any> {
]),
);
}
return (await cardClassFromResource(resource, this.card, relativeTo))[
deserialize
](resource, relativeTo, doc, store, opts);
let cardClass = await cardClassFromResource(
resource,
this.card,
relativeTo,
);
assertNoDeserializeOverride(cardClass);
return cardClass[deserialize](resource, relativeTo, doc, store, opts);
}

emptyValue(_instance: BaseDef) {
Expand Down
2 changes: 2 additions & 0 deletions packages/base/rich-markdown.gts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
extractCardReferenceUrls,
fieldSerializer,
relativeTo,
} from '@cardstack/runtime-common';

Expand Down Expand Up @@ -35,6 +36,7 @@ import MarkdownTemplate from './default-templates/markdown';
*/
export class RichMarkdownField extends FieldDef {
static displayName = 'Rich Markdown';
static [fieldSerializer] = 'string-to-content' as const;

/** The raw markdown text. Uses MarkdownField for textarea edit UI. */
@field content = contains(MarkdownField);
Expand Down
67 changes: 1 addition & 66 deletions packages/billing/ai-billing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,50 +109,7 @@ export async function spendUsageCost(
}
}

export async function saveUsageCost(
dbAdapter: DBAdapter,
matrixUserId: string,
generationId: string,
openRouterApiKey: string,
) {
try {
// Generation data is sometimes not immediately available, so we retry a couple of times until we are able to get the cost
let costInUsd = await fetchGenerationCostWithBackoff(
generationId,
openRouterApiKey,
);

if (costInUsd === null) {
Sentry.captureException(
new Error(
`Failed to fetch generation cost after retries (generationId: ${generationId})`,
),
);
return;
}

let creditsConsumed = Math.round(costInUsd * CREDITS_PER_USD);

let user = await getUserByMatrixUserId(dbAdapter, matrixUserId);

if (!user) {
throw new Error(
`should not happen: user with matrix id ${matrixUserId} not found in the users table`,
);
}

await spendCredits(dbAdapter, user.id, creditsConsumed);
} catch (err) {
log.error(
`Failed to track AI usage (matrixUserId: ${matrixUserId}, generationId: ${generationId}):`,
err,
);
Sentry.captureException(err);
// Don't throw, because we don't want to crash the application over this
}
}

async function fetchGenerationCostWithBackoff(
export async function fetchGenerationCostWithBackoff(
generationId: string,
openRouterApiKey: string,
): Promise<number | null> {
Expand Down Expand Up @@ -202,7 +159,6 @@ async function fetchGenerationCost(
},
);

// 404 means generation data probably isn't available yet - return null to trigger retry
if (response.status === 404) {
return null;
}
Expand All @@ -224,24 +180,3 @@ async function fetchGenerationCost(

return data.data.total_cost;
}

export function extractGenerationIdFromResponse(
response: any,
): string | undefined {
// OpenRouter responses typically include a generation_id in the response
// This might be in different places depending on the endpoint
if (response.id) {
return response.id;
}

if (response.choices && response.choices[0] && response.choices[0].id) {
return response.choices[0].id;
}

// For chat completions, the generation ID might be in usage
if (response.usage && response.usage.generation_id) {
return response.usage.generation_id;
}

return undefined;
}
Loading
Loading