diff --git a/backend/src/__tests__/cost-aggregation.integration.test.ts b/backend/src/__tests__/cost-aggregation.integration.test.ts index 673343c3fd..10f77415a4 100644 --- a/backend/src/__tests__/cost-aggregation.integration.test.ts +++ b/backend/src/__tests__/cost-aggregation.integration.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { spyOn, @@ -12,12 +12,12 @@ import { } from 'bun:test' import * as messageCostTracker from '../llm-apis/message-cost-tracker' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { mainPrompt } from '../main-prompt' import * as agentRegistry from '../templates/agent-registry' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' @@ -99,8 +99,10 @@ class MockWebSocket { describe('Cost Aggregation Integration Tests', () => { let mockLocalAgentTemplates: Record let mockWebSocket: MockWebSocket + let agentRuntimeImpl: AgentRuntimeDeps beforeEach(async () => { + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } mockWebSocket = new MockWebSocket() // Setup mock agent templates @@ -150,33 +152,31 @@ describe('Cost Aggregation Integration Tests', () => { // Mock LLM streaming let callCount = 0 const creditHistory: number[] = [] - spyOn(aisdk, 'promptAiSdkStream').mockImplementation( - async function* (options) { - callCount++ - const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs - creditHistory.push(credits) - - if (options.onCostCalculated) { - await options.onCostCalculated(credits) - } + agentRuntimeImpl.promptAiSdkStream = async function* (options) { + callCount++ + const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs + creditHistory.push(credits) - // Simulate different responses based on call - if (callCount === 1) { - // Main agent spawns a subagent - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n', - } - } else { - // Subagent writes a file - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n', - } + if (options.onCostCalculated) { + await options.onCostCalculated(credits) + } + + // Simulate different responses based on call + if (callCount === 1) { + // Main agent spawns a subagent + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n', } - return 'mock-message-id' - }, - ) + } else { + // Subagent writes a file + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n', + } + } + return 'mock-message-id' + } // Mock tool call execution spyOn(websocketAction, 'requestToolCall').mockImplementation( @@ -250,7 +250,7 @@ describe('Cost Aggregation Integration Tests', () => { } const result = await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -285,7 +285,7 @@ describe('Cost Aggregation Integration Tests', () => { // Call through websocket action handler to test full integration await websocketAction.callMainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -308,37 +308,35 @@ describe('Cost Aggregation Integration Tests', () => { it('should handle multi-level subagent hierarchies correctly', async () => { // Mock a more complex scenario with nested subagents let callCount = 0 - spyOn(aisdk, 'promptAiSdkStream').mockImplementation( - async function* (options) { - callCount++ + agentRuntimeImpl.promptAiSdkStream = async function* (options) { + callCount++ - if (options.onCostCalculated) { - await options.onCostCalculated(5) // Each call costs 5 credits - } + if (options.onCostCalculated) { + await options.onCostCalculated(5) // Each call costs 5 credits + } - if (callCount === 1) { - // Main agent spawns first-level subagent - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n', - } - } else if (callCount === 2) { - // First-level subagent spawns second-level subagent - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n', - } - } else { - // Second-level subagent does actual work - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n', - } + if (callCount === 1) { + // Main agent spawns first-level subagent + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n', + } + } else if (callCount === 2) { + // First-level subagent spawns second-level subagent + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n', } + } else { + // Second-level subagent does actual work + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n', + } + } - return 'mock-message-id' - }, - ) + return 'mock-message-id' + } const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.stepsRemaining = 10 @@ -355,7 +353,7 @@ describe('Cost Aggregation Integration Tests', () => { } const result = await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -373,29 +371,27 @@ describe('Cost Aggregation Integration Tests', () => { it('should maintain cost integrity when subagents fail', async () => { // Mock scenario where subagent fails after incurring partial costs let callCount = 0 - spyOn(aisdk, 'promptAiSdkStream').mockImplementation( - async function* (options) { - callCount++ + agentRuntimeImpl.promptAiSdkStream = async function* (options) { + callCount++ - if (options.onCostCalculated) { - await options.onCostCalculated(6) // Each call costs 6 credits - } + if (options.onCostCalculated) { + await options.onCostCalculated(6) // Each call costs 6 credits + } - if (callCount === 1) { - // Main agent spawns subagent - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n', - } - } else { - // Subagent fails after incurring cost - yield { type: 'text' as const, text: 'Some response' } - throw new Error('Subagent execution failed') + if (callCount === 1) { + // Main agent spawns subagent + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n', } + } else { + // Subagent fails after incurring cost + yield { type: 'text' as const, text: 'Some response' } + throw new Error('Subagent execution failed') + } - return 'mock-message-id' - }, - ) + return 'mock-message-id' + } const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.agentType = 'base' @@ -413,7 +409,7 @@ describe('Cost Aggregation Integration Tests', () => { let result try { result = await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -462,7 +458,7 @@ describe('Cost Aggregation Integration Tests', () => { } await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -502,7 +498,7 @@ describe('Cost Aggregation Integration Tests', () => { // Call through websocket action to test server-side reset await websocketAction.callMainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, diff --git a/backend/src/__tests__/cost-aggregation.test.ts b/backend/src/__tests__/cost-aggregation.test.ts index e79d8a4057..bef5ab372d 100644 --- a/backend/src/__tests__/cost-aggregation.test.ts +++ b/backend/src/__tests__/cost-aggregation.test.ts @@ -1,4 +1,4 @@ -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialAgentState, getInitialSessionState, @@ -180,7 +180,7 @@ describe('Cost Aggregation System', () => { } const result = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall: mockToolCall, fileContext: mockFileContext, @@ -260,7 +260,7 @@ describe('Cost Aggregation System', () => { } const result = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall: mockToolCall, fileContext: mockFileContext, @@ -417,7 +417,7 @@ describe('Cost Aggregation System', () => { } const result = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall: mockToolCall, fileContext: mockFileContext, diff --git a/backend/src/__tests__/fast-rewrite.test.ts b/backend/src/__tests__/fast-rewrite.test.ts index 19e50abf90..8db17d40b9 100644 --- a/backend/src/__tests__/fast-rewrite.test.ts +++ b/backend/src/__tests__/fast-rewrite.test.ts @@ -1,23 +1,19 @@ import path from 'path' import { TEST_USER_ID } from '@codebuff/common/old-constants' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { clearMockedModules, mockModule, } from '@codebuff/common/testing/mock-modules' -import { afterAll, beforeAll, describe, expect, it } from 'bun:test' +import { afterAll, beforeAll, beforeEach, describe, expect, it } from 'bun:test' import { createPatch } from 'diff' import { rewriteWithOpenAI } from '../fast-rewrite' -import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -const logger: Logger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {}, -} +let agentRuntimeImpl: AgentRuntimeDeps describe.skip('rewriteWithOpenAI', () => { beforeAll(() => { @@ -42,6 +38,10 @@ describe.skip('rewriteWithOpenAI', () => { })) }) + beforeEach(() => { + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } + }) + afterAll(() => { clearMockedModules() }) @@ -53,15 +53,13 @@ describe.skip('rewriteWithOpenAI', () => { const expectedResult = await Bun.file(`${testDataDir}/expected.go`).text() const result = await rewriteWithOpenAI({ + ...agentRuntimeImpl, oldContent: originalContent, editSnippet, - filePath: 'taskruntoolcall.go', clientSessionId: 'clientSessionId', fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - userMessage: undefined, - logger, }) const patch = createPatch('test.ts', expectedResult, result) diff --git a/backend/src/__tests__/generate-diffs-prompt.test.ts b/backend/src/__tests__/generate-diffs-prompt.test.ts index 3095094dcd..e61ca1329f 100644 --- a/backend/src/__tests__/generate-diffs-prompt.test.ts +++ b/backend/src/__tests__/generate-diffs-prompt.test.ts @@ -1,16 +1,8 @@ +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { expect, describe, it } from 'bun:test' import { parseAndGetDiffBlocksSingleFile } from '../generate-diffs-prompt' -import type { Logger } from '@codebuff/common/types/contracts/logger' - -const logger: Logger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {}, -} - describe('parseAndGetDiffBlocksSingleFile', () => { it('should parse diff blocks with newline before closing marker', () => { const oldContent = 'function test() {\n return true;\n}\n' @@ -26,9 +18,9 @@ function test() { >>>>>>> REPLACE` const result = parseAndGetDiffBlocksSingleFile({ + ...TEST_AGENT_RUNTIME_IMPL, newContent, oldFileContent: oldContent, - logger, }) console.log(JSON.stringify({ result })) @@ -55,9 +47,9 @@ function test() { }>>>>>>> REPLACE` const result = parseAndGetDiffBlocksSingleFile({ + ...TEST_AGENT_RUNTIME_IMPL, newContent, oldFileContent: oldContent, - logger, }) expect(result.diffBlocks.length).toBe(1) @@ -108,9 +100,9 @@ function subtract(a, b) { >>>>>>> REPLACE` const result = parseAndGetDiffBlocksSingleFile({ + ...TEST_AGENT_RUNTIME_IMPL, newContent, oldFileContent: oldContent, - logger, }) expect(result.diffBlocks.length).toBe(2) @@ -136,9 +128,9 @@ function subtract(a, b) { >>>>>>> REPLACE` const result = parseAndGetDiffBlocksSingleFile({ + ...TEST_AGENT_RUNTIME_IMPL, newContent, oldFileContent: oldContent, - logger, }) expect(result.diffBlocks.length).toBe(1) diff --git a/backend/src/__tests__/loop-agent-steps.test.ts b/backend/src/__tests__/loop-agent-steps.test.ts index 379dec178a..226907c018 100644 --- a/backend/src/__tests__/loop-agent-steps.test.ts +++ b/backend/src/__tests__/loop-agent-steps.test.ts @@ -1,7 +1,7 @@ import * as analytics from '@codebuff/common/analytics' import db from '@codebuff/common/db' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { clearMockedModules, mockModule, @@ -25,15 +25,17 @@ import { withAppContext } from '../context/app-context' import { loopAgentSteps } from '../run-agent-step' import { clearAgentGeneratorCache } from '../run-programmatic-step' import { mockFileContext, MockWebSocket } from './test-utils' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import type { getAgentTemplate } from '../templates/agent-registry' import type { AgentTemplate } from '../templates/types' import type { StepGenerator } from '@codebuff/common/types/agent-template' -import type { AgentState } from '@codebuff/common/types/session-state' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { ParamsOf } from '@codebuff/common/types/function-params' +import type { AgentState } from '@codebuff/common/types/session-state' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => { let mockTemplate: AgentTemplate let mockAgentState: AgentState @@ -94,8 +96,6 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => }) beforeEach(() => { - clearAgentGeneratorCache(testAgentRuntimeImpl) - llmCallCount = 0 // Setup spies for database operations @@ -113,14 +113,14 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => })), } as any) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallCount++ yield { type: 'text' as const, text: `LLM response\n\n${getToolCallString('end_turn', {})}`, } return 'mock-message-id' - }) + } // Mock analytics spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) @@ -165,9 +165,9 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => }) afterEach(() => { - clearAgentGeneratorCache(testAgentRuntimeImpl) - + clearAgentGeneratorCache(agentRuntimeImpl) mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) afterAll(() => { @@ -199,7 +199,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -245,7 +245,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -293,7 +293,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -340,7 +340,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -380,7 +380,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -412,7 +412,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -446,7 +446,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -497,7 +497,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -546,7 +546,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -598,7 +598,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => let llmCallNumber = 0 let capturedAgentState: AgentState | null = null - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallNumber++ if (llmCallNumber === 1) { // First call: agent tries to end turn without setting output @@ -627,13 +627,13 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } } return 'mock-message-id' - }) + } mockAgentState.output = undefined capturedAgentState = mockAgentState const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -689,7 +689,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => let llmCallNumber = 0 let capturedAgentState: AgentState | null = null - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallNumber++ // Agent sets output correctly on first call if (capturedAgentState) { @@ -700,13 +700,13 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => text: `Setting output\n\n${getToolCallString('set_output', { result: 'success' })}\n\n${getToolCallString('end_turn', {})}`, } return 'mock-message-id' - }) + } mockAgentState.output = undefined capturedAgentState = mockAgentState const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -742,17 +742,17 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } let llmCallNumber = 0 - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallNumber++ yield { type: 'text' as const, text: `Response without output\n\n${getToolCallString('end_turn', {})}`, } return 'mock-message-id' - }) + } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -795,7 +795,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => let llmCallNumber = 0 let capturedAgentState: AgentState | null = null - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallNumber++ if (llmCallNumber === 1) { // First call: agent does some work but doesn't end turn @@ -814,13 +814,13 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } } return 'mock-message-id' - }) + } mockAgentState.output = undefined capturedAgentState = mockAgentState const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', diff --git a/backend/src/__tests__/main-prompt.integration.test.ts b/backend/src/__tests__/main-prompt.integration.test.ts index e8b433a3bd..75f20e6c03 100644 --- a/backend/src/__tests__/main-prompt.integration.test.ts +++ b/backend/src/__tests__/main-prompt.integration.test.ts @@ -1,7 +1,7 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' // Mock imports needed for setup within the test -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { @@ -15,16 +15,18 @@ import { } from 'bun:test' import * as checkTerminalCommandModule from '../check-terminal-command' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { mainPrompt } from '../main-prompt' import * as websocketAction from '../websockets/websocket-action' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' // --- Shared Mocks & Helpers --- +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + class MockWebSocket { send(msg: string) {} close() {} @@ -103,6 +105,7 @@ describe.skip('mainPrompt (Integration)', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) it('should delete a specified function while preserving other code', async () => { @@ -337,7 +340,9 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { ) // Mock LLM calls - spyOn(aisdk, 'promptAiSdk').mockResolvedValue('Mocked non-stream AiSdk') + agentRuntimeImpl.promptAiSdk = async function () { + return 'Mocked non-stream AiSdk' + } const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.messageHistory.push( @@ -377,7 +382,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { } const { output, sessionState: finalSessionState } = await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -419,7 +424,9 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { ).mockResolvedValue(null) // Mock LLM calls - spyOn(aisdk, 'promptAiSdk').mockResolvedValue('Mocked non-stream AiSdk') + agentRuntimeImpl.promptAiSdk = async function () { + return 'Mocked non-stream AiSdk' + } const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.messageHistory.push( @@ -459,7 +466,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { } await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, diff --git a/backend/src/__tests__/main-prompt.test.ts b/backend/src/__tests__/main-prompt.test.ts index 5dae356481..296c47907a 100644 --- a/backend/src/__tests__/main-prompt.test.ts +++ b/backend/src/__tests__/main-prompt.test.ts @@ -1,7 +1,7 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { AgentTemplateTypes, @@ -22,33 +22,30 @@ import * as checkTerminalCommandModule from '../check-terminal-command' import * as requestFilesPrompt from '../find-files/request-files-prompt' import * as getDocumentationForQueryModule from '../get-documentation-for-query' import * as liveUserInputs from '../live-user-inputs' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { mainPrompt } from '../main-prompt' import * as processFileBlockModule from '../process-file-block' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '@codebuff/common/types/agent-template' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { ProjectFileContext } from '@codebuff/common/util/file' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps + const mockAgentStream = (streamOutput: string) => { - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: streamOutput } return 'mock-message-id' - }) + } } describe('mainPrompt', () => { let mockLocalAgentTemplates: Record - const logger: Logger = { - debug: () => {}, - error: () => {}, - info: () => {}, - warn: () => {}, - } beforeEach(() => { + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } + // Setup common mock agent templates mockLocalAgentTemplates = { [AgentTemplateTypes.base]: { @@ -84,12 +81,10 @@ describe('mainPrompt', () => { stepPrompt: '', } satisfies AgentTemplate, } - }) - beforeEach(() => { // Mock analytics and tracing spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) - analytics.initAnalytics({ logger }) // Initialize the mock + analytics.initAnalytics(agentRuntimeImpl) // Initialize the mock spyOn(analytics, 'trackEvent').mockImplementation(() => {}) spyOn(bigquery, 'insertTrace').mockImplementation(() => Promise.resolve(true), @@ -109,9 +104,6 @@ describe('mainPrompt', () => { ) // Mock LLM APIs - spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) mockAgentStream('Test response') // Mock websocket actions @@ -231,13 +223,13 @@ describe('mainPrompt', () => { } const { sessionState: newSessionState, output } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) // Verify that requestToolCall was called with the terminal command @@ -292,6 +284,7 @@ describe('mainPrompt', () => { } await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -331,7 +324,6 @@ describe('mainPrompt', () => { stepPrompt: '', }, }, - ...testAgentRuntimeImpl, }) // Assert that requestToolCall was called exactly once @@ -371,13 +363,13 @@ describe('mainPrompt', () => { } const { output } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) expect(output.type).toBeDefined() // Output should exist @@ -398,13 +390,13 @@ describe('mainPrompt', () => { } const { sessionState: newSessionState } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) // When there's a new prompt, consecutiveAssistantMessages should be set to 1 @@ -429,13 +421,13 @@ describe('mainPrompt', () => { } const { sessionState: newSessionState } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) // When there's no new prompt, consecutiveAssistantMessages should increment by 1 @@ -458,13 +450,13 @@ describe('mainPrompt', () => { } const { output } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) expect(output.type).toBeDefined() // Output should exist even for empty response @@ -498,13 +490,13 @@ describe('mainPrompt', () => { } await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) // Assert that requestToolCall was called exactly once diff --git a/backend/src/__tests__/malformed-tool-call.test.ts b/backend/src/__tests__/malformed-tool-call.test.ts index bafc933a4a..62f19644fb 100644 --- a/backend/src/__tests__/malformed-tool-call.test.ts +++ b/backend/src/__tests__/malformed-tool-call.test.ts @@ -1,7 +1,7 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import * as stringUtils from '@codebuff/common/util/string' @@ -16,17 +16,19 @@ import { } from 'bun:test' import { MockWebSocket, mockFileContext } from './test-utils' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { processStreamWithTools } from '../tools/stream-parser' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { Message, ToolMessage, } from '@codebuff/common/types/messages/codebuff-message' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + describe('malformed tool call error handling', () => { let testAgent: AgentTemplate let mockWs: MockWebSocket @@ -53,7 +55,7 @@ describe('malformed tool call error handling', () => { // Mock analytics and tracing spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) - analytics.initAnalytics(testAgentRuntimeImpl) + analytics.initAnalytics(TEST_AGENT_RUNTIME_IMPL) spyOn(analytics, 'trackEvent').mockImplementation(() => {}) spyOn(bigquery, 'insertTrace').mockImplementation(() => Promise.resolve(true), @@ -72,9 +74,9 @@ describe('malformed tool call error handling', () => { })) // Mock LLM APIs - spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) + agentRuntimeImpl.promptAiSdk = async function () { + return 'Test response' + } // Mock generateCompactId for consistent test results spyOn(stringUtils, 'generateCompactId').mockReturnValue('test-tool-call-id') @@ -82,6 +84,7 @@ describe('malformed tool call error handling', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) function createMockStream(chunks: string[]) { @@ -108,7 +111,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -165,7 +168,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -212,7 +215,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -263,7 +266,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -316,7 +319,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -371,7 +374,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', diff --git a/backend/src/__tests__/process-file-block.test.ts b/backend/src/__tests__/process-file-block.test.ts index 0fcd13bfba..05e266d945 100644 --- a/backend/src/__tests__/process-file-block.test.ts +++ b/backend/src/__tests__/process-file-block.test.ts @@ -1,4 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { clearMockedModules, mockModule, @@ -9,14 +10,9 @@ import { applyPatch } from 'diff' import { processFileBlock } from '../process-file-block' -import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -const logger: Logger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {}, -} +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } describe('processFileBlockModule', () => { beforeAll(() => { @@ -74,6 +70,7 @@ describe('processFileBlockModule', () => { const expectedContent = 'function test() {\n return true;\n}' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(null), @@ -85,7 +82,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() @@ -111,6 +107,7 @@ describe('processFileBlockModule', () => { '}\r\n' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(oldContent), @@ -122,7 +119,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() @@ -144,6 +140,7 @@ describe('processFileBlockModule', () => { const newContent = 'function test() {\n return true;\n}\n' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(oldContent), @@ -155,7 +152,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() @@ -170,6 +166,7 @@ describe('processFileBlockModule', () => { const newContent = 'const x = 1;\r\nconst z = 3;\r\n' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(oldContent), @@ -181,7 +178,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() @@ -217,6 +213,7 @@ describe('processFileBlockModule', () => { '// ... existing code ...\nconst x = 1;\n// ... existing code ...' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(null), @@ -228,7 +225,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() diff --git a/backend/src/__tests__/prompt-caching-subagents.test.ts b/backend/src/__tests__/prompt-caching-subagents.test.ts index 7eb91f4921..c22db30775 100644 --- a/backend/src/__tests__/prompt-caching-subagents.test.ts +++ b/backend/src/__tests__/prompt-caching-subagents.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { spyOn, @@ -11,11 +11,11 @@ import { mock, } from 'bun:test' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { loopAgentSteps } from '../run-agent-step' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' @@ -56,6 +56,7 @@ class MockWebSocket { describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { let mockLocalAgentTemplates: Record let capturedMessages: Message[] = [] + let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } beforeEach(() => { capturedMessages = [] @@ -97,24 +98,22 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } // Mock LLM API to capture messages and end turn immediately - spyOn(aisdk, 'promptAiSdkStream').mockImplementation( - async function* (options) { - // Capture the messages sent to the LLM - capturedMessages = options.messages - - // Simulate immediate end turn - yield { - type: 'text' as const, - text: 'Test response', - } + agentRuntimeImpl.promptAiSdkStream = async function* (options) { + // Capture the messages sent to the LLM + capturedMessages = options.messages - if (options.onCostCalculated) { - await options.onCostCalculated(1) - } + // Simulate immediate end turn + yield { + type: 'text' as const, + text: 'Test response', + } - return 'mock-message-id' - }, - ) + if (options.onCostCalculated) { + await options.onCostCalculated(1) + } + + return 'mock-message-id' + } // Mock file operations spyOn(websocketAction, 'requestFiles').mockImplementation( @@ -147,6 +146,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) it('should inherit parent system prompt when inheritParentSystemPrompt is true', async () => { @@ -155,6 +155,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent first to establish system prompt const parentResult = await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -167,7 +168,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) // Capture parent's messages which include the system prompt @@ -189,6 +189,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -202,7 +203,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) // Verify child uses parent's system prompt @@ -238,6 +238,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent first const parentResult = await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -250,7 +251,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) const parentMessages = capturedMessages @@ -266,6 +266,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -279,7 +280,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) const childMessages = capturedMessages @@ -316,6 +316,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent first await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -328,7 +329,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) const parentMessages = capturedMessages @@ -347,6 +347,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -360,7 +361,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) const childMessages = capturedMessages @@ -423,6 +423,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent const parentResult = await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -435,7 +436,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) const parentMessages = capturedMessages @@ -451,6 +451,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -464,7 +465,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) const childMessages = capturedMessages @@ -505,6 +505,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent first with some message history const parentResult = await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -523,7 +524,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) const parentMessages = capturedMessages @@ -542,6 +542,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -555,7 +556,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) const childMessages = capturedMessages diff --git a/backend/src/__tests__/read-docs-tool.test.ts b/backend/src/__tests__/read-docs-tool.test.ts index f4620d271f..ee9f712b07 100644 --- a/backend/src/__tests__/read-docs-tool.test.ts +++ b/backend/src/__tests__/read-docs-tool.test.ts @@ -1,7 +1,7 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { @@ -20,15 +20,17 @@ import * as requestFilesPrompt from '../find-files/request-files-prompt' import * as liveUserInputs from '../live-user-inputs' import { MockWebSocket, mockFileContext } from './test-utils' import * as context7Api from '../llm-apis/context7-api' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { runAgentStep } from '../run-agent-step' import { assembleLocalAgentTemplates } from '../templates/agent-registry' import * as websocketAction from '../websockets/websocket-action' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + function mockAgentStream(content: string | string[]) { - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { if (typeof content === 'string') { content = [content] } @@ -36,7 +38,7 @@ function mockAgentStream(content: string | string[]) { yield { type: 'text' as const, text: chunk } } return 'mock-message-id' - }) + } } describe('read_docs tool with researcher agent', () => { @@ -56,7 +58,7 @@ describe('read_docs tool with researcher agent', () => { name: 'analytics.initAnalytics', spy: analyticsInitSpy, }) - analytics.initAnalytics(testAgentRuntimeImpl) + analytics.initAnalytics(agentRuntimeImpl) const trackEventSpy = spyOn(analytics, 'trackEvent').mockImplementation( () => {}, @@ -121,12 +123,6 @@ describe('read_docs tool with researcher agent', () => { spy: sendActionSpy, }) - // Mock LLM APIs - const promptAiSdkSpy = spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) - mockedFunctions.push({ name: 'aisdk.promptAiSdk', spy: promptAiSdkSpy }) - // Mock other required modules const requestRelevantFilesSpy = spyOn( requestFilesPrompt, @@ -177,6 +173,7 @@ describe('read_docs tool with researcher agent', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) // MockWebSocket and mockFileContext imported from test-utils @@ -229,28 +226,6 @@ describe('read_docs tool with researcher agent', () => { await expect(insertTraceMock?.spy()).resolves.toBe(true) }) - test('async generator mock completes properly', async () => { - // Test that our async generator mock properly completes - const mockResponse = 'test response' - - mockAgentStream(mockResponse) - - const generator = aisdk.promptAiSdkStream({} as any) - const results = [] - - // Consume the generator - for await (const value of generator) { - results.push(value) - } - - // Should have yielded exactly one value and then completed - expect(results).toEqual([{ type: 'text', text: mockResponse }]) - - // Generator should be done - const { done } = await generator.next() - expect(done).toBe(true) - }) - test('should successfully fetch documentation with basic query', async () => { const mockDocumentation = 'React is a JavaScript library for building user interfaces...' @@ -286,12 +261,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -359,12 +334,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -409,12 +384,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -476,12 +451,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -542,12 +517,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -609,12 +584,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, diff --git a/backend/src/__tests__/request-files-prompt.test.ts b/backend/src/__tests__/request-files-prompt.test.ts index d11ab037a2..56552574d5 100644 --- a/backend/src/__tests__/request-files-prompt.test.ts +++ b/backend/src/__tests__/request-files-prompt.test.ts @@ -1,4 +1,5 @@ import { finetunedVertexModels } from '@codebuff/common/old-constants' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { clearMockedModules, mockModule, @@ -18,10 +19,13 @@ import * as OriginalRequestFilesPromptModule from '../find-files/request-files-p import * as geminiWithFallbacksModule from '../llm-apis/gemini-with-fallbacks' import type { CostMode } from '@codebuff/common/old-constants' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { Mock } from 'bun:test' +let agentRuntimeImpl: AgentRuntimeDeps + describe('requestRelevantFiles', () => { const mockMessages: Message[] = [{ role: 'user', content: 'test prompt' }] const mockSystem = 'test system' @@ -58,12 +62,6 @@ describe('requestRelevantFiles', () => { const mockUserId = 'user1' const mockCostMode: CostMode = 'normal' const mockRepoId = 'owner/repo' - const logger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {}, - } let getCustomFilePickerConfigForOrgSpy: any // Explicitly typed as any @@ -81,15 +79,6 @@ describe('requestRelevantFiles', () => { })), })) - mockModule('@codebuff/backend/util/logger', () => ({ - logger: { - info: bunMockFn(() => {}), - error: bunMockFn(() => {}), - warn: bunMockFn(() => {}), - debug: bunMockFn(() => {}), - }, - })) - mockModule('@codebuff/common/db', () => ({ default: { insert: bunMockFn(() => ({ @@ -109,6 +98,8 @@ describe('requestRelevantFiles', () => { }) beforeEach(() => { + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } + // If the spy was created in a previous test, restore it if ( getCustomFilePickerConfigForOrgSpy && @@ -134,6 +125,7 @@ describe('requestRelevantFiles', () => { it('should use default file counts and maxFiles when no custom config', async () => { await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -144,7 +136,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, @@ -161,6 +152,7 @@ describe('requestRelevantFiles', () => { getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -171,7 +163,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, @@ -187,6 +178,7 @@ describe('requestRelevantFiles', () => { getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) const result = await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -197,7 +189,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) expect(result).toBeArray() if (result) { @@ -213,6 +204,7 @@ describe('requestRelevantFiles', () => { getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -223,7 +215,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, @@ -242,6 +233,7 @@ describe('requestRelevantFiles', () => { getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -252,7 +244,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) const expectedModel = finetunedVertexModels.ft_filepicker_010 expect( diff --git a/backend/src/__tests__/run-agent-step-tools.test.ts b/backend/src/__tests__/run-agent-step-tools.test.ts index 114340e2f7..06bf87bdda 100644 --- a/backend/src/__tests__/run-agent-step-tools.test.ts +++ b/backend/src/__tests__/run-agent-step-tools.test.ts @@ -2,7 +2,7 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import db from '@codebuff/common/db' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { @@ -18,18 +18,19 @@ import { // Mock imports import * as liveUserInputs from '../live-user-inputs' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { runAgentStep } from '../run-agent-step' import { clearAgentGeneratorCache } from '../run-programmatic-step' import { asUserMessage } from '../util/messages' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' describe('runAgentStep - set_output tool', () => { let testAgent: AgentTemplate + let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } beforeEach(async () => { // Create a test agent that supports set_output @@ -63,7 +64,7 @@ describe('runAgentStep - set_output tool', () => { // Mock analytics and tracing spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) - analytics.initAnalytics(testAgentRuntimeImpl) + analytics.initAnalytics(agentRuntimeImpl) spyOn(analytics, 'trackEvent').mockImplementation(() => {}) spyOn(bigquery, 'insertTrace').mockImplementation(() => Promise.resolve(true), @@ -104,18 +105,19 @@ describe('runAgentStep - set_output tool', () => { // Don't mock requestToolCall for integration test - let real tool execution happen // Mock LLM APIs - spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) - clearAgentGeneratorCache(testAgentRuntimeImpl) + agentRuntimeImpl.promptAiSdk = async function () { + return 'Test response' + } + clearAgentGeneratorCache(agentRuntimeImpl) }) afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) afterAll(() => { - clearAgentGeneratorCache(testAgentRuntimeImpl) + clearAgentGeneratorCache(agentRuntimeImpl) }) class MockWebSocket { @@ -159,10 +161,10 @@ describe('runAgentStep - set_output tool', () => { '\n\n' + getToolCallString('end_turn', {}) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: mockResponse } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -171,6 +173,7 @@ describe('runAgentStep - set_output tool', () => { } const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -184,7 +187,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Analyze the codebase', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) expect(result.agentState.output).toEqual({ @@ -201,10 +203,10 @@ describe('runAgentStep - set_output tool', () => { findings: ['Bug in auth.ts', 'Missing validation'], }) + getToolCallString('end_turn', {}) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: mockResponse } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -213,6 +215,7 @@ describe('runAgentStep - set_output tool', () => { } const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -226,7 +229,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Analyze the codebase', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) expect(result.agentState.output).toEqual({ @@ -244,10 +246,10 @@ describe('runAgentStep - set_output tool', () => { existingField: 'updated value', }) + getToolCallString('end_turn', {}) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: mockResponse } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -261,6 +263,7 @@ describe('runAgentStep - set_output tool', () => { } const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -274,7 +277,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Update the output', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) expect(result.agentState.output).toEqual({ @@ -287,10 +289,10 @@ describe('runAgentStep - set_output tool', () => { const mockResponse = getToolCallString('set_output', {}) + getToolCallString('end_turn', {}) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: mockResponse } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -300,6 +302,7 @@ describe('runAgentStep - set_output tool', () => { } const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -313,7 +316,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Update with empty object', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) // Should replace with empty object @@ -369,10 +371,10 @@ describe('runAgentStep - set_output tool', () => { ) // Mock the LLM stream to return a response that doesn't end the turn - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: 'Continuing with the analysis...' } // Non-empty response, no tool calls return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -396,6 +398,7 @@ describe('runAgentStep - set_output tool', () => { const initialMessageCount = agentState.messageHistory.length const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -409,7 +412,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Test the handleSteps functionality', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) // Should end turn because toolCalls.length === 0 && toolResults.length === 0 from LLM processing @@ -519,7 +521,7 @@ describe('runAgentStep - set_output tool', () => { } // Mock the LLM stream to spawn the inline agent - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: getToolCallString('spawn_agent_inline', { @@ -528,7 +530,7 @@ describe('runAgentStep - set_output tool', () => { }), } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -556,7 +558,7 @@ describe('runAgentStep - set_output tool', () => { ] const result = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', diff --git a/backend/src/__tests__/spawn-agents-message-history.test.ts b/backend/src/__tests__/spawn-agents-message-history.test.ts index 60ac1a430b..25aed6a383 100644 --- a/backend/src/__tests__/spawn-agents-message-history.test.ts +++ b/backend/src/__tests__/spawn-agents-message-history.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { describe, @@ -107,7 +107,7 @@ describe('Spawn Agents Message History', () => { ] const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -179,7 +179,7 @@ describe('Spawn Agents Message History', () => { ] const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -216,7 +216,7 @@ describe('Spawn Agents Message History', () => { const mockMessages: Message[] = [] // Empty message history const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -256,7 +256,7 @@ describe('Spawn Agents Message History', () => { ] const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, diff --git a/backend/src/__tests__/spawn-agents-permissions.test.ts b/backend/src/__tests__/spawn-agents-permissions.test.ts index 72c29d94a7..3db04e4aa2 100644 --- a/backend/src/__tests__/spawn-agents-permissions.test.ts +++ b/backend/src/__tests__/spawn-agents-permissions.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { describe, @@ -235,7 +235,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('thinker') const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -269,7 +269,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('reviewer') // Try to spawn reviewer const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -305,7 +305,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('nonexistent') const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -343,7 +343,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('codebuff/thinker@1.0.0') const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -377,7 +377,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('thinker') // Simple name const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -414,7 +414,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('codebuff/thinker@2.0.0') const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -462,7 +462,7 @@ describe('Spawn Agents Permissions', () => { } const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -517,7 +517,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('thinker') const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -549,7 +549,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('reviewer') // Try to spawn reviewer const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -582,7 +582,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('nonexistent') const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -614,7 +614,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('codebuff/thinker@1.0.0') const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -646,7 +646,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('thinker') // Simple name const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -681,7 +681,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('codebuff/thinker@2.0.0') const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -713,7 +713,7 @@ describe('Spawn Agents Permissions', () => { expect(() => { handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, diff --git a/backend/src/__tests__/subagent-streaming.test.ts b/backend/src/__tests__/subagent-streaming.test.ts index 494facdbb8..47500691f1 100644 --- a/backend/src/__tests__/subagent-streaming.test.ts +++ b/backend/src/__tests__/subagent-streaming.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { afterAll, @@ -146,7 +146,7 @@ describe('Subagent Streaming', () => { } const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -224,7 +224,7 @@ describe('Subagent Streaming', () => { } const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, diff --git a/backend/src/__tests__/web-search-tool.test.ts b/backend/src/__tests__/web-search-tool.test.ts index c384896b2b..82e094b25f 100644 --- a/backend/src/__tests__/web-search-tool.test.ts +++ b/backend/src/__tests__/web-search-tool.test.ts @@ -4,7 +4,7 @@ process.env.LINKUP_API_KEY = 'test-api-key' import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { @@ -23,15 +23,16 @@ import * as requestFilesPrompt from '../find-files/request-files-prompt' import * as liveUserInputs from '../live-user-inputs' import { MockWebSocket, mockFileContext } from './test-utils' import * as linkupApi from '../llm-apis/linkup-api' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { runAgentStep } from '../run-agent-step' import { assembleLocalAgentTemplates } from '../templates/agent-registry' import * as websocketAction from '../websockets/websocket-action' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } function mockAgentStream(content: string | string[]) { - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { if (typeof content === 'string') { content = [content] } @@ -39,14 +40,14 @@ function mockAgentStream(content: string | string[]) { yield { type: 'text' as const, text: chunk } } return 'mock-message-id' - }) + } } describe('web_search tool with researcher agent', () => { beforeEach(() => { // Mock analytics and tracing spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) - analytics.initAnalytics(testAgentRuntimeImpl) + analytics.initAnalytics(agentRuntimeImpl) spyOn(analytics, 'trackEvent').mockImplementation(() => {}) spyOn(bigquery, 'insertTrace').mockImplementation(() => Promise.resolve(true), @@ -65,9 +66,9 @@ describe('web_search tool with researcher agent', () => { })) // Mock LLM APIs - spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) + agentRuntimeImpl.promptAiSdk = async function () { + return 'Test response' + } // Mock other required modules spyOn(requestFilesPrompt, 'requestRelevantFiles').mockImplementation( @@ -84,6 +85,7 @@ describe('web_search tool with researcher agent', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) // MockWebSocket and mockFileContext imported from test-utils @@ -114,12 +116,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -164,12 +166,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -223,12 +225,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -267,12 +269,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -325,12 +327,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -382,12 +384,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -429,12 +431,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -488,12 +490,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, diff --git a/backend/src/admin/grade-runs.ts b/backend/src/admin/grade-runs.ts index 967c684606..287d300154 100644 --- a/backend/src/admin/grade-runs.ts +++ b/backend/src/admin/grade-runs.ts @@ -1,9 +1,8 @@ import { models, TEST_USER_ID } from '@codebuff/common/old-constants' import { closeXml } from '@codebuff/common/util/xml' -import { promptAiSdk } from '../llm-apis/vercel-ai-sdk/ai-sdk' - import type { Relabel, GetRelevantFilesTrace } from '@codebuff/bigquery' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' const PROMPT = ` @@ -100,9 +99,10 @@ function extractResponse(response: string): { export async function gradeRun(params: { trace: GetRelevantFilesTrace relabels: Relabel[] + promptAiSdk: PromptAiSdkFn logger: Logger }) { - const { trace, relabels, logger } = params + const { trace, relabels, promptAiSdk, logger } = params const messages = trace.payload.messages const originalOutput = trace.payload.output diff --git a/backend/src/admin/relabelRuns.ts b/backend/src/admin/relabelRuns.ts index e40fd56c4a..f52982b519 100644 --- a/backend/src/admin/relabelRuns.ts +++ b/backend/src/admin/relabelRuns.ts @@ -24,8 +24,10 @@ import type { GetRelevantFilesTrace, Relabel, } from '@codebuff/bigquery' -import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' +import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { Request, Response } from 'express' // --- GET Handler Logic --- @@ -153,6 +155,7 @@ export async function relabelForUserHandler(params: { const relaceResults = relabelUsingFullFilesForUser({ userId, limit, + promptAiSdk, logger, }) @@ -265,11 +268,16 @@ export async function relabelForUserHandler(params: { } } -async function relabelUsingFullFilesForUser(params: { - userId: string - limit: number - logger: Logger -}) { +async function relabelUsingFullFilesForUser( + params: { + userId: string + limit: number + logger: Logger + } & ParamsExcluding< + typeof relabelWithClaudeWithFullFileContext, + 'trace' | 'fileBlobs' | 'model' + >, +) { const { userId, limit, logger } = params // TODO: We need to figure out changing _everything_ to use `getTracesAndAllDataForUser` const tracesBundles = await getTracesAndAllDataForUser(userId) @@ -308,10 +316,10 @@ async function relabelUsingFullFilesForUser(params: { ) { relabelPromises.push( relabelWithClaudeWithFullFileContext({ + ...params, trace, fileBlobs, model, - logger, }), ) didRelabel = true @@ -392,9 +400,10 @@ export async function relabelWithClaudeWithFullFileContext(params: { fileBlobs: GetExpandedFileContextForTrainingBlobTrace model: string dataset?: string + promptAiSdk: PromptAiSdkFn logger: Logger }) { - const { trace, fileBlobs, model, dataset, logger } = params + const { trace, fileBlobs, model, dataset, promptAiSdk, logger } = params if (dataset) { await setupBigQuery({ dataset, logger }) } diff --git a/backend/src/check-terminal-command.ts b/backend/src/check-terminal-command.ts index 1c1a2931bc..b1973b44b0 100644 --- a/backend/src/check-terminal-command.ts +++ b/backend/src/check-terminal-command.ts @@ -1,8 +1,7 @@ import { models } from '@codebuff/common/old-constants' import { withTimeout } from '@codebuff/common/util/promise' -import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' - +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsExcluding } from '@codebuff/common/types/function-params' @@ -13,10 +12,11 @@ import type { ParamsExcluding } from '@codebuff/common/types/function-params' export async function checkTerminalCommand( params: { prompt: string + promptAiSdk: PromptAiSdkFn logger: Logger - } & ParamsExcluding, + } & ParamsExcluding, ): Promise { - const { prompt, logger } = params + const { prompt, promptAiSdk, logger } = params if (!prompt?.trim()) { return null } diff --git a/backend/src/fast-rewrite.ts b/backend/src/fast-rewrite.ts index 49212bd584..f6709909e9 100644 --- a/backend/src/fast-rewrite.ts +++ b/backend/src/fast-rewrite.ts @@ -5,52 +5,33 @@ import { generateCompactId, hasLazyEdit } from '@codebuff/common/util/string' import { promptFlashWithFallbacks } from './llm-apis/gemini-with-fallbacks' import { promptRelaceAI } from './llm-apis/relace-api' -import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' import type { CodebuffToolMessage } from '@codebuff/common/tools/list' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { Message, ToolMessage, } from '@codebuff/common/types/messages/codebuff-message' -export async function fastRewrite(params: { - initialContent: string - editSnippet: string - filePath: string - instructions: string | undefined - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - userMessage: string | undefined - logger: Logger -}) { - const { - initialContent, - editSnippet, - filePath, - instructions, - clientSessionId, - fingerprintId, - userInputId, - userId, - userMessage, - logger, - } = params +export async function fastRewrite( + params: { + initialContent: string + editSnippet: string + filePath: string + userMessage: string | undefined + logger: Logger + } & ParamsExcluding & + ParamsExcluding, +) { + const { initialContent, editSnippet, filePath, userMessage, logger } = params const relaceStartTime = Date.now() const messageId = generateCompactId('cb-') let response = await promptRelaceAI({ + ...params, initialCode: initialContent, - editSnippet, - instructions, - clientSessionId, - fingerprintId, - userInputId, - userId, - userMessage, messageId, - logger, }) const relaceDuration = Date.now() - relaceStartTime @@ -62,15 +43,8 @@ export async function fastRewrite(params: { ) { const relaceResponse = response response = await rewriteWithOpenAI({ + ...params, oldContent: initialContent, - editSnippet, - filePath, - clientSessionId, - fingerprintId, - userInputId, - userId, - userMessage, - logger, }) logger.debug( { filePath, relaceResponse, openaiResponse: response, messageId }, @@ -97,23 +71,21 @@ export async function fastRewrite(params: { export async function rewriteWithOpenAI(params: { oldContent: string editSnippet: string - filePath: string clientSessionId: string fingerprintId: string userInputId: string userId: string | undefined - userMessage: string | undefined + promptAiSdk: PromptAiSdkFn logger: Logger }): Promise { const { oldContent, editSnippet, - filePath, clientSessionId, fingerprintId, userInputId, userId, - userMessage, + promptAiSdk, logger, } = params const prompt = `You are an expert programmer tasked with implementing changes to a file. Please rewrite the file to implement the changes shown in the edit snippet, while preserving the original formatting and behavior of unchanged parts. @@ -158,28 +130,22 @@ Please output just the complete updated file content with the edit applied and n * sketches an update to a single function, but forgets to add ... existing code ... * above and below the function. */ -export const shouldAddFilePlaceholders = async (params: { - filePath: string - oldContent: string - rewrittenNewContent: string - messageHistory: Message[] - fullResponse: string - userId: string | undefined - clientSessionId: string - fingerprintId: string - userInputId: string - logger: Logger -}) => { +export const shouldAddFilePlaceholders = async ( + params: { + filePath: string + oldContent: string + rewrittenNewContent: string + messageHistory: Message[] + fullResponse: string + logger: Logger + } & ParamsExcluding, +) => { const { filePath, oldContent, rewrittenNewContent, messageHistory, fullResponse, - userId, - clientSessionId, - fingerprintId, - userInputId, logger, } = params const fileWasPreviouslyEdited = messageHistory @@ -244,13 +210,9 @@ Do not write anything else. }, ) const response = await promptFlashWithFallbacks({ + ...params, messages, - clientSessionId, - fingerprintId, - userInputId, model: models.openrouter_gemini2_5_flash, - userId, - logger, }) const shouldAddPlaceholderComments = response.includes('LOCAL_CHANGE_ONLY') logger.debug( diff --git a/backend/src/find-files/request-files-prompt.ts b/backend/src/find-files/request-files-prompt.ts index 4d29189d96..7f311d7836 100644 --- a/backend/src/find-files/request-files-prompt.ts +++ b/backend/src/find-files/request-files-prompt.ts @@ -14,7 +14,6 @@ import { range, shuffle, uniq } from 'lodash' import { CustomFilePickerConfigSchema } from './custom-file-picker-config' import { promptFlashWithFallbacks } from '../llm-apis/gemini-with-fallbacks' -import { promptAiSdk } from '../llm-apis/vercel-ai-sdk/ai-sdk' import { castAssistantMessage, messagesWithSystem, @@ -28,9 +27,11 @@ import type { GetExpandedFileContextForTrainingTrace, GetRelevantFilesTrace, } from '@codebuff/bigquery' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' -import type { Logger } from '@codebuff/common/types/contracts/logger' const NUMBER_OF_EXAMPLE_FILES = 100 const MAX_FILES_PER_REQUEST = 30 @@ -120,32 +121,25 @@ function isValidFilePickerModelName( return Object.keys(finetunedVertexModels).includes(modelName) } -export async function requestRelevantFiles(params: { - messages: Message[] - system: string | Array - fileContext: ProjectFileContext - assistantPrompt: string | null - agentStepId: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - repoId: string | undefined - logger: Logger -}) { - const { - messages, - system, - fileContext, - assistantPrompt, - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, - logger, - } = params +export async function requestRelevantFiles( + params: { + messages: Message[] + system: string | Array + fileContext: ProjectFileContext + assistantPrompt: string | null + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + repoId: string | undefined + logger: Logger + } & ParamsExcluding< + typeof getRelevantFiles, + 'messages' | 'userPrompt' | 'requestType' | 'modelId' + >, +) { + const { messages, fileContext, assistantPrompt, logger } = params // Check for organization custom file picker feature const requestContext = getRequestContext() const orgId = requestContext?.approvedOrgIdForRepo @@ -193,18 +187,11 @@ export async function requestRelevantFiles(params: { } const keyPromise = getRelevantFiles({ + ...params, messages: messagesExcludingLastIfByUser, - system, userPrompt: keyPrompt, requestType: 'Key', - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, modelId: modelIdForRequest, - logger, }).catch((error) => { logger.error({ error }, 'Error requesting key files') return { files: [] as string[], duration: 0 } @@ -227,32 +214,18 @@ export async function requestRelevantFiles(params: { return candidateFiles.slice(0, maxFilesPerRequest) } -export async function requestRelevantFilesForTraining(params: { - messages: Message[] - system: string | Array - fileContext: ProjectFileContext - assistantPrompt: string | null - agentStepId: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - repoId: string | undefined - logger: Logger -}) { - const { - messages, - system, - fileContext, - assistantPrompt, - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, - logger, - } = params +export async function requestRelevantFilesForTraining( + params: { + messages: Message[] + fileContext: ProjectFileContext + assistantPrompt: string | null + logger: Logger + } & ParamsExcluding< + typeof getRelevantFilesForTraining, + 'messages' | 'userPrompt' | 'requestType' + >, +) { + const { messages, fileContext, assistantPrompt, logger } = params const COUNT = 50 const lastMessage = messages[messages.length - 1] @@ -279,31 +252,17 @@ export async function requestRelevantFilesForTraining(params: { ) const keyFiles = await getRelevantFilesForTraining({ + ...params, messages: messagesExcludingLastIfByUser, - system, userPrompt: keyFilesPrompt, requestType: 'Key', - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, - logger, }) const nonObviousFiles = await getRelevantFilesForTraining({ + ...params, messages: messagesExcludingLastIfByUser, - system, userPrompt: nonObviousPrompt, requestType: 'Non-Obvious', - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, - logger, }) const candidateFiles = [...keyFiles.files, ...nonObviousFiles.files] @@ -315,20 +274,25 @@ export async function requestRelevantFilesForTraining(params: { return validatedFiles.slice(0, MAX_FILES_PER_REQUEST) } -async function getRelevantFiles(params: { - messages: Message[] - system: string | Array - userPrompt: string - requestType: string - agentStepId: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - repoId: string | undefined - modelId?: FinetunedVertexModel - logger: Logger -}) { +async function getRelevantFiles( + params: { + messages: Message[] + system: string | Array + userPrompt: string + requestType: string + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + repoId: string | undefined + modelId?: FinetunedVertexModel + logger: Logger + } & ParamsExcluding< + typeof promptFlashWithFallbacks, + 'messages' | 'model' | 'useFinetunedModel' + >, +) { const { messages, system, @@ -374,14 +338,10 @@ async function getRelevantFiles(params: { const finetunedModel = modelId ?? finetunedVertexModels.ft_filepicker_010 let response = await promptFlashWithFallbacks({ + ...params, messages: codebuffMessages, - clientSessionId, - userInputId, model: models.openrouter_gemini2_5_flash, - userId, useFinetunedModel: finetunedModel, - fingerprintId, - logger, }) const end = performance.now() const duration = end - start @@ -425,6 +385,7 @@ async function getRelevantFilesForTraining(params: { userInputId: string userId: string | undefined repoId: string | undefined + promptAiSdk: PromptAiSdkFn logger: Logger }) { const { @@ -438,6 +399,7 @@ async function getRelevantFilesForTraining(params: { userInputId, userId, repoId, + promptAiSdk, logger, } = params const bufferTokens = 100_000 diff --git a/backend/src/generate-diffs-prompt.ts b/backend/src/generate-diffs-prompt.ts index 11f0ac39c9..4575ab3c54 100644 --- a/backend/src/generate-diffs-prompt.ts +++ b/backend/src/generate-diffs-prompt.ts @@ -4,9 +4,9 @@ import { createSearchReplaceBlock, } from '@codebuff/common/util/file' -import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' - +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' export const parseAndGetDiffBlocksSingleFile = (params: { newContent: string @@ -133,24 +133,27 @@ export const tryToDoStringReplacementWithExtraIndentation = (params: { return null } -export async function retryDiffBlocksPrompt(params: { - filePath: string - oldContent: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - diffBlocksThatDidntMatch: { searchContent: string; replaceContent: string }[] - logger: Logger -}) { +export async function retryDiffBlocksPrompt( + params: { + filePath: string + oldContent: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + diffBlocksThatDidntMatch: { + searchContent: string + replaceContent: string + }[] + promptAiSdk: PromptAiSdkFn + logger: Logger + } & ParamsExcluding, +) { const { filePath, oldContent, - clientSessionId, - fingerprintId, - userInputId, - userId, diffBlocksThatDidntMatch, + promptAiSdk, logger, } = params const newPrompt = @@ -168,13 +171,9 @@ The search content needs to match an exact substring of the old file content, wh Provide a new set of SEARCH/REPLACE changes to make the intended edit from the old file.`.trim() const response = await promptAiSdk({ + ...params, messages: [{ role: 'user', content: newPrompt }], model: models.openrouter_claude_sonnet_4, - clientSessionId, - fingerprintId, - userInputId, - userId, - logger, }) const { diffBlocks: newDiffBlocks, diff --git a/backend/src/get-documentation-for-query.ts b/backend/src/get-documentation-for-query.ts index 35dafa8cd5..4492216818 100644 --- a/backend/src/get-documentation-for-query.ts +++ b/backend/src/get-documentation-for-query.ts @@ -4,9 +4,13 @@ import { uniq } from 'lodash' import { z } from 'zod/v4' import { fetchContext7LibraryDocumentation } from './llm-apis/context7-api' -import { promptAiSdkStructured } from './llm-apis/vercel-ai-sdk/ai-sdk' +import type { PromptAiSdkStructuredFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { + ParamsExcluding, + ParamsOf, +} from '@codebuff/common/types/function-params' const DELIMITER = `\n\n----------------------------------------\n\n` @@ -20,15 +24,18 @@ const DELIMITER = `\n\n----------------------------------------\n\n` * @param options.userId The ID of the user making the request * @returns The documentation text chunks or null if no relevant docs found */ -export async function getDocumentationForQuery(params: { - query: string - tokens?: number - clientSessionId: string - userInputId: string - fingerprintId: string - userId?: string - logger: Logger -}): Promise { +export async function getDocumentationForQuery( + params: { + query: string + tokens?: number + clientSessionId: string + userInputId: string + fingerprintId: string + userId?: string + logger: Logger + } & ParamsOf & + ParamsExcluding, +): Promise { const { query, tokens, @@ -41,14 +48,7 @@ export async function getDocumentationForQuery(params: { const startTime = Date.now() // 1. Search for relevant libraries - const libraryResults = await suggestLibraries({ - query, - clientSessionId, - userInputId, - fingerprintId, - userId, - logger, - }) + const libraryResults = await suggestLibraries(params) if (!libraryResults || libraryResults.libraries.length === 0) { logger.info( @@ -104,6 +104,7 @@ export async function getDocumentationForQuery(params: { // 3. Filter relevant chunks using another LLM call const filterResults = await filterRelevantChunks({ + ...params, query, allChunks: allUniqueChunks, clientSessionId, @@ -163,10 +164,18 @@ const suggestLibraries = async (params: { userInputId: string fingerprintId: string userId?: string + promptAiSdkStructured: PromptAiSdkStructuredFn logger: Logger }) => { - const { query, clientSessionId, userInputId, fingerprintId, userId, logger } = - params + const { + query, + clientSessionId, + userInputId, + fingerprintId, + userId, + promptAiSdkStructured, + logger, + } = params const prompt = `You are an expert at documentation for libraries. Given a user's query return a list of (library name, topic) where each library name is the name of a library and topic is a keyword or phrase that specifies a topic within the library that is most relevant to the user's query. @@ -231,6 +240,7 @@ async function filterRelevantChunks(params: { userInputId: string fingerprintId: string userId?: string + promptAiSdkStructured: PromptAiSdkStructuredFn logger: Logger }): Promise<{ relevantChunks: string[]; geminiDuration: number } | null> { const { @@ -240,6 +250,7 @@ async function filterRelevantChunks(params: { userInputId, fingerprintId, userId, + promptAiSdkStructured, logger, } = params const prompt = `You are an expert at analyzing documentation queries. Given a user's query and a list of documentation chunks, determine which chunks are relevant to the query. Choose as few chunks as possible, likely none. Only include chunks if they are relevant to the user query. diff --git a/backend/src/impl/agent-runtime.ts b/backend/src/impl/agent-runtime.ts index a766de3ea7..53759bedea 100644 --- a/backend/src/impl/agent-runtime.ts +++ b/backend/src/impl/agent-runtime.ts @@ -1,13 +1,24 @@ import { addAgentStep, finishAgentRun, startAgentRun } from '../agent-run' +import { + promptAiSdk, + promptAiSdkStream, + promptAiSdkStructured, +} from '../llm-apis/vercel-ai-sdk/ai-sdk' import { logger } from '../util/logger' import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -export const backendAgentRuntimeImpl: AgentRuntimeDeps = { - logger, - +export const BACKEND_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ + // Database startAgentRun, finishAgentRun, addAgentStep, - // promptAiSdkStream, -} + + // LLM + promptAiSdkStream, + promptAiSdk, + promptAiSdkStructured, + + // Other + logger, +}) diff --git a/backend/src/llm-apis/gemini-with-fallbacks.ts b/backend/src/llm-apis/gemini-with-fallbacks.ts index 15763c43f9..4a2cb71350 100644 --- a/backend/src/llm-apis/gemini-with-fallbacks.ts +++ b/backend/src/llm-apis/gemini-with-fallbacks.ts @@ -1,14 +1,13 @@ import { openaiModels, openrouterModels } from '@codebuff/common/old-constants' -import { promptAiSdk } from './vercel-ai-sdk/ai-sdk' - import type { CostMode, FinetunedVertexModel, - Model, } from '@codebuff/common/old-constants' -import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' +import type { Message } from '@codebuff/common/types/messages/codebuff-message' /** * Prompts a Gemini model with fallback logic. @@ -35,38 +34,33 @@ import type { Logger } from '@codebuff/common/types/contracts/logger' * @returns A promise that resolves to the complete response string from the successful API call. * @throws If all API calls (primary and fallbacks) fail. */ -export async function promptFlashWithFallbacks(params: { - messages: Message[] - clientSessionId: string - fingerprintId: string - userInputId: string - model: Model - userId: string | undefined - maxTokens?: number - temperature?: number - costMode?: CostMode - useGPT4oInsteadOfClaude?: boolean - thinkingBudget?: number - useFinetunedModel?: FinetunedVertexModel | undefined - logger: Logger -}): Promise { +export async function promptFlashWithFallbacks( + params: { + messages: Message[] + costMode?: CostMode + useGPT4oInsteadOfClaude?: boolean + thinkingBudget?: number + useFinetunedModel?: FinetunedVertexModel | undefined + promptAiSdk: PromptAiSdkFn + logger: Logger + } & ParamsExcluding, +): Promise { const { messages, costMode, useGPT4oInsteadOfClaude, useFinetunedModel, + promptAiSdk, logger, - ...geminiOptions } = params // Try finetuned model first if enabled if (useFinetunedModel) { try { return await promptAiSdk({ - ...geminiOptions, + ...params, messages, model: useFinetunedModel, - logger, }) } catch (error) { logger.warn( @@ -78,14 +72,14 @@ export async function promptFlashWithFallbacks(params: { try { // First try Gemini - return await promptAiSdk({ ...geminiOptions, messages, logger }) + return await promptAiSdk({ ...params, messages }) } catch (error) { logger.warn( { error }, `Error calling Gemini API, falling back to ${useGPT4oInsteadOfClaude ? 'gpt-4o' : 'Claude'}`, ) return await promptAiSdk({ - ...geminiOptions, + ...params, messages, model: useGPT4oInsteadOfClaude ? openaiModels.gpt4o @@ -96,7 +90,6 @@ export async function promptFlashWithFallbacks(params: { experimental: openrouterModels.openrouter_claude_3_5_haiku, ask: openrouterModels.openrouter_claude_3_5_haiku, }[costMode ?? 'normal'], - logger, }) } } diff --git a/backend/src/llm-apis/relace-api.ts b/backend/src/llm-apis/relace-api.ts index 54a2ba5075..1daf04aa49 100644 --- a/backend/src/llm-apis/relace-api.ts +++ b/backend/src/llm-apis/relace-api.ts @@ -7,8 +7,8 @@ import { env } from '@codebuff/internal' import { saveMessage } from '../llm-apis/message-cost-tracker' import { countTokens } from '../util/token-counter' -import { promptAiSdk } from './vercel-ai-sdk/ai-sdk' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' const timeoutPromise = (ms: number) => @@ -26,6 +26,7 @@ export async function promptRelaceAI(params: { userId: string | undefined messageId: string userMessage?: string + promptAiSdk: PromptAiSdkFn logger: Logger }) { const { @@ -38,6 +39,7 @@ export async function promptRelaceAI(params: { userId, userMessage, messageId, + promptAiSdk, logger, } = params const startTime = Date.now() diff --git a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts index c2f57e10e4..995a218ced 100644 --- a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts +++ b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts @@ -17,9 +17,13 @@ import { openRouterLanguageModel } from '../openrouter' import { vertexFinetuned } from './vertex-finetuned' import type { Model, OpenAIModel } from '@codebuff/common/old-constants' -import type { ParamsExcluding } from '@codebuff/common/types/function-params' -import type { Logger } from '@codebuff/common/types/contracts/logger' -import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { + PromptAiSdkFn, + PromptAiSdkStreamFn, + PromptAiSdkStructuredInput, + PromptAiSdkStructuredOutput, +} from '@codebuff/common/types/contracts/llm' +import type { ParamsOf } from '@codebuff/common/types/function-params' import type { OpenRouterProviderOptions, OpenRouterUsageAccounting, @@ -27,17 +31,6 @@ import type { import type { LanguageModel } from 'ai' import type { z } from 'zod/v4' -export type StreamChunk = - | { - type: 'text' - text: string - } - | { - type: 'reasoning' - text: string - } - | { type: 'error'; message: string } - // TODO: We'll want to add all our models here! const modelToAiSDKModel = (model: Model): LanguageModel => { if ( @@ -60,23 +53,9 @@ const modelToAiSDKModel = (model: Model): LanguageModel => { // TODO: Add retries & fallbacks: likely by allowing this to instead of "model" // also take an array of form [{model: Model, retries: number}, {model: Model, retries: number}...] // eg: [{model: "gemini-2.0-flash-001"}, {model: "vertex/gemini-2.0-flash-001"}, {model: "claude-3-5-haiku", retries: 3}] -export const promptAiSdkStream = async function* ( - params: { - messages: Message[] - clientSessionId: string - fingerprintId: string - model: Model - userId: string | undefined - chargeUser?: boolean - thinkingBudget?: number - userInputId: string - agentId?: string - maxRetries?: number - onCostCalculated?: (credits: number) => Promise - includeCacheControl?: boolean - logger: Logger - } & ParamsExcluding, -): AsyncGenerator { +export async function* promptAiSdkStream( + params: ParamsOf, +): ReturnType { const { logger } = params if ( !checkLiveUserInput({ ...params, clientSessionId: params.clientSessionId }) @@ -251,22 +230,9 @@ export const promptAiSdkStream = async function* ( } // TODO: figure out a nice way to unify stream & non-stream versions maybe? -export const promptAiSdk = async function ( - params: { - messages: Message[] - clientSessionId: string - fingerprintId: string - userInputId: string - model: Model - userId: string | undefined - chargeUser?: boolean - agentId?: string - onCostCalculated?: (credits: number) => Promise - includeCacheControl?: boolean - maxRetries?: number - logger: Logger - } & ParamsExcluding, -): Promise { +export async function promptAiSdk( + params: ParamsOf, +): ReturnType { const { logger } = params if (!checkLiveUserInput(params)) { @@ -321,24 +287,9 @@ export const promptAiSdk = async function ( } // Copied over exactly from promptAiSdk but with a schema -export const promptAiSdkStructured = async function (params: { - messages: Message[] - schema: z.ZodType - clientSessionId: string - fingerprintId: string - userInputId: string - model: Model - userId: string | undefined - maxTokens?: number - temperature?: number - timeout?: number - chargeUser?: boolean - agentId?: string - onCostCalculated?: (credits: number) => Promise - includeCacheControl?: boolean - maxRetries?: number - logger: Logger -}): Promise { +export async function promptAiSdkStructured( + params: PromptAiSdkStructuredInput, +): PromptAiSdkStructuredOutput { const { logger } = params if (!checkLiveUserInput(params)) { diff --git a/backend/src/main-prompt.ts b/backend/src/main-prompt.ts index 4206cd69d8..a4faa0ccae 100644 --- a/backend/src/main-prompt.ts +++ b/backend/src/main-prompt.ts @@ -11,14 +11,14 @@ import { requestToolCall } from './websockets/websocket-action' import type { AgentTemplate } from './templates/types' import type { ClientAction } from '@codebuff/common/actions' import type { CostMode } from '@codebuff/common/old-constants' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { SessionState, AgentTemplateType, AgentOutput, } from '@codebuff/common/types/session-state' -import type { ParamsExcluding } from '@codebuff/common/types/function-params' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { WebSocket } from 'ws' export const mainPrompt = async ( @@ -42,7 +42,11 @@ export const mainPrompt = async ( | 'agentType' | 'fingerprintId' | 'fileContext' - >, + > & + ParamsExcluding< + typeof checkTerminalCommand, + 'prompt' | 'fingerprintId' | 'userInputId' + >, ): Promise<{ sessionState: SessionState output: AgentOutput @@ -156,12 +160,10 @@ export const mainPrompt = async ( // Check if this is a direct terminal command const startTime = Date.now() const terminalCommand = await checkTerminalCommand({ + ...params, prompt, - clientSessionId, fingerprintId, userInputId: promptId, - userId, - logger, }) const duration = Date.now() - startTime diff --git a/backend/src/process-file-block.ts b/backend/src/process-file-block.ts index 6d8101abb5..b74bad728e 100644 --- a/backend/src/process-file-block.ts +++ b/backend/src/process-file-block.ts @@ -8,26 +8,39 @@ import { parseAndGetDiffBlocksSingleFile, retryDiffBlocksPrompt, } from './generate-diffs-prompt' -import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' import { countTokens } from './util/token-counter' -import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' +import type { Message } from '@codebuff/common/types/messages/codebuff-message' -export async function processFileBlock(params: { - path: string - instructions: string | undefined - initialContentPromise: Promise - newContent: string - messages: Message[] - fullResponse: string - lastUserPrompt: string | undefined - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - logger: Logger -}): Promise< +export async function processFileBlock( + params: { + path: string + initialContentPromise: Promise + newContent: string + messages: Message[] + fullResponse: string + lastUserPrompt: string | undefined + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + logger: Logger + } & ParamsExcluding< + typeof handleLargeFile, + 'oldContent' | 'editSnippet' | 'filePath' + > & + ParamsExcluding< + typeof fastRewrite, + 'initialContent' | 'editSnippet' | 'filePath' | 'userMessage' + > & + ParamsExcluding< + typeof shouldAddFilePlaceholders, + 'filePath' | 'oldContent' | 'rewrittenNewContent' | 'messageHistory' + >, +): Promise< | { tool: 'write_file' path: string @@ -43,7 +56,6 @@ export async function processFileBlock(params: { > { const { path, - instructions, initialContentPromise, newContent, messages, @@ -113,14 +125,10 @@ export async function processFileBlock(params: { ) if (tokenCount > LARGE_FILE_TOKEN_LIMIT) { const largeFileContent = await handleLargeFile({ + ...params, oldContent: normalizedInitialContent, editSnippet: normalizedEditSnippet, - clientSessionId, - fingerprintId, - userInputId, - userId, filePath: path, - logger, }) if (!largeFileContent) { @@ -135,44 +143,29 @@ export async function processFileBlock(params: { updatedContent = largeFileContent } else { updatedContent = await fastRewrite({ + ...params, initialContent: normalizedInitialContent, editSnippet: normalizedEditSnippet, filePath: path, - instructions, - clientSessionId, - fingerprintId, - userInputId, - userId, userMessage: lastUserPrompt, - logger, }) const shouldAddPlaceholders = await shouldAddFilePlaceholders({ + ...params, filePath: path, oldContent: normalizedInitialContent, rewrittenNewContent: updatedContent, messageHistory: messages, - fullResponse, - userId, - clientSessionId, - fingerprintId, - userInputId, - logger, }) if (shouldAddPlaceholders) { const placeholderComment = `... existing code ...` const updatedEditSnippet = `${placeholderComment}\n${updatedContent}\n${placeholderComment}` updatedContent = await fastRewrite({ + ...params, initialContent: normalizedInitialContent, editSnippet: updatedEditSnippet, filePath: path, - instructions, - clientSessionId, - fingerprintId, - userInputId, - userId, userMessage: lastUserPrompt, - logger, }) } } @@ -230,16 +223,22 @@ export async function processFileBlock(params: { const LARGE_FILE_TOKEN_LIMIT = 64_000 -export async function handleLargeFile(params: { - oldContent: string - editSnippet: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - filePath: string - logger: Logger -}): Promise { +export async function handleLargeFile( + params: { + oldContent: string + editSnippet: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + filePath: string + logger: Logger + promptAiSdk: PromptAiSdkFn + } & ParamsExcluding< + typeof retryDiffBlocksPrompt, + 'oldContent' | 'diffBlocksThatDidntMatch' + >, +): Promise { const { oldContent, editSnippet, @@ -248,6 +247,7 @@ export async function handleLargeFile(params: { userInputId, userId, filePath, + promptAiSdk, logger, } = params const startTime = Date.now() @@ -325,14 +325,9 @@ Please output just the SEARCH/REPLACE blocks like this: const { newDiffBlocks, newDiffBlocksThatDidntMatch } = await retryDiffBlocksPrompt({ - filePath, + ...params, oldContent: updatedContent, - clientSessionId, - fingerprintId, - userInputId, - userId, diffBlocksThatDidntMatch, - logger, }) if (newDiffBlocksThatDidntMatch.length > 0) { diff --git a/backend/src/prompt-agent-stream.ts b/backend/src/prompt-agent-stream.ts index bad7e3aca8..fa8239887d 100644 --- a/backend/src/prompt-agent-stream.ts +++ b/backend/src/prompt-agent-stream.ts @@ -1,12 +1,13 @@ import { providerModelNames } from '@codebuff/common/old-constants' -import { promptAiSdkStream } from './llm-apis/vercel-ai-sdk/ai-sdk' import { globalStopSequence } from './tools/constants' import type { AgentTemplate } from './templates/types' +import type { PromptAiSdkStreamFn } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsOf } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { OpenRouterProviderOptions } from '@codebuff/internal/openrouter-ai-sdk' -import type { Logger } from '@codebuff/common/types/contracts/logger' export const getAgentStreamFromTemplate = (params: { clientSessionId: string @@ -19,6 +20,7 @@ export const getAgentStreamFromTemplate = (params: { template: AgentTemplate logger: Logger + promptAiSdkStream: PromptAiSdkStreamFn }) => { const { clientSessionId, @@ -30,6 +32,7 @@ export const getAgentStreamFromTemplate = (params: { includeCacheControl, template, logger, + promptAiSdkStream, } = params if (!template) { @@ -39,7 +42,7 @@ export const getAgentStreamFromTemplate = (params: { const { model } = template const getStream = (messages: Message[]) => { - const aiSdkStreamParams: Parameters[0] = { + const aiSdkStreamParams: ParamsOf = { messages, model, stopSequences: [globalStopSequence], diff --git a/backend/src/run-agent-step.ts b/backend/src/run-agent-step.ts index a2490a7c17..92746d5ff8 100644 --- a/backend/src/run-agent-step.ts +++ b/backend/src/run-agent-step.ts @@ -103,7 +103,11 @@ export const runAgentStep = async ( | 'agentTemplate' | 'agentContext' | 'fullResponse' - >, + > & + ParamsExcluding< + typeof getAgentStreamFromTemplate, + 'agentId' | 'template' | 'onCostCalculated' | 'includeCacheControl' + >, ): Promise<{ agentState: AgentState fullResponse: string @@ -246,10 +250,7 @@ export const runAgentStep = async ( const { model } = agentTemplate const { getStream } = getAgentStreamFromTemplate({ - clientSessionId, - fingerprintId, - userInputId, - userId, + ...params, agentId: agentState.agentId, template: agentTemplate, onCostCalculated: async (credits: number) => { @@ -270,7 +271,6 @@ export const runAgentStep = async ( } }, includeCacheControl: supportsCacheControl(agentTemplate.model), - logger, }) const iterationNum = agentState.messageHistory.length diff --git a/backend/src/tools/handlers/tool/find-files.ts b/backend/src/tools/handlers/tool/find-files.ts index 322fe12f84..60a0d450f7 100644 --- a/backend/src/tools/handlers/tool/find-files.ts +++ b/backend/src/tools/handlers/tool/find-files.ts @@ -10,14 +10,17 @@ import { renderReadFilesResult } from '../../../util/parse-tool-call-xml' import { countTokens, countTokensJson } from '../../../util/token-counter' import { requestFiles } from '../../../websockets/websocket-action' -import type { TextBlock } from '../../../llm-apis/claude' import type { CodebuffToolHandlerFunction } from '../handler-function-type' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { GetExpandedFileContextForTrainingBlobTrace } from '@codebuff/bigquery' import type { CodebuffToolCall, CodebuffToolOutput, } from '@codebuff/common/tools/list' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { + ParamsExcluding, + ParamsOf, +} from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' @@ -26,24 +29,44 @@ import type { WebSocket } from 'ws' // TODO: We might want to be able to turn this on on a per-repo basis. const COLLECT_FULL_FILE_CONTEXT = false -export const handleFindFiles = ((params: { - previousToolCallFinished: Promise - toolCall: CodebuffToolCall<'find_files'> - logger: Logger - - fileContext: ProjectFileContext - agentStepId: string - clientSessionId: string - userInputId: string - - state: { - ws?: WebSocket - fingerprintId?: string - userId?: string - repoId?: string - messages?: Message[] - } -}): { result: Promise>; state: {} } => { +export const handleFindFiles = (( + params: { + previousToolCallFinished: Promise + toolCall: CodebuffToolCall<'find_files'> + logger: Logger + + fileContext: ProjectFileContext + agentStepId: string + clientSessionId: string + userInputId: string + + state: { + ws?: WebSocket + fingerprintId?: string + userId?: string + repoId?: string + messages?: Message[] + } + } & ParamsExcluding< + typeof requestRelevantFiles, + | 'messages' + | 'system' + | 'assistantPrompt' + | 'fingerprintId' + | 'userId' + | 'repoId' + > & + ParamsExcluding< + typeof uploadExpandedFileContextForTraining, + | 'ws' + | 'messages' + | 'system' + | 'assistantPrompt' + | 'fingerprintId' + | 'userId' + | 'repoId' + >, +): { result: Promise>; state: {} } => { const { previousToolCallFinished, toolCall, @@ -87,36 +110,29 @@ export const handleFindFiles = ((params: { CodebuffToolOutput<'find_files'> > = async () => { const requestedFiles = await requestRelevantFiles({ + ...params, messages, system, - fileContext, assistantPrompt: prompt, - agentStepId, - clientSessionId, fingerprintId, - userInputId, userId, repoId, - logger, }) if (requestedFiles && requestedFiles.length > 0) { const addedFiles = await getFileReadingUpdates(ws, requestedFiles) if (COLLECT_FULL_FILE_CONTEXT && addedFiles.length > 0) { - uploadExpandedFileContextForTraining( + uploadExpandedFileContextForTraining({ + ...params, ws, - { messages, system }, - fileContext, - prompt, - agentStepId, - clientSessionId, + messages, + system, + assistantPrompt: prompt, fingerprintId, - userInputId, userId, repoId, - logger, - ).catch((error) => { + }).catch((error) => { logger.error( { error }, 'Error uploading expanded file context for training', @@ -165,37 +181,26 @@ export const handleFindFiles = ((params: { }) satisfies CodebuffToolHandlerFunction<'find_files'> async function uploadExpandedFileContextForTraining( - ws: WebSocket, - { - messages, - system, - }: { - messages: Message[] - system: string | Array - }, - fileContext: ProjectFileContext, - assistantPrompt: string | null, - agentStepId: string, - clientSessionId: string, - fingerprintId: string, - userInputId: string, - userId: string | undefined, - repoId: string | undefined, - logger: Logger, + params: { + ws: WebSocket + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + logger: Logger + } & ParamsOf, ) { - const files = await requestRelevantFilesForTraining({ - messages, - system, - fileContext, - assistantPrompt, + const { + ws, agentStepId, clientSessionId, fingerprintId, userInputId, userId, - repoId, logger, - }) + } = params + const files = await requestRelevantFilesForTraining(params) const loadedFiles = await requestFiles({ ws, filePaths: files }) diff --git a/backend/src/tools/handlers/tool/write-file.ts b/backend/src/tools/handlers/tool/write-file.ts index c4cfe5a533..0624b15d58 100644 --- a/backend/src/tools/handlers/tool/write-file.ts +++ b/backend/src/tools/handlers/tool/write-file.ts @@ -3,14 +3,14 @@ import { partition } from 'lodash' import { processFileBlock } from '../../../process-file-block' import { requestOptionalFile } from '../../../websockets/websocket-action' -import type { Logger } from '@codebuff/common/types/contracts/logger' - import type { CodebuffToolHandlerFunction } from '../handler-function-type' import type { ClientToolCall, CodebuffToolCall, CodebuffToolOutput, } from '@codebuff/common/tools/list' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { WebSocket } from 'ws' @@ -63,45 +63,59 @@ export function getFileProcessingValues( return fileProcessingValues } -export const handleWriteFile = (({ - previousToolCallFinished, - toolCall, - - clientSessionId, - userInputId, +export const handleWriteFile = (( + params: { + previousToolCallFinished: Promise + toolCall: CodebuffToolCall<'write_file'> - requestClientToolCall, - writeToClient, + clientSessionId: string + userInputId: string - getLatestState, - state, - logger, -}: { - previousToolCallFinished: Promise - toolCall: CodebuffToolCall<'write_file'> - - clientSessionId: string - userInputId: string - - requestClientToolCall: ( - toolCall: ClientToolCall<'write_file'>, - ) => Promise> - writeToClient: (chunk: string) => void + requestClientToolCall: ( + toolCall: ClientToolCall<'write_file'>, + ) => Promise> + writeToClient: (chunk: string) => void - getLatestState: () => FileProcessingState - state: { - ws?: WebSocket - fingerprintId?: string - userId?: string - fullResponse?: string - prompt?: string - messages?: Message[] - } & OptionalFileProcessingState - logger: Logger -}): { + getLatestState: () => FileProcessingState + state: { + ws?: WebSocket + fingerprintId?: string + userId?: string + fullResponse?: string + prompt?: string + messages?: Message[] + } & OptionalFileProcessingState + logger: Logger + } & ParamsExcluding< + typeof processFileBlock, + | 'path' + | 'instructions' + | 'fingerprintId' + | 'userId' + | 'initialContentPromise' + | 'newContent' + | 'messages' + | 'fullResponse' + | 'lastUserPrompt' + >, +): { result: Promise> state: FileProcessingState } => { + const { + previousToolCallFinished, + toolCall, + + clientSessionId, + userInputId, + + requestClientToolCall, + writeToClient, + + getLatestState, + state, + logger, + } = params const { path, instructions, content } = toolCall.input const { ws, fingerprintId, userId, fullResponse, prompt } = state if (!ws) { @@ -143,6 +157,7 @@ export const handleWriteFile = (({ logger.debug({ path, content }, `write_file ${path}`) const newPromise = processFileBlock({ + ...params, path, instructions, initialContentPromise: latestContentPromise, diff --git a/backend/src/tools/stream-parser.ts b/backend/src/tools/stream-parser.ts index 49fa0fcc2f..b406a84af8 100644 --- a/backend/src/tools/stream-parser.ts +++ b/backend/src/tools/stream-parser.ts @@ -16,10 +16,12 @@ import { executeCustomToolCall, executeToolCall } from './tool-executor' import type { BatchStrReplaceState } from './batch-str-replace' import type { CustomToolCall, ExecuteToolCallParams } from './tool-executor' -import type { StreamChunk } from '../llm-apis/vercel-ai-sdk/ai-sdk' import type { AgentTemplate } from '../templates/types' import type { ToolName } from '@codebuff/common/tools/constants' import type { CodebuffToolCall } from '@codebuff/common/tools/list' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { Message, ToolMessage, @@ -28,8 +30,6 @@ import type { ToolResultPart } from '@codebuff/common/types/messages/content-par import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState, Subgoal } from '@codebuff/common/types/session-state' import type { ProjectFileContext } from '@codebuff/common/util/file' -import type { ParamsExcluding } from '@codebuff/common/types/function-params' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ToolCallPart } from 'ai' import type { WebSocket } from 'ws' diff --git a/backend/src/websockets/middleware.ts b/backend/src/websockets/middleware.ts index 1f95ae62ce..ae0528c168 100644 --- a/backend/src/websockets/middleware.ts +++ b/backend/src/websockets/middleware.ts @@ -16,7 +16,7 @@ import { getUserInfoFromAuthToken } from './auth' import { updateRequestContext } from './request-context' import { sendAction } from './websocket-action' import { withAppContext } from '../context/app-context' -import { backendAgentRuntimeImpl } from '../impl/agent-runtime' +import { BACKEND_AGENT_RUNTIME_IMPL } from '../impl/agent-runtime' import { checkAuth } from '../util/check-auth' import type { UserInfo } from './auth' @@ -164,7 +164,7 @@ export class WebSocketMiddleware { } } -export const protec = new WebSocketMiddleware(backendAgentRuntimeImpl) +export const protec = new WebSocketMiddleware(BACKEND_AGENT_RUNTIME_IMPL) protec.use(async ({ action, clientSessionId, logger }) => checkAuth({ diff --git a/backend/src/xml-stream-parser.ts b/backend/src/xml-stream-parser.ts index 4956cb9e32..9fefc26ef7 100644 --- a/backend/src/xml-stream-parser.ts +++ b/backend/src/xml-stream-parser.ts @@ -7,14 +7,14 @@ import { toolNameParam, } from '@codebuff/common/tools/constants' -import type { StreamChunk } from './llm-apis/vercel-ai-sdk/ai-sdk' import type { Model } from '@codebuff/common/old-constants' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' import type { PrintModeError, PrintModeText, PrintModeToolCall, } from '@codebuff/common/types/print-mode' -import type { Logger } from '@codebuff/common/types/contracts/logger' const toolExtractionPattern = new RegExp( `${startToolTag}(.*?)${endToolTag}`, diff --git a/common/src/testing/impl/agent-runtime.ts b/common/src/testing/impl/agent-runtime.ts index 78b57edbae..90a4e0bce1 100644 --- a/common/src/testing/impl/agent-runtime.ts +++ b/common/src/testing/impl/agent-runtime.ts @@ -8,13 +8,23 @@ export const testLogger: Logger = { warn: () => {}, } -export const testAgentRuntimeImpl: AgentRuntimeDeps = { - logger: testLogger, - +export const TEST_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ + // Database startAgentRun: async () => 'test-agent-run-id', finishAgentRun: async () => {}, addAgentStep: async () => 'test-agent-step-id', - // promptAiSdkStream: async function* () { - // throw new Error('promptAiSdkStream not implemented in test runtime') - // }, -} + + // LLM + promptAiSdkStream: async function* () { + throw new Error('promptAiSdkStream not implemented in test runtime') + }, + promptAiSdk: async function () { + throw new Error('promptAiSdk not implemented in test runtime') + }, + promptAiSdkStructured: async function () { + throw new Error('promptAiSdkStructured not implemented in test runtime') + }, + + // Other + logger: testLogger, +}) diff --git a/common/src/types/contracts/agent-runtime.ts b/common/src/types/contracts/agent-runtime.ts index b3431f3dfc..65d9358b2f 100644 --- a/common/src/types/contracts/agent-runtime.ts +++ b/common/src/types/contracts/agent-runtime.ts @@ -3,14 +3,24 @@ import type { FinishAgentRunFn, StartAgentRunFn, } from './database' +import type { + PromptAiSdkFn, + PromptAiSdkStreamFn, + PromptAiSdkStructuredFn, +} from './llm' import type { Logger } from './logger' export type AgentRuntimeDeps = { - logger: Logger - + // Database startAgentRun: StartAgentRunFn finishAgentRun: FinishAgentRunFn addAgentStep: AddAgentStepFn - // promptAiSdkStream: PromptAiSdkStreamFn + // LLM + promptAiSdkStream: PromptAiSdkStreamFn + promptAiSdk: PromptAiSdkFn + promptAiSdkStructured: PromptAiSdkStructuredFn + + // Other + logger: Logger } diff --git a/common/src/types/contracts/llm.ts b/common/src/types/contracts/llm.ts index 1846d48dfd..0a2ecb9ccb 100644 --- a/common/src/types/contracts/llm.ts +++ b/common/src/types/contracts/llm.ts @@ -1,8 +1,9 @@ import type { ParamsExcluding } from '../function-params' import type { Logger } from './logger' -import type { Message } from '../messages/codebuff-message' -import type { streamText } from 'ai' import type { Model } from '../../old-constants' +import type { Message } from '../messages/codebuff-message' +import type { generateText, streamText } from 'ai' +import type z from 'zod/v4' export type StreamChunk = | { @@ -32,3 +33,43 @@ export type PromptAiSdkStreamFn = ( logger: Logger } & ParamsExcluding, ) => AsyncGenerator + +export type PromptAiSdkFn = ( + params: { + messages: Message[] + clientSessionId: string + fingerprintId: string + userInputId: string + model: Model + userId: string | undefined + chargeUser?: boolean + agentId?: string + onCostCalculated?: (credits: number) => Promise + includeCacheControl?: boolean + maxRetries?: number + logger: Logger + } & ParamsExcluding, +) => Promise + +export type PromptAiSdkStructuredInput = { + messages: Message[] + schema: z.ZodType + clientSessionId: string + fingerprintId: string + userInputId: string + model: Model + userId: string | undefined + maxTokens?: number + temperature?: number + timeout?: number + chargeUser?: boolean + agentId?: string + onCostCalculated?: (credits: number) => Promise + includeCacheControl?: boolean + maxRetries?: number + logger: Logger +} +export type PromptAiSdkStructuredOutput = Promise +export type PromptAiSdkStructuredFn = ( + params: PromptAiSdkStructuredInput, +) => PromptAiSdkStructuredOutput diff --git a/evals/impl/agent-runtime.ts b/evals/impl/agent-runtime.ts index 25e9a6ac8a..373e8b637d 100644 --- a/evals/impl/agent-runtime.ts +++ b/evals/impl/agent-runtime.ts @@ -1,12 +1,19 @@ import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -export const evalAgentRuntimeImpl: AgentRuntimeDeps = { +export const EVALS_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ logger: console, startAgentRun: async () => 'test-agent-run-id', finishAgentRun: async () => {}, addAgentStep: async () => 'test-agent-step-id', - // promptAiSdkStream: async function* () { - // throw new Error('promptAiSdkStream not implemented in eval runtime') - // }, -} + + promptAiSdkStream: async function* () { + throw new Error('promptAiSdkStream not implemented in eval runtime') + }, + promptAiSdk: async function () { + throw new Error('promptAiSdk not implemented in eval runtime') + }, + promptAiSdkStructured: async function () { + throw new Error('promptAiSdkStructured not implemented in eval runtime') + }, +}) diff --git a/evals/scaffolding.ts b/evals/scaffolding.ts index 01911eb377..7e67347525 100644 --- a/evals/scaffolding.ts +++ b/evals/scaffolding.ts @@ -14,7 +14,7 @@ import { getSystemInfo } from '@codebuff/npm-app/utils/system-info' import { mock } from 'bun:test' import { blue } from 'picocolors' -import { evalAgentRuntimeImpl } from './impl/agent-runtime' +import { EVALS_AGENT_RUNTIME_IMPL } from './impl/agent-runtime' import { getAllFilePaths, getProjectFileTree, @@ -182,7 +182,7 @@ export async function runAgentStepScaffolding( }) const result = await runAgentStep({ - ...evalAgentRuntimeImpl, + ...EVALS_AGENT_RUNTIME_IMPL, ws: mockWs, userId: TEST_USER_ID, userInputId: generateCompactId(), diff --git a/knowledge.md b/knowledge.md index f04efcb543..1e598319b3 100644 --- a/knowledge.md +++ b/knowledge.md @@ -149,12 +149,13 @@ afterEach(() => { ```typescript // From main-prompt.test.ts - Mocking LLM APIs -spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), -) -spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* () { - yield 'Test response' -}) +agentRuntimeImpl.promptAiSdk = async function () { + return 'Test response' +} +agentRuntimeImpl.promptAiSdkStream = async function* () { + yield { type: 'text' as const, text: 'Test response' } + return 'mock-message-id' +} // From rage-detector.test.ts - Mocking Date spyOn(Date, 'now').mockImplementation(() => currentTime) diff --git a/scripts/ft-file-selection/grade-traces.ts b/scripts/ft-file-selection/grade-traces.ts index 7e7fdcedd3..b54d477240 100644 --- a/scripts/ft-file-selection/grade-traces.ts +++ b/scripts/ft-file-selection/grade-traces.ts @@ -1,3 +1,4 @@ +import { promptAiSdk } from '@codebuff/backend/llm-apis/vercel-ai-sdk/ai-sdk' import { getTracesAndRelabelsForUser, setupBigQuery } from '@codebuff/bigquery' import { gradeRun } from '../../backend/src/admin/grade-runs' @@ -46,7 +47,11 @@ async function gradeTraces({ logger }: { logger: Logger }) { batch.map(async (traceAndRelabels) => { try { console.log(`Grading trace ${traceAndRelabels.trace.id}`) - const result = await gradeRun({ ...traceAndRelabels, logger }) + const result = await gradeRun({ + ...traceAndRelabels, + promptAiSdk, + logger, + }) return { traceId: traceAndRelabels.trace.id, status: 'success', diff --git a/scripts/ft-file-selection/relabel-traces-with-context.ts b/scripts/ft-file-selection/relabel-traces-with-context.ts index 144f929c0b..0e4c63c39c 100644 --- a/scripts/ft-file-selection/relabel-traces-with-context.ts +++ b/scripts/ft-file-selection/relabel-traces-with-context.ts @@ -1,3 +1,4 @@ +import { promptAiSdk } from '@codebuff/backend/llm-apis/vercel-ai-sdk/ai-sdk' import { getTracesAndAllDataForUser, setupBigQuery } from '@codebuff/bigquery' import { models } from '@codebuff/common/old-constants' @@ -95,6 +96,7 @@ async function runTraces() { fileBlobs, model: MODEL_TO_TEST, dataset: DATASET, + promptAiSdk, logger: console, }) console.log(`Successfully stored relabel for trace ${trace.id}`) diff --git a/scripts/ft-file-selection/relabel-traces.ts b/scripts/ft-file-selection/relabel-traces.ts index 5ce518477f..f9b296e834 100644 --- a/scripts/ft-file-selection/relabel-traces.ts +++ b/scripts/ft-file-selection/relabel-traces.ts @@ -82,6 +82,7 @@ async function runTraces() { userInputId: 'relabel-trace-run', userId: 'relabel-trace-run', logger: console, + promptAiSdk, }) }