@@ -7,6 +7,7 @@ import { PlatformEvents } from '@/lib/core/telemetry'
77import { generateRequestId } from '@/lib/core/utils/request'
88import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
99import { ALL_TAG_SLOTS } from '@/lib/knowledge/constants'
10+ import { DEFAULT_RERANKER_MODEL , rerank , SUPPORTED_RERANKER_MODELS } from '@/lib/knowledge/reranker'
1011import { getDocumentTagDefinitions } from '@/lib/knowledge/tags/service'
1112import { buildUndefinedTagsError , validateTagValue } from '@/lib/knowledge/tags/utils'
1213import type { StructuredFilter } from '@/lib/knowledge/types'
@@ -21,6 +22,7 @@ import {
2122 type SearchResult ,
2223} from '@/app/api/knowledge/search/utils'
2324import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
25+ import { getRerankModelPricing } from '@/providers/models'
2426import { calculateCost } from '@/providers/utils'
2527
2628const logger = createLogger ( 'VectorSearchAPI' )
@@ -59,6 +61,11 @@ const VectorSearchSchema = z
5961 . optional ( )
6062 . nullable ( )
6163 . transform ( ( val ) => val || undefined ) ,
64+ rerankerEnabled : z . boolean ( ) . optional ( ) . default ( false ) ,
65+ rerankerModel : z
66+ . enum ( SUPPORTED_RERANKER_MODELS as unknown as [ string , ...string [ ] ] )
67+ . optional ( )
68+ . default ( DEFAULT_RERANKER_MODEL ) ,
6269 } )
6370 . refine (
6471 ( data ) => {
@@ -235,11 +242,40 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
235242 )
236243 }
237244
238- const workspaceId = accessChecks . find ( ( ac ) => ac ?. hasAccess ) ?. knowledgeBase ?. workspaceId
245+ const accessibleKbs = accessChecks
246+ . filter (
247+ (
248+ ac
249+ ) : ac is {
250+ hasAccess : true
251+ knowledgeBase : {
252+ id : string
253+ embeddingModel : string
254+ workspaceId ?: string | null
255+ }
256+ } => Boolean ( ac ?. hasAccess )
257+ )
258+ . map ( ( ac ) => ac . knowledgeBase )
259+ const workspaceId = accessibleKbs [ 0 ] ?. workspaceId
260+
261+ const useReranker = validatedData . rerankerEnabled && Boolean ( validatedData . query ?. trim ( ) )
262+ const rerankerModel = useReranker ? validatedData . rerankerModel : null
263+
264+ const embeddingModels = Array . from ( new Set ( accessibleKbs . map ( ( kb ) => kb . embeddingModel ) ) )
265+ if ( embeddingModels . length > 1 ) {
266+ return NextResponse . json (
267+ {
268+ error :
269+ 'Selected knowledge bases use different embedding models and cannot be searched together. Search them separately.' ,
270+ } ,
271+ { status : 400 }
272+ )
273+ }
274+ const queryEmbeddingModel = embeddingModels [ 0 ]
239275
240276 const hasQuery = validatedData . query && validatedData . query . trim ( ) . length > 0
241277 const queryEmbeddingPromise = hasQuery
242- ? generateSearchEmbedding ( validatedData . query ! , undefined , workspaceId )
278+ ? generateSearchEmbedding ( validatedData . query ! , queryEmbeddingModel , workspaceId )
243279 : Promise . resolve ( null )
244280
245281 // Check if any requested knowledge bases were not accessible
@@ -278,6 +314,10 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
278314
279315 const hasFilters = structuredFilters && structuredFilters . length > 0
280316
317+ // Oversample candidates when reranking so the reranker has more to choose from.
318+ // Cap at 100 to bound Cohere request cost (1 search unit = ≤100 docs).
319+ const candidateTopK = useReranker ? Math . min ( 100 , validatedData . topK * 4 ) : validatedData . topK
320+
281321 if ( ! hasQuery && hasFilters ) {
282322 // Tag-only search without vector similarity
283323 results = await handleTagOnlySearch ( {
@@ -291,24 +331,24 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
291331 `[${ requestId } ] Executing tag + vector search with filters:` ,
292332 structuredFilters
293333 )
294- const strategy = getQueryStrategy ( accessibleKbIds . length , validatedData . topK )
334+ const strategy = getQueryStrategy ( accessibleKbIds . length , candidateTopK )
295335 const queryVector = JSON . stringify ( await queryEmbeddingPromise )
296336
297337 results = await handleTagAndVectorSearch ( {
298338 knowledgeBaseIds : accessibleKbIds ,
299- topK : validatedData . topK ,
339+ topK : candidateTopK ,
300340 structuredFilters,
301341 queryVector,
302342 distanceThreshold : strategy . distanceThreshold ,
303343 } )
304344 } else if ( hasQuery && ! hasFilters ) {
305345 // Vector-only search
306- const strategy = getQueryStrategy ( accessibleKbIds . length , validatedData . topK )
346+ const strategy = getQueryStrategy ( accessibleKbIds . length , candidateTopK )
307347 const queryVector = JSON . stringify ( await queryEmbeddingPromise )
308348
309349 results = await handleVectorOnlySearch ( {
310350 knowledgeBaseIds : accessibleKbIds ,
311- topK : validatedData . topK ,
351+ topK : candidateTopK ,
312352 queryVector,
313353 distanceThreshold : strategy . distanceThreshold ,
314354 } )
@@ -323,13 +363,54 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
323363 )
324364 }
325365
366+ // Optional Cohere rerank pass on top of vector results.
367+ const rerankedScores = new Map < string , number > ( )
368+ let rerankApplied = false
369+ if ( useReranker && rerankerModel && results . length > 0 ) {
370+ const candidateCount = results . length
371+ try {
372+ const ranked = await rerank (
373+ validatedData . query ! ,
374+ results . map ( ( r ) => ( { id : r . id , text : r . content } ) ) ,
375+ { model : rerankerModel , topN : validatedData . topK , workspaceId }
376+ )
377+ if ( ranked . length === 0 ) {
378+ logger . warn (
379+ `[${ requestId } ] Reranker returned 0 results; falling back to vector ordering` ,
380+ { model : rerankerModel , candidateCount }
381+ )
382+ results = results . slice ( 0 , validatedData . topK )
383+ } else {
384+ const idToResult = new Map ( results . map ( ( r ) => [ r . id , r ] ) )
385+ results = ranked
386+ . map ( ( r ) => idToResult . get ( r . item . id ) )
387+ . filter ( ( r ) : r is SearchResult => Boolean ( r ) )
388+ for ( const r of ranked ) rerankedScores . set ( r . item . id , r . relevanceScore )
389+ rerankApplied = true
390+ logger . info ( `[${ requestId } ] Reranked ${ candidateCount } → ${ results . length } results` , {
391+ model : rerankerModel ,
392+ } )
393+ }
394+ } catch ( error ) {
395+ logger . warn ( `[${ requestId } ] Reranker failed; falling back to vector ordering` , {
396+ error : error instanceof Error ? error . message : 'Unknown error' ,
397+ model : rerankerModel ,
398+ candidateCount,
399+ workspaceId,
400+ } )
401+ results = results . slice ( 0 , validatedData . topK )
402+ }
403+ } else if ( useReranker ) {
404+ results = results . slice ( 0 , validatedData . topK )
405+ }
406+
326407 // Calculate cost for the embedding (with fallback if calculation fails)
327408 let cost = null
328409 let tokenCount = null
329410 if ( hasQuery ) {
330411 try {
331412 tokenCount = estimateTokenCount ( validatedData . query ! , 'openai' )
332- cost = calculateCost ( 'text-embedding-3-small' , tokenCount . count , 0 , false )
413+ cost = calculateCost ( queryEmbeddingModel , tokenCount . count , 0 , false )
333414 } catch ( error ) {
334415 logger . warn ( `[${ requestId } ] Failed to calculate cost for search query` , {
335416 error : error instanceof Error ? error . message : 'Unknown error' ,
@@ -338,6 +419,31 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
338419 }
339420 }
340421
422+ // Add Cohere rerank cost (1 search unit per call, since we cap candidates ≤100).
423+ let rerankerCost = 0
424+ if ( rerankApplied && rerankerModel ) {
425+ const pricing = getRerankModelPricing ( rerankerModel )
426+ if ( pricing ) {
427+ rerankerCost = pricing . perSearchUnit
428+ if ( cost ) {
429+ cost = {
430+ ...cost ,
431+ input : cost . input + rerankerCost ,
432+ total : cost . total + rerankerCost ,
433+ }
434+ } else {
435+ cost = {
436+ input : rerankerCost ,
437+ output : 0 ,
438+ total : rerankerCost ,
439+ pricing : { input : 0 , output : 0 , updatedAt : pricing . updatedAt } ,
440+ }
441+ }
442+ } else {
443+ logger . warn ( `[${ requestId } ] No pricing entry for rerank model ${ rerankerModel } ` )
444+ }
445+ }
446+
341447 // Fetch tag definitions for display name mapping (reuse the same fetch from filtering)
342448 const tagDefsResults = await Promise . all (
343449 accessibleKbIds . map ( async ( kbId ) => {
@@ -400,33 +506,36 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
400506 }
401507 } )
402508
509+ const rerankerScore = rerankedScores . get ( result . id )
403510 return {
404511 documentId : result . documentId ,
405512 documentName : documentNameMap [ result . documentId ] || undefined ,
406513 content : result . content ,
407514 chunkIndex : result . chunkIndex ,
408515 metadata : tags , // Clean display name mapped tags
409516 similarity : hasQuery ? 1 - result . distance : 1 , // Perfect similarity for tag-only searches
517+ ...( rerankerScore !== undefined && { rerankerScore } ) ,
410518 }
411519 } ) ,
412520 query : validatedData . query || '' ,
413521 knowledgeBaseIds : accessibleKbIds ,
414522 knowledgeBaseId : accessibleKbIds [ 0 ] ,
415523 topK : validatedData . topK ,
416524 totalResults : results . length ,
417- ...( cost && tokenCount
525+ ...( cost
418526 ? {
419527 cost : {
420528 input : cost . input ,
421529 output : cost . output ,
422530 total : cost . total ,
423531 tokens : {
424- prompt : tokenCount . count ,
532+ prompt : tokenCount ? .count ?? 0 ,
425533 completion : 0 ,
426- total : tokenCount . count ,
534+ total : tokenCount ? .count ?? 0 ,
427535 } ,
428- model : 'text-embedding-3-small' ,
536+ model : queryEmbeddingModel ,
429537 pricing : cost . pricing ,
538+ ...( rerankApplied ? { rerankerCost, rerankerModel, rerankerSearchUnits : 1 } : { } ) ,
430539 } ,
431540 }
432541 : { } ) ,
0 commit comments