Skip to content

Commit 9614a28

Browse files
committed
feat(knowledge): add embedding model selection and Cohere reranker
1 parent 6081670 commit 9614a28

23 files changed

Lines changed: 753 additions & 147 deletions

File tree

apps/docs/content/docs/en/tools/knowledge.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ Search for similar content in a knowledge base using vector similarity
4747
| `properties` | string | No | No description |
4848
| `tagName` | string | No | No description |
4949
| `tagValue` | string | No | No description |
50+
| `rerankerEnabled` | boolean | No | Whether to apply Cohere reranking to vector search results |
51+
| `rerankerModel` | string | No | Cohere rerank model to use \(one of: $\{SUPPORTED_RERANKER_MODELS.join\(', '\)\}\) |
5052
| `tagFilters` | string | No | No description |
5153

5254
#### Output

apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,11 @@ export const POST = withRouteHandler(
213213
accessCheck.knowledgeBase?.workspaceId
214214
)
215215

216+
const chunkEmbeddingModel =
217+
accessCheck.knowledgeBase?.embeddingModel ?? 'text-embedding-3-small'
216218
let cost = null
217219
try {
218-
cost = calculateCost('text-embedding-3-small', newChunk.tokenCount, 0, false)
220+
cost = calculateCost(chunkEmbeddingModel, newChunk.tokenCount, 0, false)
219221
} catch (error) {
220222
logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, {
221223
error: error instanceof Error ? error.message : 'Unknown error',
@@ -240,7 +242,7 @@ export const POST = withRouteHandler(
240242
completion: 0,
241243
total: newChunk.tokenCount,
242244
},
243-
model: 'text-embedding-3-small',
245+
model: chunkEmbeddingModel,
244246
pricing: cost.pricing,
245247
},
246248
}

apps/sim/app/api/knowledge/route.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ import { getSession } from '@/lib/auth'
66
import { PlatformEvents } from '@/lib/core/telemetry'
77
import { generateRequestId } from '@/lib/core/utils/request'
88
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
9+
import {
10+
DEFAULT_EMBEDDING_MODEL,
11+
EMBEDDING_DIMENSIONS,
12+
SUPPORTED_EMBEDDING_MODEL_IDS,
13+
} from '@/lib/knowledge/embeddings'
914
import {
1015
createKnowledgeBase,
1116
getKnowledgeBases,
@@ -20,8 +25,10 @@ const CreateKnowledgeBaseSchema = z.object({
2025
name: z.string().min(1, 'Name is required'),
2126
description: z.string().optional(),
2227
workspaceId: z.string().min(1, 'Workspace ID is required'),
23-
embeddingModel: z.literal('text-embedding-3-small').default('text-embedding-3-small'),
24-
embeddingDimension: z.literal(1536).default(1536),
28+
embeddingModel: z
29+
.enum(SUPPORTED_EMBEDDING_MODEL_IDS as [string, ...string[]])
30+
.default(DEFAULT_EMBEDDING_MODEL),
31+
embeddingDimension: z.literal(EMBEDDING_DIMENSIONS).default(EMBEDDING_DIMENSIONS),
2532
chunkingConfig: z
2633
.object({
2734
maxSize: z.number().min(100).max(4000).default(1024),

apps/sim/app/api/knowledge/search/route.ts

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { PlatformEvents } from '@/lib/core/telemetry'
77
import { generateRequestId } from '@/lib/core/utils/request'
88
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
99
import { ALL_TAG_SLOTS } from '@/lib/knowledge/constants'
10+
import { DEFAULT_RERANKER_MODEL, rerank, SUPPORTED_RERANKER_MODELS } from '@/lib/knowledge/reranker'
1011
import { getDocumentTagDefinitions } from '@/lib/knowledge/tags/service'
1112
import { buildUndefinedTagsError, validateTagValue } from '@/lib/knowledge/tags/utils'
1213
import type { StructuredFilter } from '@/lib/knowledge/types'
@@ -21,6 +22,7 @@ import {
2122
type SearchResult,
2223
} from '@/app/api/knowledge/search/utils'
2324
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
25+
import { getRerankModelPricing } from '@/providers/models'
2426
import { calculateCost } from '@/providers/utils'
2527

2628
const logger = createLogger('VectorSearchAPI')
@@ -59,6 +61,11 @@ const VectorSearchSchema = z
5961
.optional()
6062
.nullable()
6163
.transform((val) => val || undefined),
64+
rerankerEnabled: z.boolean().optional().default(false),
65+
rerankerModel: z
66+
.enum(SUPPORTED_RERANKER_MODELS as unknown as [string, ...string[]])
67+
.optional()
68+
.default(DEFAULT_RERANKER_MODEL),
6269
})
6370
.refine(
6471
(data) => {
@@ -235,11 +242,40 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
235242
)
236243
}
237244

238-
const workspaceId = accessChecks.find((ac) => ac?.hasAccess)?.knowledgeBase?.workspaceId
245+
const accessibleKbs = accessChecks
246+
.filter(
247+
(
248+
ac
249+
): ac is {
250+
hasAccess: true
251+
knowledgeBase: {
252+
id: string
253+
embeddingModel: string
254+
workspaceId?: string | null
255+
}
256+
} => Boolean(ac?.hasAccess)
257+
)
258+
.map((ac) => ac.knowledgeBase)
259+
const workspaceId = accessibleKbs[0]?.workspaceId
260+
261+
const useReranker = validatedData.rerankerEnabled && Boolean(validatedData.query?.trim())
262+
const rerankerModel = useReranker ? validatedData.rerankerModel : null
263+
264+
const embeddingModels = Array.from(new Set(accessibleKbs.map((kb) => kb.embeddingModel)))
265+
if (embeddingModels.length > 1) {
266+
return NextResponse.json(
267+
{
268+
error:
269+
'Selected knowledge bases use different embedding models and cannot be searched together. Search them separately.',
270+
},
271+
{ status: 400 }
272+
)
273+
}
274+
const queryEmbeddingModel = embeddingModels[0]
239275

240276
const hasQuery = validatedData.query && validatedData.query.trim().length > 0
241277
const queryEmbeddingPromise = hasQuery
242-
? generateSearchEmbedding(validatedData.query!, undefined, workspaceId)
278+
? generateSearchEmbedding(validatedData.query!, queryEmbeddingModel, workspaceId)
243279
: Promise.resolve(null)
244280

245281
// Check if any requested knowledge bases were not accessible
@@ -278,6 +314,10 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
278314

279315
const hasFilters = structuredFilters && structuredFilters.length > 0
280316

317+
// Oversample candidates when reranking so the reranker has more to choose from.
318+
// Cap at 100 to bound Cohere request cost (1 search unit = ≤100 docs).
319+
const candidateTopK = useReranker ? Math.min(100, validatedData.topK * 4) : validatedData.topK
320+
281321
if (!hasQuery && hasFilters) {
282322
// Tag-only search without vector similarity
283323
results = await handleTagOnlySearch({
@@ -291,24 +331,24 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
291331
`[${requestId}] Executing tag + vector search with filters:`,
292332
structuredFilters
293333
)
294-
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
334+
const strategy = getQueryStrategy(accessibleKbIds.length, candidateTopK)
295335
const queryVector = JSON.stringify(await queryEmbeddingPromise)
296336

297337
results = await handleTagAndVectorSearch({
298338
knowledgeBaseIds: accessibleKbIds,
299-
topK: validatedData.topK,
339+
topK: candidateTopK,
300340
structuredFilters,
301341
queryVector,
302342
distanceThreshold: strategy.distanceThreshold,
303343
})
304344
} else if (hasQuery && !hasFilters) {
305345
// Vector-only search
306-
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
346+
const strategy = getQueryStrategy(accessibleKbIds.length, candidateTopK)
307347
const queryVector = JSON.stringify(await queryEmbeddingPromise)
308348

309349
results = await handleVectorOnlySearch({
310350
knowledgeBaseIds: accessibleKbIds,
311-
topK: validatedData.topK,
351+
topK: candidateTopK,
312352
queryVector,
313353
distanceThreshold: strategy.distanceThreshold,
314354
})
@@ -323,13 +363,54 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
323363
)
324364
}
325365

366+
// Optional Cohere rerank pass on top of vector results.
367+
const rerankedScores = new Map<string, number>()
368+
let rerankApplied = false
369+
if (useReranker && rerankerModel && results.length > 0) {
370+
const candidateCount = results.length
371+
try {
372+
const ranked = await rerank(
373+
validatedData.query!,
374+
results.map((r) => ({ id: r.id, text: r.content })),
375+
{ model: rerankerModel, topN: validatedData.topK, workspaceId }
376+
)
377+
if (ranked.length === 0) {
378+
logger.warn(
379+
`[${requestId}] Reranker returned 0 results; falling back to vector ordering`,
380+
{ model: rerankerModel, candidateCount }
381+
)
382+
results = results.slice(0, validatedData.topK)
383+
} else {
384+
const idToResult = new Map(results.map((r) => [r.id, r]))
385+
results = ranked
386+
.map((r) => idToResult.get(r.item.id))
387+
.filter((r): r is SearchResult => Boolean(r))
388+
for (const r of ranked) rerankedScores.set(r.item.id, r.relevanceScore)
389+
rerankApplied = true
390+
logger.info(`[${requestId}] Reranked ${candidateCount}${results.length} results`, {
391+
model: rerankerModel,
392+
})
393+
}
394+
} catch (error) {
395+
logger.warn(`[${requestId}] Reranker failed; falling back to vector ordering`, {
396+
error: error instanceof Error ? error.message : 'Unknown error',
397+
model: rerankerModel,
398+
candidateCount,
399+
workspaceId,
400+
})
401+
results = results.slice(0, validatedData.topK)
402+
}
403+
} else if (useReranker) {
404+
results = results.slice(0, validatedData.topK)
405+
}
406+
326407
// Calculate cost for the embedding (with fallback if calculation fails)
327408
let cost = null
328409
let tokenCount = null
329410
if (hasQuery) {
330411
try {
331412
tokenCount = estimateTokenCount(validatedData.query!, 'openai')
332-
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
413+
cost = calculateCost(queryEmbeddingModel, tokenCount.count, 0, false)
333414
} catch (error) {
334415
logger.warn(`[${requestId}] Failed to calculate cost for search query`, {
335416
error: error instanceof Error ? error.message : 'Unknown error',
@@ -338,6 +419,31 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
338419
}
339420
}
340421

422+
// Add Cohere rerank cost (1 search unit per call, since we cap candidates ≤100).
423+
let rerankerCost = 0
424+
if (rerankApplied && rerankerModel) {
425+
const pricing = getRerankModelPricing(rerankerModel)
426+
if (pricing) {
427+
rerankerCost = pricing.perSearchUnit
428+
if (cost) {
429+
cost = {
430+
...cost,
431+
input: cost.input + rerankerCost,
432+
total: cost.total + rerankerCost,
433+
}
434+
} else {
435+
cost = {
436+
input: rerankerCost,
437+
output: 0,
438+
total: rerankerCost,
439+
pricing: { input: 0, output: 0, updatedAt: pricing.updatedAt },
440+
}
441+
}
442+
} else {
443+
logger.warn(`[${requestId}] No pricing entry for rerank model ${rerankerModel}`)
444+
}
445+
}
446+
341447
// Fetch tag definitions for display name mapping (reuse the same fetch from filtering)
342448
const tagDefsResults = await Promise.all(
343449
accessibleKbIds.map(async (kbId) => {
@@ -400,33 +506,36 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
400506
}
401507
})
402508

509+
const rerankerScore = rerankedScores.get(result.id)
403510
return {
404511
documentId: result.documentId,
405512
documentName: documentNameMap[result.documentId] || undefined,
406513
content: result.content,
407514
chunkIndex: result.chunkIndex,
408515
metadata: tags, // Clean display name mapped tags
409516
similarity: hasQuery ? 1 - result.distance : 1, // Perfect similarity for tag-only searches
517+
...(rerankerScore !== undefined && { rerankerScore }),
410518
}
411519
}),
412520
query: validatedData.query || '',
413521
knowledgeBaseIds: accessibleKbIds,
414522
knowledgeBaseId: accessibleKbIds[0],
415523
topK: validatedData.topK,
416524
totalResults: results.length,
417-
...(cost && tokenCount
525+
...(cost
418526
? {
419527
cost: {
420528
input: cost.input,
421529
output: cost.output,
422530
total: cost.total,
423531
tokens: {
424-
prompt: tokenCount.count,
532+
prompt: tokenCount?.count ?? 0,
425533
completion: 0,
426-
total: tokenCount.count,
534+
total: tokenCount?.count ?? 0,
427535
},
428-
model: 'text-embedding-3-small',
536+
model: queryEmbeddingModel,
429537
pricing: cost.pricing,
538+
...(rerankApplied ? { rerankerCost, rerankerModel, rerankerSearchUnits: 1 } : {}),
430539
},
431540
}
432541
: {}),

apps/sim/app/api/knowledge/utils.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ export interface EmbeddingData {
103103

104104
export interface KnowledgeBaseAccessResult {
105105
hasAccess: true
106-
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId' | 'name'>
106+
knowledgeBase: Pick<
107+
KnowledgeBaseData,
108+
'id' | 'userId' | 'workspaceId' | 'name' | 'embeddingModel'
109+
>
107110
}
108111

109112
export interface KnowledgeBaseAccessDenied {
@@ -117,7 +120,10 @@ export type KnowledgeBaseAccessCheck = KnowledgeBaseAccessResult | KnowledgeBase
117120
export interface DocumentAccessResult {
118121
hasAccess: true
119122
document: DocumentData
120-
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId' | 'name'>
123+
knowledgeBase: Pick<
124+
KnowledgeBaseData,
125+
'id' | 'userId' | 'workspaceId' | 'name' | 'embeddingModel'
126+
>
121127
}
122128

123129
export interface DocumentAccessDenied {
@@ -132,7 +138,10 @@ export interface ChunkAccessResult {
132138
hasAccess: true
133139
chunk: EmbeddingData
134140
document: DocumentData
135-
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId' | 'name'>
141+
knowledgeBase: Pick<
142+
KnowledgeBaseData,
143+
'id' | 'userId' | 'workspaceId' | 'name' | 'embeddingModel'
144+
>
136145
}
137146

138147
export interface ChunkAccessDenied {
@@ -156,6 +165,7 @@ export async function checkKnowledgeBaseAccess(
156165
userId: knowledgeBase.userId,
157166
workspaceId: knowledgeBase.workspaceId,
158167
name: knowledgeBase.name,
168+
embeddingModel: knowledgeBase.embeddingModel,
159169
})
160170
.from(knowledgeBase)
161171
.where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt)))
@@ -200,6 +210,7 @@ export async function checkKnowledgeBaseWriteAccess(
200210
userId: knowledgeBase.userId,
201211
workspaceId: knowledgeBase.workspaceId,
202212
name: knowledgeBase.name,
213+
embeddingModel: knowledgeBase.embeddingModel,
203214
})
204215
.from(knowledgeBase)
205216
.where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt)))

apps/sim/app/api/v1/knowledge/route.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ import { AuditAction, AuditResourceType, recordAudit } from '@sim/audit'
22
import { type NextRequest, NextResponse } from 'next/server'
33
import { z } from 'zod'
44
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
5+
import {
6+
DEFAULT_EMBEDDING_MODEL,
7+
EMBEDDING_DIMENSIONS,
8+
SUPPORTED_EMBEDDING_MODEL_IDS,
9+
} from '@/lib/knowledge/embeddings'
510
import { createKnowledgeBase, getKnowledgeBases } from '@/lib/knowledge/service'
611
import {
712
authenticateRequest,
@@ -29,6 +34,9 @@ const CreateKBSchema = z.object({
2934
workspaceId: z.string().min(1, 'Workspace ID is required'),
3035
name: z.string().min(1, 'Name is required').max(255, 'Name must be 255 characters or less'),
3136
description: z.string().max(1000, 'Description must be 1000 characters or less').optional(),
37+
embeddingModel: z
38+
.enum(SUPPORTED_EMBEDDING_MODEL_IDS as [string, ...string[]])
39+
.default(DEFAULT_EMBEDDING_MODEL),
3240
chunkingConfig: ChunkingConfigSchema.optional().default({
3341
maxSize: 1024,
3442
minSize: 100,
@@ -81,7 +89,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
8189
const validation = validateSchema(CreateKBSchema, body.data)
8290
if (!validation.success) return validation.response
8391

84-
const { workspaceId, name, description, chunkingConfig } = validation.data
92+
const { workspaceId, name, description, embeddingModel, chunkingConfig } = validation.data
8593

8694
const accessError = await validateWorkspaceAccess(rateLimit, userId, workspaceId, 'write')
8795
if (accessError) return accessError
@@ -92,8 +100,8 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
92100
description,
93101
workspaceId,
94102
userId,
95-
embeddingModel: 'text-embedding-3-small',
96-
embeddingDimension: 1536,
103+
embeddingModel,
104+
embeddingDimension: EMBEDDING_DIMENSIONS,
97105
chunkingConfig: chunkingConfig ?? { maxSize: 1024, minSize: 100, overlap: 200 },
98106
},
99107
requestId

0 commit comments

Comments
 (0)