Skip to content

Commit 4490e07

Browse files
committed
fix billing account details for kb embeddings
1 parent bd6cf41 commit 4490e07

7 files changed

Lines changed: 179 additions & 20 deletions

File tree

apps/sim/app/api/auth/sso/providers/route.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ const logger = createLogger('SSOProvidersRoute')
1313
export const GET = withRouteHandler(async (request: NextRequest) => {
1414
try {
1515
const session = await getSession()
16-
const parsed = await parseRequest(listSsoProvidersContract, request, undefined)
16+
const parsed = await parseRequest(listSsoProvidersContract, request, {})
1717
if (!parsed.success) return parsed.response
1818
const { organizationId } = parsed.data.query
1919

apps/sim/app/api/wand/route.ts

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-f
1414
import { generateRequestId } from '@/lib/core/utils/request'
1515
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
1616
import { enrichTableSchema } from '@/lib/table/llm/wand'
17+
import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
1718
import { verifyWorkspaceMembership } from '@/app/api/workflows/utils'
1819
import { extractResponseText, parseResponsesUsage } from '@/providers/openai/utils'
1920
import { getModelPricing } from '@/providers/utils'
@@ -86,7 +87,8 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
8687
}
8788

8889
async function updateUserStatsForWand(
89-
userId: string,
90+
billingUserId: string,
91+
workspaceId: string | null,
9092
usage: {
9193
prompt_tokens?: number
9294
completion_tokens?: number
@@ -128,7 +130,8 @@ async function updateUserStatsForWand(
128130
}
129131

130132
await recordUsage({
131-
userId,
133+
userId: billingUserId,
134+
workspaceId: workspaceId ?? undefined,
132135
entries: [
133136
{
134137
category: 'model',
@@ -143,7 +146,7 @@ async function updateUserStatsForWand(
143146
},
144147
})
145148

146-
await checkAndBillOverageThreshold(userId)
149+
await checkAndBillOverageThreshold(billingUserId)
147150
} catch (error) {
148151
logger.error(`[${requestId}] Failed to update user stats for wand usage`, error)
149152
}
@@ -223,6 +226,21 @@ export const POST = withRouteHandler(async (req: NextRequest) => {
223226
}
224227
}
225228

229+
let billingUserId = session.user.id
230+
if (workspaceId) {
231+
const workspaceBilledAccountUserId = await getWorkspaceBilledAccountUserId(workspaceId)
232+
if (!workspaceBilledAccountUserId) {
233+
logger.error(`[${requestId}] Unable to resolve billed account for workspace`, {
234+
workspaceId,
235+
})
236+
return NextResponse.json(
237+
{ success: false, error: 'Unable to resolve billing account for this workspace' },
238+
{ status: 500 }
239+
)
240+
}
241+
billingUserId = workspaceBilledAccountUserId
242+
}
243+
226244
let isBYOK = false
227245
let activeOpenAIKey = openaiApiKey
228246

@@ -339,7 +357,13 @@ export const POST = withRouteHandler(async (req: NextRequest) => {
339357
}
340358

341359
usageRecorded = true
342-
await updateUserStatsForWand(session.user.id, finalUsage, requestId, isBYOK)
360+
await updateUserStatsForWand(
361+
billingUserId,
362+
workspaceId,
363+
finalUsage,
364+
requestId,
365+
isBYOK
366+
)
343367
}
344368

345369
try {
@@ -556,7 +580,8 @@ export const POST = withRouteHandler(async (req: NextRequest) => {
556580
const usage = parseResponsesUsage(completion.usage)
557581
if (usage) {
558582
await updateUserStatsForWand(
559-
session.user.id,
583+
billingUserId,
584+
workspaceId,
560585
{
561586
prompt_tokens: usage.promptTokens,
562587
completion_tokens: usage.completionTokens,

apps/sim/app/workspace/[workspaceId]/w/hooks/use-import-workspace.ts

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ import { createLogger } from '@sim/logger'
33
import { generateId } from '@sim/utils/id'
44
import { useRouter } from 'next/navigation'
55
import { requestJson } from '@/lib/api/client/request'
6-
import type { ContractBodyInput } from '@/lib/api/contracts'
76
import {
87
createWorkflowContract,
98
createWorkspaceContract,
109
putWorkflowNormalizedStateContract,
10+
type WorkflowStateContractInput,
1111
workflowVariablesContract,
1212
} from '@/lib/api/contracts'
1313
import {
@@ -186,10 +186,31 @@ export function useImportWorkspace({ onSuccess }: UseImportWorkspaceProps = {})
186186
continue
187187
}
188188

189+
type ContractEdgeInput = WorkflowStateContractInput['edges'][number]
190+
191+
const sanitizedEdges: ContractEdgeInput[] = (workflowData.edges || []).map((edge) => {
192+
const { sourceHandle, targetHandle, ...rest } = edge
193+
const sanitized: ContractEdgeInput = { ...rest } as ContractEdgeInput
194+
if (typeof sourceHandle === 'string' && sourceHandle.length > 0) {
195+
sanitized.sourceHandle = sourceHandle
196+
}
197+
if (typeof targetHandle === 'string' && targetHandle.length > 0) {
198+
sanitized.targetHandle = targetHandle
199+
}
200+
return sanitized
201+
})
202+
203+
const workflowStateBody: WorkflowStateContractInput = {
204+
...workflowData,
205+
loops: workflowData.loops || {},
206+
parallels: workflowData.parallels || {},
207+
edges: sanitizedEdges,
208+
}
209+
189210
try {
190211
await requestJson(putWorkflowNormalizedStateContract, {
191212
params: { id: newWorkflow.id },
192-
body: workflowData as ContractBodyInput<typeof putWorkflowNormalizedStateContract>,
213+
body: workflowStateBody,
193214
})
194215
} catch (error) {
195216
logger.error(`Failed to save workflow state for ${newWorkflow.id}`, { error })

apps/sim/ee/audit-logs/hooks/audit-logs.ts

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
import { keepPreviousData, useInfiniteQuery } from '@tanstack/react-query'
2-
import type { z } from 'zod'
32
import { requestJson } from '@/lib/api/client/request'
4-
import type { ContractJsonResponse } from '@/lib/api/contracts'
5-
import {
6-
type enterpriseAuditLogEntrySchema,
7-
listAuditLogsContract,
8-
} from '@/lib/api/contracts/audit-logs'
3+
import { type AuditLogPage, listAuditLogsContract } from '@/lib/api/contracts/audit-logs'
94

105
export const auditLogKeys = {
116
all: ['audit-logs'] as const,
@@ -22,9 +17,6 @@ export interface AuditLogFilters {
2217
endDate?: string
2318
}
2419

25-
export type EnterpriseAuditLogEntry = z.output<typeof enterpriseAuditLogEntrySchema>
26-
type AuditLogPage = ContractJsonResponse<typeof listAuditLogsContract>
27-
2820
async function fetchAuditLogs(
2921
filters: AuditLogFilters,
3022
cursor?: string,

apps/sim/lib/api/contracts/audit-logs.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,16 @@ export const enterpriseAuditLogEntrySchema = z.object({
9999
createdAt: z.string(),
100100
})
101101

102+
export type EnterpriseAuditLogEntry = z.output<typeof enterpriseAuditLogEntrySchema>
103+
102104
export const listAuditLogsResponseSchema = z.object({
103105
success: z.boolean(),
104106
data: z.array(enterpriseAuditLogEntrySchema),
105107
nextCursor: z.string().optional(),
106108
})
107109

110+
export type AuditLogPage = z.output<typeof listAuditLogsResponseSchema>
111+
108112
export const listAuditLogsContract = defineRouteContract({
109113
method: 'GET',
110114
path: '/api/audit-logs',

apps/sim/lib/knowledge/documents/service.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import type { ProcessedDocumentTags } from '@/lib/knowledge/types'
4747
import { estimateTokenCount } from '@/lib/tokenization/estimators'
4848
import { deleteFile } from '@/lib/uploads/core/storage-service'
4949
import { extractStorageKey } from '@/lib/uploads/utils/file-utils'
50+
import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
5051
import type { DocumentProcessingPayload } from '@/background/knowledge-processing'
5152
import { calculateCost } from '@/providers/utils'
5253

@@ -433,6 +434,13 @@ export async function processDocumentAsync(
433434
}
434435

435436
const kbEmbeddingModel = kb[0].embeddingModel
437+
if (!kb[0].workspaceId) {
438+
throw new Error(`Knowledge base ${knowledgeBaseId} is missing workspace billing context`)
439+
}
440+
const billingUserId = await getWorkspaceBilledAccountUserId(kb[0].workspaceId)
441+
if (!billingUserId) {
442+
throw new Error(`Workspace ${kb[0].workspaceId} is missing billed account`)
443+
}
436444
let totalEmbeddingTokens = 0
437445
let embeddingIsBYOK = false
438446
let embeddingModelName = kbEmbeddingModel
@@ -625,7 +633,7 @@ export async function processDocumentAsync(
625633
const processingTime = Date.now() - startTime
626634
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)
627635

628-
if (!embeddingIsBYOK && totalEmbeddingTokens > 0 && kb[0].userId) {
636+
if (!embeddingIsBYOK && totalEmbeddingTokens > 0 && billingUserId) {
629637
try {
630638
const costMultiplier = getCostMultiplier()
631639
const { total: cost } = calculateCost(
@@ -637,7 +645,7 @@ export async function processDocumentAsync(
637645
)
638646
if (cost > 0) {
639647
await recordUsage({
640-
userId: kb[0].userId,
648+
userId: billingUserId,
641649
workspaceId: kb[0].workspaceId ?? undefined,
642650
entries: [
643651
{
@@ -652,7 +660,7 @@ export async function processDocumentAsync(
652660
totalTokensUsed: sql`total_tokens_used + ${totalEmbeddingTokens}`,
653661
},
654662
})
655-
await checkAndBillOverageThreshold(kb[0].userId)
663+
await checkAndBillOverageThreshold(billingUserId)
656664
} else {
657665
logger.warn(
658666
`[${documentId}] Embedding model "${embeddingModelName}" has no pricing entry — billing skipped`,

0 commit comments

Comments
 (0)