From 1f0554b9fd520adea95678b7fef1c9f85d5f8f5d Mon Sep 17 00:00:00 2001 From: Charles Lien Date: Thu, 9 Oct 2025 16:30:55 -0700 Subject: [PATCH 1/6] re-implement ai sdk changes --- .../cost-aggregation.integration.test.ts | 154 +++++++++--------- .../src/__tests__/cost-aggregation.test.ts | 8 +- .../__tests__/generate-diffs-prompt.test.ts | 18 +- .../src/__tests__/loop-agent-steps.test.ts | 60 +++---- .../__tests__/main-prompt.integration.test.ts | 6 +- backend/src/__tests__/main-prompt.test.ts | 30 ++-- .../src/__tests__/malformed-tool-call.test.ts | 16 +- .../prompt-caching-subagents.test.ts | 56 +++---- backend/src/__tests__/read-docs-tool.test.ts | 65 +++----- .../__tests__/run-agent-step-tools.test.ts | 47 +++--- .../spawn-agents-message-history.test.ts | 10 +- .../spawn-agents-permissions.test.ts | 30 ++-- .../src/__tests__/subagent-streaming.test.ts | 6 +- backend/src/__tests__/web-search-tool.test.ts | 43 ++--- backend/src/impl/agent-runtime.ts | 8 +- backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts | 35 +--- backend/src/prompt-agent-stream.ts | 9 +- backend/src/run-agent-step.ts | 12 +- backend/src/tools/stream-parser.ts | 6 +- backend/src/websockets/middleware.ts | 4 +- backend/src/xml-stream-parser.ts | 4 +- common/src/testing/impl/agent-runtime.ts | 18 +- common/src/types/contracts/agent-runtime.ts | 10 +- evals/impl/agent-runtime.ts | 11 +- evals/scaffolding.ts | 4 +- knowledge.md | 3 - 26 files changed, 316 insertions(+), 357 deletions(-) diff --git a/backend/src/__tests__/cost-aggregation.integration.test.ts b/backend/src/__tests__/cost-aggregation.integration.test.ts index 673343c3fd..18216240c1 100644 --- a/backend/src/__tests__/cost-aggregation.integration.test.ts +++ b/backend/src/__tests__/cost-aggregation.integration.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { spyOn, @@ -12,12 +12,12 @@ import { } from 'bun:test' import * as messageCostTracker from '../llm-apis/message-cost-tracker' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { mainPrompt } from '../main-prompt' import * as agentRegistry from '../templates/agent-registry' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' @@ -99,6 +99,7 @@ class MockWebSocket { describe('Cost Aggregation Integration Tests', () => { let mockLocalAgentTemplates: Record let mockWebSocket: MockWebSocket + let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } beforeEach(async () => { mockWebSocket = new MockWebSocket() @@ -150,33 +151,31 @@ describe('Cost Aggregation Integration Tests', () => { // Mock LLM streaming let callCount = 0 const creditHistory: number[] = [] - spyOn(aisdk, 'promptAiSdkStream').mockImplementation( - async function* (options) { - callCount++ - const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs - creditHistory.push(credits) - - if (options.onCostCalculated) { - await options.onCostCalculated(credits) - } + agentRuntimeImpl.promptAiSdkStream = async function* (options) { + callCount++ + const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs + creditHistory.push(credits) - // Simulate different responses based on call - if (callCount === 1) { - // Main agent spawns a subagent - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n', - } - } else { - // Subagent writes a file - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n', - } + if (options.onCostCalculated) { + await options.onCostCalculated(credits) + } + + // Simulate different responses based on call + if (callCount === 1) { + // Main agent spawns a subagent + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n', } - return 'mock-message-id' - }, - ) + } else { + // Subagent writes a file + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n', + } + } + return 'mock-message-id' + } // Mock tool call execution spyOn(websocketAction, 'requestToolCall').mockImplementation( @@ -231,6 +230,7 @@ describe('Cost Aggregation Integration Tests', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) it('should correctly aggregate costs across the entire main prompt flow', async () => { @@ -250,7 +250,7 @@ describe('Cost Aggregation Integration Tests', () => { } const result = await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -285,7 +285,7 @@ describe('Cost Aggregation Integration Tests', () => { // Call through websocket action handler to test full integration await websocketAction.callMainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -308,37 +308,35 @@ describe('Cost Aggregation Integration Tests', () => { it('should handle multi-level subagent hierarchies correctly', async () => { // Mock a more complex scenario with nested subagents let callCount = 0 - spyOn(aisdk, 'promptAiSdkStream').mockImplementation( - async function* (options) { - callCount++ + agentRuntimeImpl.promptAiSdkStream = async function* (options) { + callCount++ - if (options.onCostCalculated) { - await options.onCostCalculated(5) // Each call costs 5 credits - } + if (options.onCostCalculated) { + await options.onCostCalculated(5) // Each call costs 5 credits + } - if (callCount === 1) { - // Main agent spawns first-level subagent - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n', - } - } else if (callCount === 2) { - // First-level subagent spawns second-level subagent - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n', - } - } else { - // Second-level subagent does actual work - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n', - } + if (callCount === 1) { + // Main agent spawns first-level subagent + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n', + } + } else if (callCount === 2) { + // First-level subagent spawns second-level subagent + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n', } + } else { + // Second-level subagent does actual work + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n', + } + } - return 'mock-message-id' - }, - ) + return 'mock-message-id' + } const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.stepsRemaining = 10 @@ -355,7 +353,7 @@ describe('Cost Aggregation Integration Tests', () => { } const result = await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -373,29 +371,27 @@ describe('Cost Aggregation Integration Tests', () => { it('should maintain cost integrity when subagents fail', async () => { // Mock scenario where subagent fails after incurring partial costs let callCount = 0 - spyOn(aisdk, 'promptAiSdkStream').mockImplementation( - async function* (options) { - callCount++ + agentRuntimeImpl.promptAiSdkStream = async function* (options) { + callCount++ - if (options.onCostCalculated) { - await options.onCostCalculated(6) // Each call costs 6 credits - } + if (options.onCostCalculated) { + await options.onCostCalculated(6) // Each call costs 6 credits + } - if (callCount === 1) { - // Main agent spawns subagent - yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n', - } - } else { - // Subagent fails after incurring cost - yield { type: 'text' as const, text: 'Some response' } - throw new Error('Subagent execution failed') + if (callCount === 1) { + // Main agent spawns subagent + yield { + type: 'text' as const, + text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n', } + } else { + // Subagent fails after incurring cost + yield { type: 'text' as const, text: 'Some response' } + throw new Error('Subagent execution failed') + } - return 'mock-message-id' - }, - ) + return 'mock-message-id' + } const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.agentType = 'base' @@ -413,7 +409,7 @@ describe('Cost Aggregation Integration Tests', () => { let result try { result = await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -462,7 +458,7 @@ describe('Cost Aggregation Integration Tests', () => { } await mainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -502,7 +498,7 @@ describe('Cost Aggregation Integration Tests', () => { // Call through websocket action to test server-side reset await websocketAction.callMainPrompt({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: mockWebSocket as unknown as WebSocket, action, userId: TEST_USER_ID, diff --git a/backend/src/__tests__/cost-aggregation.test.ts b/backend/src/__tests__/cost-aggregation.test.ts index e79d8a4057..bef5ab372d 100644 --- a/backend/src/__tests__/cost-aggregation.test.ts +++ b/backend/src/__tests__/cost-aggregation.test.ts @@ -1,4 +1,4 @@ -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialAgentState, getInitialSessionState, @@ -180,7 +180,7 @@ describe('Cost Aggregation System', () => { } const result = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall: mockToolCall, fileContext: mockFileContext, @@ -260,7 +260,7 @@ describe('Cost Aggregation System', () => { } const result = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall: mockToolCall, fileContext: mockFileContext, @@ -417,7 +417,7 @@ describe('Cost Aggregation System', () => { } const result = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall: mockToolCall, fileContext: mockFileContext, diff --git a/backend/src/__tests__/generate-diffs-prompt.test.ts b/backend/src/__tests__/generate-diffs-prompt.test.ts index 3095094dcd..e61ca1329f 100644 --- a/backend/src/__tests__/generate-diffs-prompt.test.ts +++ b/backend/src/__tests__/generate-diffs-prompt.test.ts @@ -1,16 +1,8 @@ +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { expect, describe, it } from 'bun:test' import { parseAndGetDiffBlocksSingleFile } from '../generate-diffs-prompt' -import type { Logger } from '@codebuff/common/types/contracts/logger' - -const logger: Logger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {}, -} - describe('parseAndGetDiffBlocksSingleFile', () => { it('should parse diff blocks with newline before closing marker', () => { const oldContent = 'function test() {\n return true;\n}\n' @@ -26,9 +18,9 @@ function test() { >>>>>>> REPLACE` const result = parseAndGetDiffBlocksSingleFile({ + ...TEST_AGENT_RUNTIME_IMPL, newContent, oldFileContent: oldContent, - logger, }) console.log(JSON.stringify({ result })) @@ -55,9 +47,9 @@ function test() { }>>>>>>> REPLACE` const result = parseAndGetDiffBlocksSingleFile({ + ...TEST_AGENT_RUNTIME_IMPL, newContent, oldFileContent: oldContent, - logger, }) expect(result.diffBlocks.length).toBe(1) @@ -108,9 +100,9 @@ function subtract(a, b) { >>>>>>> REPLACE` const result = parseAndGetDiffBlocksSingleFile({ + ...TEST_AGENT_RUNTIME_IMPL, newContent, oldFileContent: oldContent, - logger, }) expect(result.diffBlocks.length).toBe(2) @@ -136,9 +128,9 @@ function subtract(a, b) { >>>>>>> REPLACE` const result = parseAndGetDiffBlocksSingleFile({ + ...TEST_AGENT_RUNTIME_IMPL, newContent, oldFileContent: oldContent, - logger, }) expect(result.diffBlocks.length).toBe(1) diff --git a/backend/src/__tests__/loop-agent-steps.test.ts b/backend/src/__tests__/loop-agent-steps.test.ts index 379dec178a..226907c018 100644 --- a/backend/src/__tests__/loop-agent-steps.test.ts +++ b/backend/src/__tests__/loop-agent-steps.test.ts @@ -1,7 +1,7 @@ import * as analytics from '@codebuff/common/analytics' import db from '@codebuff/common/db' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { clearMockedModules, mockModule, @@ -25,15 +25,17 @@ import { withAppContext } from '../context/app-context' import { loopAgentSteps } from '../run-agent-step' import { clearAgentGeneratorCache } from '../run-programmatic-step' import { mockFileContext, MockWebSocket } from './test-utils' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import type { getAgentTemplate } from '../templates/agent-registry' import type { AgentTemplate } from '../templates/types' import type { StepGenerator } from '@codebuff/common/types/agent-template' -import type { AgentState } from '@codebuff/common/types/session-state' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { ParamsOf } from '@codebuff/common/types/function-params' +import type { AgentState } from '@codebuff/common/types/session-state' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => { let mockTemplate: AgentTemplate let mockAgentState: AgentState @@ -94,8 +96,6 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => }) beforeEach(() => { - clearAgentGeneratorCache(testAgentRuntimeImpl) - llmCallCount = 0 // Setup spies for database operations @@ -113,14 +113,14 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => })), } as any) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallCount++ yield { type: 'text' as const, text: `LLM response\n\n${getToolCallString('end_turn', {})}`, } return 'mock-message-id' - }) + } // Mock analytics spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) @@ -165,9 +165,9 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => }) afterEach(() => { - clearAgentGeneratorCache(testAgentRuntimeImpl) - + clearAgentGeneratorCache(agentRuntimeImpl) mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) afterAll(() => { @@ -199,7 +199,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -245,7 +245,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -293,7 +293,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -340,7 +340,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -380,7 +380,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -412,7 +412,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -446,7 +446,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -497,7 +497,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -546,7 +546,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -598,7 +598,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => let llmCallNumber = 0 let capturedAgentState: AgentState | null = null - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallNumber++ if (llmCallNumber === 1) { // First call: agent tries to end turn without setting output @@ -627,13 +627,13 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } } return 'mock-message-id' - }) + } mockAgentState.output = undefined capturedAgentState = mockAgentState const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -689,7 +689,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => let llmCallNumber = 0 let capturedAgentState: AgentState | null = null - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallNumber++ // Agent sets output correctly on first call if (capturedAgentState) { @@ -700,13 +700,13 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => text: `Setting output\n\n${getToolCallString('set_output', { result: 'success' })}\n\n${getToolCallString('end_turn', {})}`, } return 'mock-message-id' - }) + } mockAgentState.output = undefined capturedAgentState = mockAgentState const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -742,17 +742,17 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } let llmCallNumber = 0 - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallNumber++ yield { type: 'text' as const, text: `Response without output\n\n${getToolCallString('end_turn', {})}`, } return 'mock-message-id' - }) + } const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', @@ -795,7 +795,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => let llmCallNumber = 0 let capturedAgentState: AgentState | null = null - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallNumber++ if (llmCallNumber === 1) { // First call: agent does some work but doesn't end turn @@ -814,13 +814,13 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => } } return 'mock-message-id' - }) + } mockAgentState.output = undefined capturedAgentState = mockAgentState const result = await runLoopAgentStepsWithContext({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userInputId: 'test-user-input', agentType: 'test-agent', diff --git a/backend/src/__tests__/main-prompt.integration.test.ts b/backend/src/__tests__/main-prompt.integration.test.ts index e8b433a3bd..0f13ba63b4 100644 --- a/backend/src/__tests__/main-prompt.integration.test.ts +++ b/backend/src/__tests__/main-prompt.integration.test.ts @@ -1,7 +1,7 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' // Mock imports needed for setup within the test -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { @@ -377,7 +377,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { } const { output, sessionState: finalSessionState } = await mainPrompt({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -459,7 +459,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { } await mainPrompt({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, diff --git a/backend/src/__tests__/main-prompt.test.ts b/backend/src/__tests__/main-prompt.test.ts index 5dae356481..25577a880f 100644 --- a/backend/src/__tests__/main-prompt.test.ts +++ b/backend/src/__tests__/main-prompt.test.ts @@ -1,7 +1,7 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { AgentTemplateTypes, @@ -22,21 +22,23 @@ import * as checkTerminalCommandModule from '../check-terminal-command' import * as requestFilesPrompt from '../find-files/request-files-prompt' import * as getDocumentationForQueryModule from '../get-documentation-for-query' import * as liveUserInputs from '../live-user-inputs' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { mainPrompt } from '../main-prompt' import * as processFileBlockModule from '../process-file-block' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '@codebuff/common/types/agent-template' -import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + const mockAgentStream = (streamOutput: string) => { - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: streamOutput } return 'mock-message-id' - }) + } } describe('mainPrompt', () => { @@ -109,9 +111,6 @@ describe('mainPrompt', () => { ) // Mock LLM APIs - spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) mockAgentStream('Test response') // Mock websocket actions @@ -177,6 +176,7 @@ describe('mainPrompt', () => { afterEach(() => { // Clear all mocks after each test mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) class MockWebSocket { @@ -231,13 +231,13 @@ describe('mainPrompt', () => { } const { sessionState: newSessionState, output } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) // Verify that requestToolCall was called with the terminal command @@ -292,6 +292,7 @@ describe('mainPrompt', () => { } await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -331,7 +332,6 @@ describe('mainPrompt', () => { stepPrompt: '', }, }, - ...testAgentRuntimeImpl, }) // Assert that requestToolCall was called exactly once @@ -371,13 +371,13 @@ describe('mainPrompt', () => { } const { output } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) expect(output.type).toBeDefined() // Output should exist @@ -398,13 +398,13 @@ describe('mainPrompt', () => { } const { sessionState: newSessionState } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) // When there's a new prompt, consecutiveAssistantMessages should be set to 1 @@ -429,13 +429,13 @@ describe('mainPrompt', () => { } const { sessionState: newSessionState } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) // When there's no new prompt, consecutiveAssistantMessages should increment by 1 @@ -458,13 +458,13 @@ describe('mainPrompt', () => { } const { output } = await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) expect(output.type).toBeDefined() // Output should exist even for empty response @@ -498,13 +498,13 @@ describe('mainPrompt', () => { } await mainPrompt({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, localAgentTemplates: mockLocalAgentTemplates, - ...testAgentRuntimeImpl, }) // Assert that requestToolCall was called exactly once diff --git a/backend/src/__tests__/malformed-tool-call.test.ts b/backend/src/__tests__/malformed-tool-call.test.ts index bafc933a4a..eec1caed18 100644 --- a/backend/src/__tests__/malformed-tool-call.test.ts +++ b/backend/src/__tests__/malformed-tool-call.test.ts @@ -1,7 +1,7 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import * as stringUtils from '@codebuff/common/util/string' @@ -53,7 +53,7 @@ describe('malformed tool call error handling', () => { // Mock analytics and tracing spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) - analytics.initAnalytics(testAgentRuntimeImpl) + analytics.initAnalytics(TEST_AGENT_RUNTIME_IMPL) spyOn(analytics, 'trackEvent').mockImplementation(() => {}) spyOn(bigquery, 'insertTrace').mockImplementation(() => Promise.resolve(true), @@ -108,7 +108,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -165,7 +165,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -212,7 +212,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -263,7 +263,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -316,7 +316,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', @@ -371,7 +371,7 @@ describe('malformed tool call error handling', () => { const agentState = sessionState.mainAgentState const result = await processStreamWithTools({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, stream, ws: mockWs as unknown as WebSocket, agentStepId: 'test-step', diff --git a/backend/src/__tests__/prompt-caching-subagents.test.ts b/backend/src/__tests__/prompt-caching-subagents.test.ts index 7eb91f4921..c22db30775 100644 --- a/backend/src/__tests__/prompt-caching-subagents.test.ts +++ b/backend/src/__tests__/prompt-caching-subagents.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { spyOn, @@ -11,11 +11,11 @@ import { mock, } from 'bun:test' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { loopAgentSteps } from '../run-agent-step' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' @@ -56,6 +56,7 @@ class MockWebSocket { describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { let mockLocalAgentTemplates: Record let capturedMessages: Message[] = [] + let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } beforeEach(() => { capturedMessages = [] @@ -97,24 +98,22 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } // Mock LLM API to capture messages and end turn immediately - spyOn(aisdk, 'promptAiSdkStream').mockImplementation( - async function* (options) { - // Capture the messages sent to the LLM - capturedMessages = options.messages - - // Simulate immediate end turn - yield { - type: 'text' as const, - text: 'Test response', - } + agentRuntimeImpl.promptAiSdkStream = async function* (options) { + // Capture the messages sent to the LLM + capturedMessages = options.messages - if (options.onCostCalculated) { - await options.onCostCalculated(1) - } + // Simulate immediate end turn + yield { + type: 'text' as const, + text: 'Test response', + } - return 'mock-message-id' - }, - ) + if (options.onCostCalculated) { + await options.onCostCalculated(1) + } + + return 'mock-message-id' + } // Mock file operations spyOn(websocketAction, 'requestFiles').mockImplementation( @@ -147,6 +146,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) it('should inherit parent system prompt when inheritParentSystemPrompt is true', async () => { @@ -155,6 +155,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent first to establish system prompt const parentResult = await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -167,7 +168,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) // Capture parent's messages which include the system prompt @@ -189,6 +189,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -202,7 +203,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) // Verify child uses parent's system prompt @@ -238,6 +238,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent first const parentResult = await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -250,7 +251,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) const parentMessages = capturedMessages @@ -266,6 +266,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -279,7 +280,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) const childMessages = capturedMessages @@ -316,6 +316,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent first await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -328,7 +329,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) const parentMessages = capturedMessages @@ -347,6 +347,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -360,7 +361,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) const childMessages = capturedMessages @@ -423,6 +423,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent const parentResult = await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -435,7 +436,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) const parentMessages = capturedMessages @@ -451,6 +451,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -464,7 +465,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) const childMessages = capturedMessages @@ -505,6 +505,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // Run parent agent first with some message history const parentResult = await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-parent', prompt: 'Parent task', @@ -523,7 +524,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, - ...testAgentRuntimeImpl, }) const parentMessages = capturedMessages @@ -542,6 +542,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { } await loopAgentSteps({ + ...agentRuntimeImpl, ws, userInputId: 'test-child', prompt: 'Child task', @@ -555,7 +556,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, - ...testAgentRuntimeImpl, }) const childMessages = capturedMessages diff --git a/backend/src/__tests__/read-docs-tool.test.ts b/backend/src/__tests__/read-docs-tool.test.ts index f4620d271f..ee9f712b07 100644 --- a/backend/src/__tests__/read-docs-tool.test.ts +++ b/backend/src/__tests__/read-docs-tool.test.ts @@ -1,7 +1,7 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { @@ -20,15 +20,17 @@ import * as requestFilesPrompt from '../find-files/request-files-prompt' import * as liveUserInputs from '../live-user-inputs' import { MockWebSocket, mockFileContext } from './test-utils' import * as context7Api from '../llm-apis/context7-api' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { runAgentStep } from '../run-agent-step' import { assembleLocalAgentTemplates } from '../templates/agent-registry' import * as websocketAction from '../websockets/websocket-action' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + function mockAgentStream(content: string | string[]) { - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { if (typeof content === 'string') { content = [content] } @@ -36,7 +38,7 @@ function mockAgentStream(content: string | string[]) { yield { type: 'text' as const, text: chunk } } return 'mock-message-id' - }) + } } describe('read_docs tool with researcher agent', () => { @@ -56,7 +58,7 @@ describe('read_docs tool with researcher agent', () => { name: 'analytics.initAnalytics', spy: analyticsInitSpy, }) - analytics.initAnalytics(testAgentRuntimeImpl) + analytics.initAnalytics(agentRuntimeImpl) const trackEventSpy = spyOn(analytics, 'trackEvent').mockImplementation( () => {}, @@ -121,12 +123,6 @@ describe('read_docs tool with researcher agent', () => { spy: sendActionSpy, }) - // Mock LLM APIs - const promptAiSdkSpy = spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) - mockedFunctions.push({ name: 'aisdk.promptAiSdk', spy: promptAiSdkSpy }) - // Mock other required modules const requestRelevantFilesSpy = spyOn( requestFilesPrompt, @@ -177,6 +173,7 @@ describe('read_docs tool with researcher agent', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) // MockWebSocket and mockFileContext imported from test-utils @@ -229,28 +226,6 @@ describe('read_docs tool with researcher agent', () => { await expect(insertTraceMock?.spy()).resolves.toBe(true) }) - test('async generator mock completes properly', async () => { - // Test that our async generator mock properly completes - const mockResponse = 'test response' - - mockAgentStream(mockResponse) - - const generator = aisdk.promptAiSdkStream({} as any) - const results = [] - - // Consume the generator - for await (const value of generator) { - results.push(value) - } - - // Should have yielded exactly one value and then completed - expect(results).toEqual([{ type: 'text', text: mockResponse }]) - - // Generator should be done - const { done } = await generator.next() - expect(done).toBe(true) - }) - test('should successfully fetch documentation with basic query', async () => { const mockDocumentation = 'React is a JavaScript library for building user interfaces...' @@ -286,12 +261,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -359,12 +334,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -409,12 +384,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -476,12 +451,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -542,12 +517,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -609,12 +584,12 @@ describe('read_docs tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, diff --git a/backend/src/__tests__/run-agent-step-tools.test.ts b/backend/src/__tests__/run-agent-step-tools.test.ts index 114340e2f7..840ab53df8 100644 --- a/backend/src/__tests__/run-agent-step-tools.test.ts +++ b/backend/src/__tests__/run-agent-step-tools.test.ts @@ -2,7 +2,7 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import db from '@codebuff/common/db' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { @@ -25,11 +25,13 @@ import { asUserMessage } from '../util/messages' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' describe('runAgentStep - set_output tool', () => { let testAgent: AgentTemplate + let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } beforeEach(async () => { // Create a test agent that supports set_output @@ -63,7 +65,7 @@ describe('runAgentStep - set_output tool', () => { // Mock analytics and tracing spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) - analytics.initAnalytics(testAgentRuntimeImpl) + analytics.initAnalytics(agentRuntimeImpl) spyOn(analytics, 'trackEvent').mockImplementation(() => {}) spyOn(bigquery, 'insertTrace').mockImplementation(() => Promise.resolve(true), @@ -107,15 +109,16 @@ describe('runAgentStep - set_output tool', () => { spyOn(aisdk, 'promptAiSdk').mockImplementation(() => Promise.resolve('Test response'), ) - clearAgentGeneratorCache(testAgentRuntimeImpl) + clearAgentGeneratorCache(agentRuntimeImpl) }) afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) afterAll(() => { - clearAgentGeneratorCache(testAgentRuntimeImpl) + clearAgentGeneratorCache(agentRuntimeImpl) }) class MockWebSocket { @@ -159,10 +162,10 @@ describe('runAgentStep - set_output tool', () => { '\n\n' + getToolCallString('end_turn', {}) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: mockResponse } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -171,6 +174,7 @@ describe('runAgentStep - set_output tool', () => { } const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -184,7 +188,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Analyze the codebase', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) expect(result.agentState.output).toEqual({ @@ -201,10 +204,10 @@ describe('runAgentStep - set_output tool', () => { findings: ['Bug in auth.ts', 'Missing validation'], }) + getToolCallString('end_turn', {}) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: mockResponse } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -213,6 +216,7 @@ describe('runAgentStep - set_output tool', () => { } const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -226,7 +230,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Analyze the codebase', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) expect(result.agentState.output).toEqual({ @@ -244,10 +247,10 @@ describe('runAgentStep - set_output tool', () => { existingField: 'updated value', }) + getToolCallString('end_turn', {}) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: mockResponse } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -261,6 +264,7 @@ describe('runAgentStep - set_output tool', () => { } const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -274,7 +278,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Update the output', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) expect(result.agentState.output).toEqual({ @@ -287,10 +290,10 @@ describe('runAgentStep - set_output tool', () => { const mockResponse = getToolCallString('set_output', {}) + getToolCallString('end_turn', {}) - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: mockResponse } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -300,6 +303,7 @@ describe('runAgentStep - set_output tool', () => { } const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -313,7 +317,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Update with empty object', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) // Should replace with empty object @@ -369,10 +372,10 @@ describe('runAgentStep - set_output tool', () => { ) // Mock the LLM stream to return a response that doesn't end the turn - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: 'Continuing with the analysis...' } // Non-empty response, no tool calls return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -396,6 +399,7 @@ describe('runAgentStep - set_output tool', () => { const initialMessageCount = agentState.messageHistory.length const result = await runAgentStep({ + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', @@ -409,7 +413,6 @@ describe('runAgentStep - set_output tool', () => { prompt: 'Test the handleSteps functionality', spawnParams: undefined, system: 'Test system prompt', - ...testAgentRuntimeImpl, }) // Should end turn because toolCalls.length === 0 && toolResults.length === 0 from LLM processing @@ -519,7 +522,7 @@ describe('runAgentStep - set_output tool', () => { } // Mock the LLM stream to spawn the inline agent - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { yield { type: 'text' as const, text: getToolCallString('spawn_agent_inline', { @@ -528,7 +531,7 @@ describe('runAgentStep - set_output tool', () => { }), } return 'mock-message-id' - }) + } const sessionState = getInitialSessionState(mockFileContext) const agentState = sessionState.mainAgentState @@ -556,7 +559,7 @@ describe('runAgentStep - set_output tool', () => { ] const result = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, userId: TEST_USER_ID, userInputId: 'test-input', diff --git a/backend/src/__tests__/spawn-agents-message-history.test.ts b/backend/src/__tests__/spawn-agents-message-history.test.ts index 60ac1a430b..25aed6a383 100644 --- a/backend/src/__tests__/spawn-agents-message-history.test.ts +++ b/backend/src/__tests__/spawn-agents-message-history.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { describe, @@ -107,7 +107,7 @@ describe('Spawn Agents Message History', () => { ] const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -179,7 +179,7 @@ describe('Spawn Agents Message History', () => { ] const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -216,7 +216,7 @@ describe('Spawn Agents Message History', () => { const mockMessages: Message[] = [] // Empty message history const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -256,7 +256,7 @@ describe('Spawn Agents Message History', () => { ] const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, diff --git a/backend/src/__tests__/spawn-agents-permissions.test.ts b/backend/src/__tests__/spawn-agents-permissions.test.ts index 72c29d94a7..3db04e4aa2 100644 --- a/backend/src/__tests__/spawn-agents-permissions.test.ts +++ b/backend/src/__tests__/spawn-agents-permissions.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { describe, @@ -235,7 +235,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('thinker') const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -269,7 +269,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('reviewer') // Try to spawn reviewer const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -305,7 +305,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('nonexistent') const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -343,7 +343,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('codebuff/thinker@1.0.0') const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -377,7 +377,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('thinker') // Simple name const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -414,7 +414,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createSpawnToolCall('codebuff/thinker@2.0.0') const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -462,7 +462,7 @@ describe('Spawn Agents Permissions', () => { } const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -517,7 +517,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('thinker') const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -549,7 +549,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('reviewer') // Try to spawn reviewer const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -582,7 +582,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('nonexistent') const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -614,7 +614,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('codebuff/thinker@1.0.0') const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -646,7 +646,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('thinker') // Simple name const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -681,7 +681,7 @@ describe('Spawn Agents Permissions', () => { const toolCall = createInlineSpawnToolCall('codebuff/thinker@2.0.0') const { result } = handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -713,7 +713,7 @@ describe('Spawn Agents Permissions', () => { expect(() => { handleSpawnAgentInline({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, diff --git a/backend/src/__tests__/subagent-streaming.test.ts b/backend/src/__tests__/subagent-streaming.test.ts index 494facdbb8..47500691f1 100644 --- a/backend/src/__tests__/subagent-streaming.test.ts +++ b/backend/src/__tests__/subagent-streaming.test.ts @@ -1,5 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { afterAll, @@ -146,7 +146,7 @@ describe('Subagent Streaming', () => { } const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, @@ -224,7 +224,7 @@ describe('Subagent Streaming', () => { } const { result } = handleSpawnAgents({ - ...testAgentRuntimeImpl, + ...TEST_AGENT_RUNTIME_IMPL, previousToolCallFinished: Promise.resolve(), toolCall, fileContext: mockFileContext, diff --git a/backend/src/__tests__/web-search-tool.test.ts b/backend/src/__tests__/web-search-tool.test.ts index c384896b2b..6190db3916 100644 --- a/backend/src/__tests__/web-search-tool.test.ts +++ b/backend/src/__tests__/web-search-tool.test.ts @@ -4,7 +4,7 @@ process.env.LINKUP_API_KEY = 'test-api-key' import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { @@ -28,10 +28,12 @@ import { runAgentStep } from '../run-agent-step' import { assembleLocalAgentTemplates } from '../templates/agent-registry' import * as websocketAction from '../websockets/websocket-action' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } function mockAgentStream(content: string | string[]) { - spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { + agentRuntimeImpl.promptAiSdkStream = async function* ({}) { if (typeof content === 'string') { content = [content] } @@ -39,14 +41,14 @@ function mockAgentStream(content: string | string[]) { yield { type: 'text' as const, text: chunk } } return 'mock-message-id' - }) + } } describe('web_search tool with researcher agent', () => { beforeEach(() => { // Mock analytics and tracing spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) - analytics.initAnalytics(testAgentRuntimeImpl) + analytics.initAnalytics(agentRuntimeImpl) spyOn(analytics, 'trackEvent').mockImplementation(() => {}) spyOn(bigquery, 'insertTrace').mockImplementation(() => Promise.resolve(true), @@ -84,6 +86,7 @@ describe('web_search tool with researcher agent', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) // MockWebSocket and mockFileContext imported from test-utils @@ -114,12 +117,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -164,12 +167,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -223,12 +226,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -267,12 +270,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -325,12 +328,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -382,12 +385,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -429,12 +432,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, @@ -488,12 +491,12 @@ describe('web_search tool with researcher agent', () => { agentType: 'researcher' as const, } const { agentTemplates } = assembleLocalAgentTemplates({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, fileContext: mockFileContextWithAgents, }) const { agentState: newAgentState } = await runAgentStep({ - ...testAgentRuntimeImpl, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, diff --git a/backend/src/impl/agent-runtime.ts b/backend/src/impl/agent-runtime.ts index a766de3ea7..42894d3b9a 100644 --- a/backend/src/impl/agent-runtime.ts +++ b/backend/src/impl/agent-runtime.ts @@ -1,13 +1,15 @@ import { addAgentStep, finishAgentRun, startAgentRun } from '../agent-run' +import { promptAiSdkStream } from '../llm-apis/vercel-ai-sdk/ai-sdk' import { logger } from '../util/logger' import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -export const backendAgentRuntimeImpl: AgentRuntimeDeps = { +export const BACKEND_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ logger, startAgentRun, finishAgentRun, addAgentStep, - // promptAiSdkStream, -} + + promptAiSdkStream, +}) diff --git a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts index c2f57e10e4..1f0cbf101b 100644 --- a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts +++ b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts @@ -17,8 +17,12 @@ import { openRouterLanguageModel } from '../openrouter' import { vertexFinetuned } from './vertex-finetuned' import type { Model, OpenAIModel } from '@codebuff/common/old-constants' -import type { ParamsExcluding } from '@codebuff/common/types/function-params' +import type { PromptAiSdkStreamFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { + ParamsExcluding, + ParamsOf, +} from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { OpenRouterProviderOptions, @@ -27,17 +31,6 @@ import type { import type { LanguageModel } from 'ai' import type { z } from 'zod/v4' -export type StreamChunk = - | { - type: 'text' - text: string - } - | { - type: 'reasoning' - text: string - } - | { type: 'error'; message: string } - // TODO: We'll want to add all our models here! const modelToAiSDKModel = (model: Model): LanguageModel => { if ( @@ -61,22 +54,8 @@ const modelToAiSDKModel = (model: Model): LanguageModel => { // also take an array of form [{model: Model, retries: number}, {model: Model, retries: number}...] // eg: [{model: "gemini-2.0-flash-001"}, {model: "vertex/gemini-2.0-flash-001"}, {model: "claude-3-5-haiku", retries: 3}] export const promptAiSdkStream = async function* ( - params: { - messages: Message[] - clientSessionId: string - fingerprintId: string - model: Model - userId: string | undefined - chargeUser?: boolean - thinkingBudget?: number - userInputId: string - agentId?: string - maxRetries?: number - onCostCalculated?: (credits: number) => Promise - includeCacheControl?: boolean - logger: Logger - } & ParamsExcluding, -): AsyncGenerator { + params: ParamsOf, +): ReturnType { const { logger } = params if ( !checkLiveUserInput({ ...params, clientSessionId: params.clientSessionId }) diff --git a/backend/src/prompt-agent-stream.ts b/backend/src/prompt-agent-stream.ts index bad7e3aca8..fa8239887d 100644 --- a/backend/src/prompt-agent-stream.ts +++ b/backend/src/prompt-agent-stream.ts @@ -1,12 +1,13 @@ import { providerModelNames } from '@codebuff/common/old-constants' -import { promptAiSdkStream } from './llm-apis/vercel-ai-sdk/ai-sdk' import { globalStopSequence } from './tools/constants' import type { AgentTemplate } from './templates/types' +import type { PromptAiSdkStreamFn } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsOf } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { OpenRouterProviderOptions } from '@codebuff/internal/openrouter-ai-sdk' -import type { Logger } from '@codebuff/common/types/contracts/logger' export const getAgentStreamFromTemplate = (params: { clientSessionId: string @@ -19,6 +20,7 @@ export const getAgentStreamFromTemplate = (params: { template: AgentTemplate logger: Logger + promptAiSdkStream: PromptAiSdkStreamFn }) => { const { clientSessionId, @@ -30,6 +32,7 @@ export const getAgentStreamFromTemplate = (params: { includeCacheControl, template, logger, + promptAiSdkStream, } = params if (!template) { @@ -39,7 +42,7 @@ export const getAgentStreamFromTemplate = (params: { const { model } = template const getStream = (messages: Message[]) => { - const aiSdkStreamParams: Parameters[0] = { + const aiSdkStreamParams: ParamsOf = { messages, model, stopSequences: [globalStopSequence], diff --git a/backend/src/run-agent-step.ts b/backend/src/run-agent-step.ts index a2490a7c17..92746d5ff8 100644 --- a/backend/src/run-agent-step.ts +++ b/backend/src/run-agent-step.ts @@ -103,7 +103,11 @@ export const runAgentStep = async ( | 'agentTemplate' | 'agentContext' | 'fullResponse' - >, + > & + ParamsExcluding< + typeof getAgentStreamFromTemplate, + 'agentId' | 'template' | 'onCostCalculated' | 'includeCacheControl' + >, ): Promise<{ agentState: AgentState fullResponse: string @@ -246,10 +250,7 @@ export const runAgentStep = async ( const { model } = agentTemplate const { getStream } = getAgentStreamFromTemplate({ - clientSessionId, - fingerprintId, - userInputId, - userId, + ...params, agentId: agentState.agentId, template: agentTemplate, onCostCalculated: async (credits: number) => { @@ -270,7 +271,6 @@ export const runAgentStep = async ( } }, includeCacheControl: supportsCacheControl(agentTemplate.model), - logger, }) const iterationNum = agentState.messageHistory.length diff --git a/backend/src/tools/stream-parser.ts b/backend/src/tools/stream-parser.ts index 49fa0fcc2f..b406a84af8 100644 --- a/backend/src/tools/stream-parser.ts +++ b/backend/src/tools/stream-parser.ts @@ -16,10 +16,12 @@ import { executeCustomToolCall, executeToolCall } from './tool-executor' import type { BatchStrReplaceState } from './batch-str-replace' import type { CustomToolCall, ExecuteToolCallParams } from './tool-executor' -import type { StreamChunk } from '../llm-apis/vercel-ai-sdk/ai-sdk' import type { AgentTemplate } from '../templates/types' import type { ToolName } from '@codebuff/common/tools/constants' import type { CodebuffToolCall } from '@codebuff/common/tools/list' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { Message, ToolMessage, @@ -28,8 +30,6 @@ import type { ToolResultPart } from '@codebuff/common/types/messages/content-par import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState, Subgoal } from '@codebuff/common/types/session-state' import type { ProjectFileContext } from '@codebuff/common/util/file' -import type { ParamsExcluding } from '@codebuff/common/types/function-params' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ToolCallPart } from 'ai' import type { WebSocket } from 'ws' diff --git a/backend/src/websockets/middleware.ts b/backend/src/websockets/middleware.ts index 1f95ae62ce..ae0528c168 100644 --- a/backend/src/websockets/middleware.ts +++ b/backend/src/websockets/middleware.ts @@ -16,7 +16,7 @@ import { getUserInfoFromAuthToken } from './auth' import { updateRequestContext } from './request-context' import { sendAction } from './websocket-action' import { withAppContext } from '../context/app-context' -import { backendAgentRuntimeImpl } from '../impl/agent-runtime' +import { BACKEND_AGENT_RUNTIME_IMPL } from '../impl/agent-runtime' import { checkAuth } from '../util/check-auth' import type { UserInfo } from './auth' @@ -164,7 +164,7 @@ export class WebSocketMiddleware { } } -export const protec = new WebSocketMiddleware(backendAgentRuntimeImpl) +export const protec = new WebSocketMiddleware(BACKEND_AGENT_RUNTIME_IMPL) protec.use(async ({ action, clientSessionId, logger }) => checkAuth({ diff --git a/backend/src/xml-stream-parser.ts b/backend/src/xml-stream-parser.ts index 4956cb9e32..9fefc26ef7 100644 --- a/backend/src/xml-stream-parser.ts +++ b/backend/src/xml-stream-parser.ts @@ -7,14 +7,14 @@ import { toolNameParam, } from '@codebuff/common/tools/constants' -import type { StreamChunk } from './llm-apis/vercel-ai-sdk/ai-sdk' import type { Model } from '@codebuff/common/old-constants' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' import type { PrintModeError, PrintModeText, PrintModeToolCall, } from '@codebuff/common/types/print-mode' -import type { Logger } from '@codebuff/common/types/contracts/logger' const toolExtractionPattern = new RegExp( `${startToolTag}(.*?)${endToolTag}`, diff --git a/common/src/testing/impl/agent-runtime.ts b/common/src/testing/impl/agent-runtime.ts index 78b57edbae..d4d87cbdd9 100644 --- a/common/src/testing/impl/agent-runtime.ts +++ b/common/src/testing/impl/agent-runtime.ts @@ -8,13 +8,17 @@ export const testLogger: Logger = { warn: () => {}, } -export const testAgentRuntimeImpl: AgentRuntimeDeps = { - logger: testLogger, - +export const TEST_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ + // Database startAgentRun: async () => 'test-agent-run-id', finishAgentRun: async () => {}, addAgentStep: async () => 'test-agent-step-id', - // promptAiSdkStream: async function* () { - // throw new Error('promptAiSdkStream not implemented in test runtime') - // }, -} + + // LLM + promptAiSdkStream: async function* () { + throw new Error('promptAiSdkStream not implemented in test runtime') + }, + + // Other + logger: testLogger, +}) diff --git a/common/src/types/contracts/agent-runtime.ts b/common/src/types/contracts/agent-runtime.ts index b3431f3dfc..83eae0d0a6 100644 --- a/common/src/types/contracts/agent-runtime.ts +++ b/common/src/types/contracts/agent-runtime.ts @@ -3,14 +3,18 @@ import type { FinishAgentRunFn, StartAgentRunFn, } from './database' +import type { PromptAiSdkStreamFn } from './llm' import type { Logger } from './logger' export type AgentRuntimeDeps = { - logger: Logger - + // Database startAgentRun: StartAgentRunFn finishAgentRun: FinishAgentRunFn addAgentStep: AddAgentStepFn - // promptAiSdkStream: PromptAiSdkStreamFn + // LLM + promptAiSdkStream: PromptAiSdkStreamFn + + // Other + logger: Logger } diff --git a/evals/impl/agent-runtime.ts b/evals/impl/agent-runtime.ts index 25e9a6ac8a..40d6674159 100644 --- a/evals/impl/agent-runtime.ts +++ b/evals/impl/agent-runtime.ts @@ -1,12 +1,13 @@ import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -export const evalAgentRuntimeImpl: AgentRuntimeDeps = { +export const EVALS_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ logger: console, startAgentRun: async () => 'test-agent-run-id', finishAgentRun: async () => {}, addAgentStep: async () => 'test-agent-step-id', - // promptAiSdkStream: async function* () { - // throw new Error('promptAiSdkStream not implemented in eval runtime') - // }, -} + + promptAiSdkStream: async function* () { + throw new Error('promptAiSdkStream not implemented in eval runtime') + }, +}) diff --git a/evals/scaffolding.ts b/evals/scaffolding.ts index 01911eb377..7e67347525 100644 --- a/evals/scaffolding.ts +++ b/evals/scaffolding.ts @@ -14,7 +14,7 @@ import { getSystemInfo } from '@codebuff/npm-app/utils/system-info' import { mock } from 'bun:test' import { blue } from 'picocolors' -import { evalAgentRuntimeImpl } from './impl/agent-runtime' +import { EVALS_AGENT_RUNTIME_IMPL } from './impl/agent-runtime' import { getAllFilePaths, getProjectFileTree, @@ -182,7 +182,7 @@ export async function runAgentStepScaffolding( }) const result = await runAgentStep({ - ...evalAgentRuntimeImpl, + ...EVALS_AGENT_RUNTIME_IMPL, ws: mockWs, userId: TEST_USER_ID, userInputId: generateCompactId(), diff --git a/knowledge.md b/knowledge.md index f04efcb543..b0c86e4a6d 100644 --- a/knowledge.md +++ b/knowledge.md @@ -152,9 +152,6 @@ afterEach(() => { spyOn(aisdk, 'promptAiSdk').mockImplementation(() => Promise.resolve('Test response'), ) -spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* () { - yield 'Test response' -}) // From rage-detector.test.ts - Mocking Date spyOn(Date, 'now').mockImplementation(() => currentTime) From 74935ba914c17007ff6646acab5c1cd5cd8f083e Mon Sep 17 00:00:00 2001 From: Charles Lien Date: Thu, 9 Oct 2025 16:40:39 -0700 Subject: [PATCH 2/6] add promptAiSdk to AgentRuntimeDeps --- .../__tests__/main-prompt.integration.test.ts | 17 ++-- .../src/__tests__/malformed-tool-call.test.ts | 11 ++- .../src/__tests__/process-file-block.test.ts | 20 ++--- .../__tests__/run-agent-step-tools.test.ts | 7 +- backend/src/__tests__/web-search-tool.test.ts | 7 +- backend/src/impl/agent-runtime.ts | 6 +- backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts | 27 ++---- backend/src/process-file-block.ts | 46 +++++----- backend/src/tools/handlers/tool/write-file.ts | 85 +++++++++++-------- common/src/testing/impl/agent-runtime.ts | 3 + common/src/types/contracts/agent-runtime.ts | 3 +- common/src/types/contracts/llm.ts | 21 ++++- evals/impl/agent-runtime.ts | 3 + knowledge.md | 10 ++- 14 files changed, 154 insertions(+), 112 deletions(-) diff --git a/backend/src/__tests__/main-prompt.integration.test.ts b/backend/src/__tests__/main-prompt.integration.test.ts index 0f13ba63b4..75f20e6c03 100644 --- a/backend/src/__tests__/main-prompt.integration.test.ts +++ b/backend/src/__tests__/main-prompt.integration.test.ts @@ -15,16 +15,18 @@ import { } from 'bun:test' import * as checkTerminalCommandModule from '../check-terminal-command' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { mainPrompt } from '../main-prompt' import * as websocketAction from '../websockets/websocket-action' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' // --- Shared Mocks & Helpers --- +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + class MockWebSocket { send(msg: string) {} close() {} @@ -103,6 +105,7 @@ describe.skip('mainPrompt (Integration)', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) it('should delete a specified function while preserving other code', async () => { @@ -337,7 +340,9 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { ) // Mock LLM calls - spyOn(aisdk, 'promptAiSdk').mockResolvedValue('Mocked non-stream AiSdk') + agentRuntimeImpl.promptAiSdk = async function () { + return 'Mocked non-stream AiSdk' + } const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.messageHistory.push( @@ -377,7 +382,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { } const { output, sessionState: finalSessionState } = await mainPrompt({ - ...TEST_AGENT_RUNTIME_IMPL, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, @@ -419,7 +424,9 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { ).mockResolvedValue(null) // Mock LLM calls - spyOn(aisdk, 'promptAiSdk').mockResolvedValue('Mocked non-stream AiSdk') + agentRuntimeImpl.promptAiSdk = async function () { + return 'Mocked non-stream AiSdk' + } const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.messageHistory.push( @@ -459,7 +466,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) { } await mainPrompt({ - ...TEST_AGENT_RUNTIME_IMPL, + ...agentRuntimeImpl, ws: new MockWebSocket() as unknown as WebSocket, action, userId: TEST_USER_ID, diff --git a/backend/src/__tests__/malformed-tool-call.test.ts b/backend/src/__tests__/malformed-tool-call.test.ts index eec1caed18..62f19644fb 100644 --- a/backend/src/__tests__/malformed-tool-call.test.ts +++ b/backend/src/__tests__/malformed-tool-call.test.ts @@ -16,17 +16,19 @@ import { } from 'bun:test' import { MockWebSocket, mockFileContext } from './test-utils' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { processStreamWithTools } from '../tools/stream-parser' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { Message, ToolMessage, } from '@codebuff/common/types/messages/codebuff-message' import type { WebSocket } from 'ws' +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + describe('malformed tool call error handling', () => { let testAgent: AgentTemplate let mockWs: MockWebSocket @@ -72,9 +74,9 @@ describe('malformed tool call error handling', () => { })) // Mock LLM APIs - spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) + agentRuntimeImpl.promptAiSdk = async function () { + return 'Test response' + } // Mock generateCompactId for consistent test results spyOn(stringUtils, 'generateCompactId').mockReturnValue('test-tool-call-id') @@ -82,6 +84,7 @@ describe('malformed tool call error handling', () => { afterEach(() => { mock.restore() + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) function createMockStream(chunks: string[]) { diff --git a/backend/src/__tests__/process-file-block.test.ts b/backend/src/__tests__/process-file-block.test.ts index 0fcd13bfba..05e266d945 100644 --- a/backend/src/__tests__/process-file-block.test.ts +++ b/backend/src/__tests__/process-file-block.test.ts @@ -1,4 +1,5 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { clearMockedModules, mockModule, @@ -9,14 +10,9 @@ import { applyPatch } from 'diff' import { processFileBlock } from '../process-file-block' -import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -const logger: Logger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {}, -} +let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } describe('processFileBlockModule', () => { beforeAll(() => { @@ -74,6 +70,7 @@ describe('processFileBlockModule', () => { const expectedContent = 'function test() {\n return true;\n}' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(null), @@ -85,7 +82,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() @@ -111,6 +107,7 @@ describe('processFileBlockModule', () => { '}\r\n' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(oldContent), @@ -122,7 +119,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() @@ -144,6 +140,7 @@ describe('processFileBlockModule', () => { const newContent = 'function test() {\n return true;\n}\n' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(oldContent), @@ -155,7 +152,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() @@ -170,6 +166,7 @@ describe('processFileBlockModule', () => { const newContent = 'const x = 1;\r\nconst z = 3;\r\n' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(oldContent), @@ -181,7 +178,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() @@ -217,6 +213,7 @@ describe('processFileBlockModule', () => { '// ... existing code ...\nconst x = 1;\n// ... existing code ...' const result = await processFileBlock({ + ...agentRuntimeImpl, path: 'test.ts', instructions: undefined, initialContentPromise: Promise.resolve(null), @@ -228,7 +225,6 @@ describe('processFileBlockModule', () => { fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - logger, }) expect(result).not.toBeNull() diff --git a/backend/src/__tests__/run-agent-step-tools.test.ts b/backend/src/__tests__/run-agent-step-tools.test.ts index 840ab53df8..06bf87bdda 100644 --- a/backend/src/__tests__/run-agent-step-tools.test.ts +++ b/backend/src/__tests__/run-agent-step-tools.test.ts @@ -18,7 +18,6 @@ import { // Mock imports import * as liveUserInputs from '../live-user-inputs' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { runAgentStep } from '../run-agent-step' import { clearAgentGeneratorCache } from '../run-programmatic-step' import { asUserMessage } from '../util/messages' @@ -106,9 +105,9 @@ describe('runAgentStep - set_output tool', () => { // Don't mock requestToolCall for integration test - let real tool execution happen // Mock LLM APIs - spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) + agentRuntimeImpl.promptAiSdk = async function () { + return 'Test response' + } clearAgentGeneratorCache(agentRuntimeImpl) }) diff --git a/backend/src/__tests__/web-search-tool.test.ts b/backend/src/__tests__/web-search-tool.test.ts index 6190db3916..82e094b25f 100644 --- a/backend/src/__tests__/web-search-tool.test.ts +++ b/backend/src/__tests__/web-search-tool.test.ts @@ -23,7 +23,6 @@ import * as requestFilesPrompt from '../find-files/request-files-prompt' import * as liveUserInputs from '../live-user-inputs' import { MockWebSocket, mockFileContext } from './test-utils' import * as linkupApi from '../llm-apis/linkup-api' -import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { runAgentStep } from '../run-agent-step' import { assembleLocalAgentTemplates } from '../templates/agent-registry' import * as websocketAction from '../websockets/websocket-action' @@ -67,9 +66,9 @@ describe('web_search tool with researcher agent', () => { })) // Mock LLM APIs - spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), - ) + agentRuntimeImpl.promptAiSdk = async function () { + return 'Test response' + } // Mock other required modules spyOn(requestFilesPrompt, 'requestRelevantFiles').mockImplementation( diff --git a/backend/src/impl/agent-runtime.ts b/backend/src/impl/agent-runtime.ts index 42894d3b9a..351ad4d839 100644 --- a/backend/src/impl/agent-runtime.ts +++ b/backend/src/impl/agent-runtime.ts @@ -1,5 +1,8 @@ import { addAgentStep, finishAgentRun, startAgentRun } from '../agent-run' -import { promptAiSdkStream } from '../llm-apis/vercel-ai-sdk/ai-sdk' +import { + promptAiSdk, + promptAiSdkStream, +} from '../llm-apis/vercel-ai-sdk/ai-sdk' import { logger } from '../util/logger' import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' @@ -12,4 +15,5 @@ export const BACKEND_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ addAgentStep, promptAiSdkStream, + promptAiSdk, }) diff --git a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts index 1f0cbf101b..a5f71df11b 100644 --- a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts +++ b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts @@ -17,12 +17,12 @@ import { openRouterLanguageModel } from '../openrouter' import { vertexFinetuned } from './vertex-finetuned' import type { Model, OpenAIModel } from '@codebuff/common/old-constants' -import type { PromptAiSdkStreamFn } from '@codebuff/common/types/contracts/llm' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { - ParamsExcluding, - ParamsOf, -} from '@codebuff/common/types/function-params' + PromptAiSdkFn, + PromptAiSdkStreamFn, +} from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsOf } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { OpenRouterProviderOptions, @@ -231,21 +231,8 @@ export const promptAiSdkStream = async function* ( // TODO: figure out a nice way to unify stream & non-stream versions maybe? export const promptAiSdk = async function ( - params: { - messages: Message[] - clientSessionId: string - fingerprintId: string - userInputId: string - model: Model - userId: string | undefined - chargeUser?: boolean - agentId?: string - onCostCalculated?: (credits: number) => Promise - includeCacheControl?: boolean - maxRetries?: number - logger: Logger - } & ParamsExcluding, -): Promise { + params: ParamsOf, +): ReturnType { const { logger } = params if (!checkLiveUserInput(params)) { diff --git a/backend/src/process-file-block.ts b/backend/src/process-file-block.ts index 6d8101abb5..7ed48a3336 100644 --- a/backend/src/process-file-block.ts +++ b/backend/src/process-file-block.ts @@ -8,26 +8,32 @@ import { parseAndGetDiffBlocksSingleFile, retryDiffBlocksPrompt, } from './generate-diffs-prompt' -import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' import { countTokens } from './util/token-counter' -import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' +import type { Message } from '@codebuff/common/types/messages/codebuff-message' -export async function processFileBlock(params: { - path: string - instructions: string | undefined - initialContentPromise: Promise - newContent: string - messages: Message[] - fullResponse: string - lastUserPrompt: string | undefined - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - logger: Logger -}): Promise< +export async function processFileBlock( + params: { + path: string + instructions: string | undefined + initialContentPromise: Promise + newContent: string + messages: Message[] + fullResponse: string + lastUserPrompt: string | undefined + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + logger: Logger + } & ParamsExcluding< + typeof handleLargeFile, + 'oldContent' | 'editSnippet' | 'filePath' + >, +): Promise< | { tool: 'write_file' path: string @@ -113,14 +119,10 @@ export async function processFileBlock(params: { ) if (tokenCount > LARGE_FILE_TOKEN_LIMIT) { const largeFileContent = await handleLargeFile({ + ...params, oldContent: normalizedInitialContent, editSnippet: normalizedEditSnippet, - clientSessionId, - fingerprintId, - userInputId, - userId, filePath: path, - logger, }) if (!largeFileContent) { @@ -239,6 +241,7 @@ export async function handleLargeFile(params: { userId: string | undefined filePath: string logger: Logger + promptAiSdk: PromptAiSdkFn }): Promise { const { oldContent, @@ -248,6 +251,7 @@ export async function handleLargeFile(params: { userInputId, userId, filePath, + promptAiSdk, logger, } = params const startTime = Date.now() diff --git a/backend/src/tools/handlers/tool/write-file.ts b/backend/src/tools/handlers/tool/write-file.ts index c4cfe5a533..0624b15d58 100644 --- a/backend/src/tools/handlers/tool/write-file.ts +++ b/backend/src/tools/handlers/tool/write-file.ts @@ -3,14 +3,14 @@ import { partition } from 'lodash' import { processFileBlock } from '../../../process-file-block' import { requestOptionalFile } from '../../../websockets/websocket-action' -import type { Logger } from '@codebuff/common/types/contracts/logger' - import type { CodebuffToolHandlerFunction } from '../handler-function-type' import type { ClientToolCall, CodebuffToolCall, CodebuffToolOutput, } from '@codebuff/common/tools/list' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { WebSocket } from 'ws' @@ -63,45 +63,59 @@ export function getFileProcessingValues( return fileProcessingValues } -export const handleWriteFile = (({ - previousToolCallFinished, - toolCall, - - clientSessionId, - userInputId, +export const handleWriteFile = (( + params: { + previousToolCallFinished: Promise + toolCall: CodebuffToolCall<'write_file'> - requestClientToolCall, - writeToClient, + clientSessionId: string + userInputId: string - getLatestState, - state, - logger, -}: { - previousToolCallFinished: Promise - toolCall: CodebuffToolCall<'write_file'> - - clientSessionId: string - userInputId: string - - requestClientToolCall: ( - toolCall: ClientToolCall<'write_file'>, - ) => Promise> - writeToClient: (chunk: string) => void + requestClientToolCall: ( + toolCall: ClientToolCall<'write_file'>, + ) => Promise> + writeToClient: (chunk: string) => void - getLatestState: () => FileProcessingState - state: { - ws?: WebSocket - fingerprintId?: string - userId?: string - fullResponse?: string - prompt?: string - messages?: Message[] - } & OptionalFileProcessingState - logger: Logger -}): { + getLatestState: () => FileProcessingState + state: { + ws?: WebSocket + fingerprintId?: string + userId?: string + fullResponse?: string + prompt?: string + messages?: Message[] + } & OptionalFileProcessingState + logger: Logger + } & ParamsExcluding< + typeof processFileBlock, + | 'path' + | 'instructions' + | 'fingerprintId' + | 'userId' + | 'initialContentPromise' + | 'newContent' + | 'messages' + | 'fullResponse' + | 'lastUserPrompt' + >, +): { result: Promise> state: FileProcessingState } => { + const { + previousToolCallFinished, + toolCall, + + clientSessionId, + userInputId, + + requestClientToolCall, + writeToClient, + + getLatestState, + state, + logger, + } = params const { path, instructions, content } = toolCall.input const { ws, fingerprintId, userId, fullResponse, prompt } = state if (!ws) { @@ -143,6 +157,7 @@ export const handleWriteFile = (({ logger.debug({ path, content }, `write_file ${path}`) const newPromise = processFileBlock({ + ...params, path, instructions, initialContentPromise: latestContentPromise, diff --git a/common/src/testing/impl/agent-runtime.ts b/common/src/testing/impl/agent-runtime.ts index d4d87cbdd9..e524020276 100644 --- a/common/src/testing/impl/agent-runtime.ts +++ b/common/src/testing/impl/agent-runtime.ts @@ -18,6 +18,9 @@ export const TEST_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ promptAiSdkStream: async function* () { throw new Error('promptAiSdkStream not implemented in test runtime') }, + promptAiSdk: async function () { + throw new Error('promptAiSdk not implemented in test runtime') + }, // Other logger: testLogger, diff --git a/common/src/types/contracts/agent-runtime.ts b/common/src/types/contracts/agent-runtime.ts index 83eae0d0a6..7446cf7846 100644 --- a/common/src/types/contracts/agent-runtime.ts +++ b/common/src/types/contracts/agent-runtime.ts @@ -3,7 +3,7 @@ import type { FinishAgentRunFn, StartAgentRunFn, } from './database' -import type { PromptAiSdkStreamFn } from './llm' +import type { PromptAiSdkFn, PromptAiSdkStreamFn } from './llm' import type { Logger } from './logger' export type AgentRuntimeDeps = { @@ -14,6 +14,7 @@ export type AgentRuntimeDeps = { // LLM promptAiSdkStream: PromptAiSdkStreamFn + promptAiSdk: PromptAiSdkFn // Other logger: Logger diff --git a/common/src/types/contracts/llm.ts b/common/src/types/contracts/llm.ts index 1846d48dfd..e0fa903a4a 100644 --- a/common/src/types/contracts/llm.ts +++ b/common/src/types/contracts/llm.ts @@ -1,8 +1,8 @@ import type { ParamsExcluding } from '../function-params' import type { Logger } from './logger' -import type { Message } from '../messages/codebuff-message' -import type { streamText } from 'ai' import type { Model } from '../../old-constants' +import type { Message } from '../messages/codebuff-message' +import type { generateText, streamText } from 'ai' export type StreamChunk = | { @@ -32,3 +32,20 @@ export type PromptAiSdkStreamFn = ( logger: Logger } & ParamsExcluding, ) => AsyncGenerator + +export type PromptAiSdkFn = ( + params: { + messages: Message[] + clientSessionId: string + fingerprintId: string + userInputId: string + model: Model + userId: string | undefined + chargeUser?: boolean + agentId?: string + onCostCalculated?: (credits: number) => Promise + includeCacheControl?: boolean + maxRetries?: number + logger: Logger + } & ParamsExcluding, +) => Promise diff --git a/evals/impl/agent-runtime.ts b/evals/impl/agent-runtime.ts index 40d6674159..7a46149be4 100644 --- a/evals/impl/agent-runtime.ts +++ b/evals/impl/agent-runtime.ts @@ -10,4 +10,7 @@ export const EVALS_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ promptAiSdkStream: async function* () { throw new Error('promptAiSdkStream not implemented in eval runtime') }, + promptAiSdk: async function () { + throw new Error('promptAiSdk not implemented in eval runtime') + }, }) diff --git a/knowledge.md b/knowledge.md index b0c86e4a6d..1e598319b3 100644 --- a/knowledge.md +++ b/knowledge.md @@ -149,9 +149,13 @@ afterEach(() => { ```typescript // From main-prompt.test.ts - Mocking LLM APIs -spyOn(aisdk, 'promptAiSdk').mockImplementation(() => - Promise.resolve('Test response'), -) +agentRuntimeImpl.promptAiSdk = async function () { + return 'Test response' +} +agentRuntimeImpl.promptAiSdkStream = async function* () { + yield { type: 'text' as const, text: 'Test response' } + return 'mock-message-id' +} // From rage-detector.test.ts - Mocking Date spyOn(Date, 'now').mockImplementation(() => currentTime) From e2abf2ad3bdb29691d18b9362e1fb3fbe6e2c2ac Mon Sep 17 00:00:00 2001 From: Charles Lien Date: Thu, 9 Oct 2025 16:54:45 -0700 Subject: [PATCH 3/6] add promptAiSdkStructured to AgentRuntimeDeps --- backend/src/impl/agent-runtime.ts | 9 +++++-- backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts | 25 ++++---------------- common/src/testing/impl/agent-runtime.ts | 3 +++ common/src/types/contracts/agent-runtime.ts | 7 +++++- common/src/types/contracts/llm.ts | 24 +++++++++++++++++++ evals/impl/agent-runtime.ts | 3 +++ 6 files changed, 48 insertions(+), 23 deletions(-) diff --git a/backend/src/impl/agent-runtime.ts b/backend/src/impl/agent-runtime.ts index 351ad4d839..53759bedea 100644 --- a/backend/src/impl/agent-runtime.ts +++ b/backend/src/impl/agent-runtime.ts @@ -2,18 +2,23 @@ import { addAgentStep, finishAgentRun, startAgentRun } from '../agent-run' import { promptAiSdk, promptAiSdkStream, + promptAiSdkStructured, } from '../llm-apis/vercel-ai-sdk/ai-sdk' import { logger } from '../util/logger' import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' export const BACKEND_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ - logger, - + // Database startAgentRun, finishAgentRun, addAgentStep, + // LLM promptAiSdkStream, promptAiSdk, + promptAiSdkStructured, + + // Other + logger, }) diff --git a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts index a5f71df11b..02fd22ca5e 100644 --- a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts +++ b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts @@ -20,10 +20,10 @@ import type { Model, OpenAIModel } from '@codebuff/common/old-constants' import type { PromptAiSdkFn, PromptAiSdkStreamFn, + PromptAiSdkStructuredInput, + PromptAiSdkStructuredOutput, } from '@codebuff/common/types/contracts/llm' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsOf } from '@codebuff/common/types/function-params' -import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { OpenRouterProviderOptions, OpenRouterUsageAccounting, @@ -287,24 +287,9 @@ export const promptAiSdk = async function ( } // Copied over exactly from promptAiSdk but with a schema -export const promptAiSdkStructured = async function (params: { - messages: Message[] - schema: z.ZodType - clientSessionId: string - fingerprintId: string - userInputId: string - model: Model - userId: string | undefined - maxTokens?: number - temperature?: number - timeout?: number - chargeUser?: boolean - agentId?: string - onCostCalculated?: (credits: number) => Promise - includeCacheControl?: boolean - maxRetries?: number - logger: Logger -}): Promise { +export const promptAiSdkStructured = async function ( + params: PromptAiSdkStructuredInput, +): PromptAiSdkStructuredOutput { const { logger } = params if (!checkLiveUserInput(params)) { diff --git a/common/src/testing/impl/agent-runtime.ts b/common/src/testing/impl/agent-runtime.ts index e524020276..90a4e0bce1 100644 --- a/common/src/testing/impl/agent-runtime.ts +++ b/common/src/testing/impl/agent-runtime.ts @@ -21,6 +21,9 @@ export const TEST_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ promptAiSdk: async function () { throw new Error('promptAiSdk not implemented in test runtime') }, + promptAiSdkStructured: async function () { + throw new Error('promptAiSdkStructured not implemented in test runtime') + }, // Other logger: testLogger, diff --git a/common/src/types/contracts/agent-runtime.ts b/common/src/types/contracts/agent-runtime.ts index 7446cf7846..65d9358b2f 100644 --- a/common/src/types/contracts/agent-runtime.ts +++ b/common/src/types/contracts/agent-runtime.ts @@ -3,7 +3,11 @@ import type { FinishAgentRunFn, StartAgentRunFn, } from './database' -import type { PromptAiSdkFn, PromptAiSdkStreamFn } from './llm' +import type { + PromptAiSdkFn, + PromptAiSdkStreamFn, + PromptAiSdkStructuredFn, +} from './llm' import type { Logger } from './logger' export type AgentRuntimeDeps = { @@ -15,6 +19,7 @@ export type AgentRuntimeDeps = { // LLM promptAiSdkStream: PromptAiSdkStreamFn promptAiSdk: PromptAiSdkFn + promptAiSdkStructured: PromptAiSdkStructuredFn // Other logger: Logger diff --git a/common/src/types/contracts/llm.ts b/common/src/types/contracts/llm.ts index e0fa903a4a..0a2ecb9ccb 100644 --- a/common/src/types/contracts/llm.ts +++ b/common/src/types/contracts/llm.ts @@ -3,6 +3,7 @@ import type { Logger } from './logger' import type { Model } from '../../old-constants' import type { Message } from '../messages/codebuff-message' import type { generateText, streamText } from 'ai' +import type z from 'zod/v4' export type StreamChunk = | { @@ -49,3 +50,26 @@ export type PromptAiSdkFn = ( logger: Logger } & ParamsExcluding, ) => Promise + +export type PromptAiSdkStructuredInput = { + messages: Message[] + schema: z.ZodType + clientSessionId: string + fingerprintId: string + userInputId: string + model: Model + userId: string | undefined + maxTokens?: number + temperature?: number + timeout?: number + chargeUser?: boolean + agentId?: string + onCostCalculated?: (credits: number) => Promise + includeCacheControl?: boolean + maxRetries?: number + logger: Logger +} +export type PromptAiSdkStructuredOutput = Promise +export type PromptAiSdkStructuredFn = ( + params: PromptAiSdkStructuredInput, +) => PromptAiSdkStructuredOutput diff --git a/evals/impl/agent-runtime.ts b/evals/impl/agent-runtime.ts index 7a46149be4..373e8b637d 100644 --- a/evals/impl/agent-runtime.ts +++ b/evals/impl/agent-runtime.ts @@ -13,4 +13,7 @@ export const EVALS_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({ promptAiSdk: async function () { throw new Error('promptAiSdk not implemented in eval runtime') }, + promptAiSdkStructured: async function () { + throw new Error('promptAiSdkStructured not implemented in eval runtime') + }, }) From cf5aaa9ee4d6d56166e52806546dbeb7859c1b29 Mon Sep 17 00:00:00 2001 From: Charles Lien Date: Thu, 9 Oct 2025 17:09:09 -0700 Subject: [PATCH 4/6] change promptAiSdk functions to function instead of const --- .../cost-aggregation.integration.test.ts | 4 ++-- backend/src/__tests__/main-prompt.test.ts | 16 ++++------------ backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts | 6 +++--- 3 files changed, 9 insertions(+), 17 deletions(-) diff --git a/backend/src/__tests__/cost-aggregation.integration.test.ts b/backend/src/__tests__/cost-aggregation.integration.test.ts index 18216240c1..10f77415a4 100644 --- a/backend/src/__tests__/cost-aggregation.integration.test.ts +++ b/backend/src/__tests__/cost-aggregation.integration.test.ts @@ -99,9 +99,10 @@ class MockWebSocket { describe('Cost Aggregation Integration Tests', () => { let mockLocalAgentTemplates: Record let mockWebSocket: MockWebSocket - let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } + let agentRuntimeImpl: AgentRuntimeDeps beforeEach(async () => { + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } mockWebSocket = new MockWebSocket() // Setup mock agent templates @@ -230,7 +231,6 @@ describe('Cost Aggregation Integration Tests', () => { afterEach(() => { mock.restore() - agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) it('should correctly aggregate costs across the entire main prompt flow', async () => { diff --git a/backend/src/__tests__/main-prompt.test.ts b/backend/src/__tests__/main-prompt.test.ts index 25577a880f..296c47907a 100644 --- a/backend/src/__tests__/main-prompt.test.ts +++ b/backend/src/__tests__/main-prompt.test.ts @@ -28,11 +28,10 @@ import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '@codebuff/common/types/agent-template' import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' -let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } +let agentRuntimeImpl: AgentRuntimeDeps const mockAgentStream = (streamOutput: string) => { agentRuntimeImpl.promptAiSdkStream = async function* ({}) { @@ -43,14 +42,10 @@ const mockAgentStream = (streamOutput: string) => { describe('mainPrompt', () => { let mockLocalAgentTemplates: Record - const logger: Logger = { - debug: () => {}, - error: () => {}, - info: () => {}, - warn: () => {}, - } beforeEach(() => { + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } + // Setup common mock agent templates mockLocalAgentTemplates = { [AgentTemplateTypes.base]: { @@ -86,12 +81,10 @@ describe('mainPrompt', () => { stepPrompt: '', } satisfies AgentTemplate, } - }) - beforeEach(() => { // Mock analytics and tracing spyOn(analytics, 'initAnalytics').mockImplementation(() => {}) - analytics.initAnalytics({ logger }) // Initialize the mock + analytics.initAnalytics(agentRuntimeImpl) // Initialize the mock spyOn(analytics, 'trackEvent').mockImplementation(() => {}) spyOn(bigquery, 'insertTrace').mockImplementation(() => Promise.resolve(true), @@ -176,7 +169,6 @@ describe('mainPrompt', () => { afterEach(() => { // Clear all mocks after each test mock.restore() - agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) class MockWebSocket { diff --git a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts index 02fd22ca5e..995a218ced 100644 --- a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts +++ b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts @@ -53,7 +53,7 @@ const modelToAiSDKModel = (model: Model): LanguageModel => { // TODO: Add retries & fallbacks: likely by allowing this to instead of "model" // also take an array of form [{model: Model, retries: number}, {model: Model, retries: number}...] // eg: [{model: "gemini-2.0-flash-001"}, {model: "vertex/gemini-2.0-flash-001"}, {model: "claude-3-5-haiku", retries: 3}] -export const promptAiSdkStream = async function* ( +export async function* promptAiSdkStream( params: ParamsOf, ): ReturnType { const { logger } = params @@ -230,7 +230,7 @@ export const promptAiSdkStream = async function* ( } // TODO: figure out a nice way to unify stream & non-stream versions maybe? -export const promptAiSdk = async function ( +export async function promptAiSdk( params: ParamsOf, ): ReturnType { const { logger } = params @@ -287,7 +287,7 @@ export const promptAiSdk = async function ( } // Copied over exactly from promptAiSdk but with a schema -export const promptAiSdkStructured = async function ( +export async function promptAiSdkStructured( params: PromptAiSdkStructuredInput, ): PromptAiSdkStructuredOutput { const { logger } = params From 89ef32f554d011f8d21d1173ced3c54b716fcd4d Mon Sep 17 00:00:00 2001 From: Charles Lien Date: Thu, 9 Oct 2025 17:59:09 -0700 Subject: [PATCH 5/6] pass in promptAiSdk --- backend/src/__tests__/fast-rewrite.test.ts | 20 +-- .../__tests__/request-files-prompt.test.ts | 31 ++-- backend/src/admin/grade-runs.ts | 6 +- backend/src/admin/relabelRuns.ts | 25 ++- backend/src/check-terminal-command.ts | 8 +- backend/src/fast-rewrite.ts | 94 ++++------- .../src/find-files/request-files-prompt.ts | 156 +++++++----------- backend/src/generate-diffs-prompt.ts | 41 +++-- backend/src/llm-apis/gemini-with-fallbacks.ts | 43 ++--- backend/src/llm-apis/relace-api.ts | 4 +- backend/src/main-prompt.ts | 14 +- backend/src/process-file-block.ts | 67 ++++---- backend/src/tools/handlers/tool/find-files.ts | 121 +++++++------- scripts/ft-file-selection/grade-traces.ts | 7 +- .../relabel-traces-with-context.ts | 2 + scripts/ft-file-selection/relabel-traces.ts | 1 + 16 files changed, 281 insertions(+), 359 deletions(-) diff --git a/backend/src/__tests__/fast-rewrite.test.ts b/backend/src/__tests__/fast-rewrite.test.ts index 19e50abf90..8db17d40b9 100644 --- a/backend/src/__tests__/fast-rewrite.test.ts +++ b/backend/src/__tests__/fast-rewrite.test.ts @@ -1,23 +1,19 @@ import path from 'path' import { TEST_USER_ID } from '@codebuff/common/old-constants' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { clearMockedModules, mockModule, } from '@codebuff/common/testing/mock-modules' -import { afterAll, beforeAll, describe, expect, it } from 'bun:test' +import { afterAll, beforeAll, beforeEach, describe, expect, it } from 'bun:test' import { createPatch } from 'diff' import { rewriteWithOpenAI } from '../fast-rewrite' -import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' -const logger: Logger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {}, -} +let agentRuntimeImpl: AgentRuntimeDeps describe.skip('rewriteWithOpenAI', () => { beforeAll(() => { @@ -42,6 +38,10 @@ describe.skip('rewriteWithOpenAI', () => { })) }) + beforeEach(() => { + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } + }) + afterAll(() => { clearMockedModules() }) @@ -53,15 +53,13 @@ describe.skip('rewriteWithOpenAI', () => { const expectedResult = await Bun.file(`${testDataDir}/expected.go`).text() const result = await rewriteWithOpenAI({ + ...agentRuntimeImpl, oldContent: originalContent, editSnippet, - filePath: 'taskruntoolcall.go', clientSessionId: 'clientSessionId', fingerprintId: 'fingerprintId', userInputId: 'userInputId', userId: TEST_USER_ID, - userMessage: undefined, - logger, }) const patch = createPatch('test.ts', expectedResult, result) diff --git a/backend/src/__tests__/request-files-prompt.test.ts b/backend/src/__tests__/request-files-prompt.test.ts index d11ab037a2..56552574d5 100644 --- a/backend/src/__tests__/request-files-prompt.test.ts +++ b/backend/src/__tests__/request-files-prompt.test.ts @@ -1,4 +1,5 @@ import { finetunedVertexModels } from '@codebuff/common/old-constants' +import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { clearMockedModules, mockModule, @@ -18,10 +19,13 @@ import * as OriginalRequestFilesPromptModule from '../find-files/request-files-p import * as geminiWithFallbacksModule from '../llm-apis/gemini-with-fallbacks' import type { CostMode } from '@codebuff/common/old-constants' +import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { Mock } from 'bun:test' +let agentRuntimeImpl: AgentRuntimeDeps + describe('requestRelevantFiles', () => { const mockMessages: Message[] = [{ role: 'user', content: 'test prompt' }] const mockSystem = 'test system' @@ -58,12 +62,6 @@ describe('requestRelevantFiles', () => { const mockUserId = 'user1' const mockCostMode: CostMode = 'normal' const mockRepoId = 'owner/repo' - const logger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {}, - } let getCustomFilePickerConfigForOrgSpy: any // Explicitly typed as any @@ -81,15 +79,6 @@ describe('requestRelevantFiles', () => { })), })) - mockModule('@codebuff/backend/util/logger', () => ({ - logger: { - info: bunMockFn(() => {}), - error: bunMockFn(() => {}), - warn: bunMockFn(() => {}), - debug: bunMockFn(() => {}), - }, - })) - mockModule('@codebuff/common/db', () => ({ default: { insert: bunMockFn(() => ({ @@ -109,6 +98,8 @@ describe('requestRelevantFiles', () => { }) beforeEach(() => { + agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } + // If the spy was created in a previous test, restore it if ( getCustomFilePickerConfigForOrgSpy && @@ -134,6 +125,7 @@ describe('requestRelevantFiles', () => { it('should use default file counts and maxFiles when no custom config', async () => { await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -144,7 +136,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, @@ -161,6 +152,7 @@ describe('requestRelevantFiles', () => { getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -171,7 +163,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, @@ -187,6 +178,7 @@ describe('requestRelevantFiles', () => { getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) const result = await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -197,7 +189,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) expect(result).toBeArray() if (result) { @@ -213,6 +204,7 @@ describe('requestRelevantFiles', () => { getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -223,7 +215,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, @@ -242,6 +233,7 @@ describe('requestRelevantFiles', () => { getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) await OriginalRequestFilesPromptModule.requestRelevantFiles({ + ...agentRuntimeImpl, messages: mockMessages, system: mockSystem, fileContext: mockFileContext, @@ -252,7 +244,6 @@ describe('requestRelevantFiles', () => { userInputId: mockUserInputId, userId: mockUserId, repoId: mockRepoId, - logger, }) const expectedModel = finetunedVertexModels.ft_filepicker_010 expect( diff --git a/backend/src/admin/grade-runs.ts b/backend/src/admin/grade-runs.ts index 967c684606..287d300154 100644 --- a/backend/src/admin/grade-runs.ts +++ b/backend/src/admin/grade-runs.ts @@ -1,9 +1,8 @@ import { models, TEST_USER_ID } from '@codebuff/common/old-constants' import { closeXml } from '@codebuff/common/util/xml' -import { promptAiSdk } from '../llm-apis/vercel-ai-sdk/ai-sdk' - import type { Relabel, GetRelevantFilesTrace } from '@codebuff/bigquery' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' const PROMPT = ` @@ -100,9 +99,10 @@ function extractResponse(response: string): { export async function gradeRun(params: { trace: GetRelevantFilesTrace relabels: Relabel[] + promptAiSdk: PromptAiSdkFn logger: Logger }) { - const { trace, relabels, logger } = params + const { trace, relabels, promptAiSdk, logger } = params const messages = trace.payload.messages const originalOutput = trace.payload.output diff --git a/backend/src/admin/relabelRuns.ts b/backend/src/admin/relabelRuns.ts index e40fd56c4a..f52982b519 100644 --- a/backend/src/admin/relabelRuns.ts +++ b/backend/src/admin/relabelRuns.ts @@ -24,8 +24,10 @@ import type { GetRelevantFilesTrace, Relabel, } from '@codebuff/bigquery' -import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' +import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { Request, Response } from 'express' // --- GET Handler Logic --- @@ -153,6 +155,7 @@ export async function relabelForUserHandler(params: { const relaceResults = relabelUsingFullFilesForUser({ userId, limit, + promptAiSdk, logger, }) @@ -265,11 +268,16 @@ export async function relabelForUserHandler(params: { } } -async function relabelUsingFullFilesForUser(params: { - userId: string - limit: number - logger: Logger -}) { +async function relabelUsingFullFilesForUser( + params: { + userId: string + limit: number + logger: Logger + } & ParamsExcluding< + typeof relabelWithClaudeWithFullFileContext, + 'trace' | 'fileBlobs' | 'model' + >, +) { const { userId, limit, logger } = params // TODO: We need to figure out changing _everything_ to use `getTracesAndAllDataForUser` const tracesBundles = await getTracesAndAllDataForUser(userId) @@ -308,10 +316,10 @@ async function relabelUsingFullFilesForUser(params: { ) { relabelPromises.push( relabelWithClaudeWithFullFileContext({ + ...params, trace, fileBlobs, model, - logger, }), ) didRelabel = true @@ -392,9 +400,10 @@ export async function relabelWithClaudeWithFullFileContext(params: { fileBlobs: GetExpandedFileContextForTrainingBlobTrace model: string dataset?: string + promptAiSdk: PromptAiSdkFn logger: Logger }) { - const { trace, fileBlobs, model, dataset, logger } = params + const { trace, fileBlobs, model, dataset, promptAiSdk, logger } = params if (dataset) { await setupBigQuery({ dataset, logger }) } diff --git a/backend/src/check-terminal-command.ts b/backend/src/check-terminal-command.ts index 1c1a2931bc..b1973b44b0 100644 --- a/backend/src/check-terminal-command.ts +++ b/backend/src/check-terminal-command.ts @@ -1,8 +1,7 @@ import { models } from '@codebuff/common/old-constants' import { withTimeout } from '@codebuff/common/util/promise' -import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' - +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsExcluding } from '@codebuff/common/types/function-params' @@ -13,10 +12,11 @@ import type { ParamsExcluding } from '@codebuff/common/types/function-params' export async function checkTerminalCommand( params: { prompt: string + promptAiSdk: PromptAiSdkFn logger: Logger - } & ParamsExcluding, + } & ParamsExcluding, ): Promise { - const { prompt, logger } = params + const { prompt, promptAiSdk, logger } = params if (!prompt?.trim()) { return null } diff --git a/backend/src/fast-rewrite.ts b/backend/src/fast-rewrite.ts index 49212bd584..f6709909e9 100644 --- a/backend/src/fast-rewrite.ts +++ b/backend/src/fast-rewrite.ts @@ -5,52 +5,33 @@ import { generateCompactId, hasLazyEdit } from '@codebuff/common/util/string' import { promptFlashWithFallbacks } from './llm-apis/gemini-with-fallbacks' import { promptRelaceAI } from './llm-apis/relace-api' -import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' import type { CodebuffToolMessage } from '@codebuff/common/tools/list' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { Message, ToolMessage, } from '@codebuff/common/types/messages/codebuff-message' -export async function fastRewrite(params: { - initialContent: string - editSnippet: string - filePath: string - instructions: string | undefined - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - userMessage: string | undefined - logger: Logger -}) { - const { - initialContent, - editSnippet, - filePath, - instructions, - clientSessionId, - fingerprintId, - userInputId, - userId, - userMessage, - logger, - } = params +export async function fastRewrite( + params: { + initialContent: string + editSnippet: string + filePath: string + userMessage: string | undefined + logger: Logger + } & ParamsExcluding & + ParamsExcluding, +) { + const { initialContent, editSnippet, filePath, userMessage, logger } = params const relaceStartTime = Date.now() const messageId = generateCompactId('cb-') let response = await promptRelaceAI({ + ...params, initialCode: initialContent, - editSnippet, - instructions, - clientSessionId, - fingerprintId, - userInputId, - userId, - userMessage, messageId, - logger, }) const relaceDuration = Date.now() - relaceStartTime @@ -62,15 +43,8 @@ export async function fastRewrite(params: { ) { const relaceResponse = response response = await rewriteWithOpenAI({ + ...params, oldContent: initialContent, - editSnippet, - filePath, - clientSessionId, - fingerprintId, - userInputId, - userId, - userMessage, - logger, }) logger.debug( { filePath, relaceResponse, openaiResponse: response, messageId }, @@ -97,23 +71,21 @@ export async function fastRewrite(params: { export async function rewriteWithOpenAI(params: { oldContent: string editSnippet: string - filePath: string clientSessionId: string fingerprintId: string userInputId: string userId: string | undefined - userMessage: string | undefined + promptAiSdk: PromptAiSdkFn logger: Logger }): Promise { const { oldContent, editSnippet, - filePath, clientSessionId, fingerprintId, userInputId, userId, - userMessage, + promptAiSdk, logger, } = params const prompt = `You are an expert programmer tasked with implementing changes to a file. Please rewrite the file to implement the changes shown in the edit snippet, while preserving the original formatting and behavior of unchanged parts. @@ -158,28 +130,22 @@ Please output just the complete updated file content with the edit applied and n * sketches an update to a single function, but forgets to add ... existing code ... * above and below the function. */ -export const shouldAddFilePlaceholders = async (params: { - filePath: string - oldContent: string - rewrittenNewContent: string - messageHistory: Message[] - fullResponse: string - userId: string | undefined - clientSessionId: string - fingerprintId: string - userInputId: string - logger: Logger -}) => { +export const shouldAddFilePlaceholders = async ( + params: { + filePath: string + oldContent: string + rewrittenNewContent: string + messageHistory: Message[] + fullResponse: string + logger: Logger + } & ParamsExcluding, +) => { const { filePath, oldContent, rewrittenNewContent, messageHistory, fullResponse, - userId, - clientSessionId, - fingerprintId, - userInputId, logger, } = params const fileWasPreviouslyEdited = messageHistory @@ -244,13 +210,9 @@ Do not write anything else. }, ) const response = await promptFlashWithFallbacks({ + ...params, messages, - clientSessionId, - fingerprintId, - userInputId, model: models.openrouter_gemini2_5_flash, - userId, - logger, }) const shouldAddPlaceholderComments = response.includes('LOCAL_CHANGE_ONLY') logger.debug( diff --git a/backend/src/find-files/request-files-prompt.ts b/backend/src/find-files/request-files-prompt.ts index 4d29189d96..7f311d7836 100644 --- a/backend/src/find-files/request-files-prompt.ts +++ b/backend/src/find-files/request-files-prompt.ts @@ -14,7 +14,6 @@ import { range, shuffle, uniq } from 'lodash' import { CustomFilePickerConfigSchema } from './custom-file-picker-config' import { promptFlashWithFallbacks } from '../llm-apis/gemini-with-fallbacks' -import { promptAiSdk } from '../llm-apis/vercel-ai-sdk/ai-sdk' import { castAssistantMessage, messagesWithSystem, @@ -28,9 +27,11 @@ import type { GetExpandedFileContextForTrainingTrace, GetRelevantFilesTrace, } from '@codebuff/bigquery' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' -import type { Logger } from '@codebuff/common/types/contracts/logger' const NUMBER_OF_EXAMPLE_FILES = 100 const MAX_FILES_PER_REQUEST = 30 @@ -120,32 +121,25 @@ function isValidFilePickerModelName( return Object.keys(finetunedVertexModels).includes(modelName) } -export async function requestRelevantFiles(params: { - messages: Message[] - system: string | Array - fileContext: ProjectFileContext - assistantPrompt: string | null - agentStepId: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - repoId: string | undefined - logger: Logger -}) { - const { - messages, - system, - fileContext, - assistantPrompt, - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, - logger, - } = params +export async function requestRelevantFiles( + params: { + messages: Message[] + system: string | Array + fileContext: ProjectFileContext + assistantPrompt: string | null + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + repoId: string | undefined + logger: Logger + } & ParamsExcluding< + typeof getRelevantFiles, + 'messages' | 'userPrompt' | 'requestType' | 'modelId' + >, +) { + const { messages, fileContext, assistantPrompt, logger } = params // Check for organization custom file picker feature const requestContext = getRequestContext() const orgId = requestContext?.approvedOrgIdForRepo @@ -193,18 +187,11 @@ export async function requestRelevantFiles(params: { } const keyPromise = getRelevantFiles({ + ...params, messages: messagesExcludingLastIfByUser, - system, userPrompt: keyPrompt, requestType: 'Key', - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, modelId: modelIdForRequest, - logger, }).catch((error) => { logger.error({ error }, 'Error requesting key files') return { files: [] as string[], duration: 0 } @@ -227,32 +214,18 @@ export async function requestRelevantFiles(params: { return candidateFiles.slice(0, maxFilesPerRequest) } -export async function requestRelevantFilesForTraining(params: { - messages: Message[] - system: string | Array - fileContext: ProjectFileContext - assistantPrompt: string | null - agentStepId: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - repoId: string | undefined - logger: Logger -}) { - const { - messages, - system, - fileContext, - assistantPrompt, - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, - logger, - } = params +export async function requestRelevantFilesForTraining( + params: { + messages: Message[] + fileContext: ProjectFileContext + assistantPrompt: string | null + logger: Logger + } & ParamsExcluding< + typeof getRelevantFilesForTraining, + 'messages' | 'userPrompt' | 'requestType' + >, +) { + const { messages, fileContext, assistantPrompt, logger } = params const COUNT = 50 const lastMessage = messages[messages.length - 1] @@ -279,31 +252,17 @@ export async function requestRelevantFilesForTraining(params: { ) const keyFiles = await getRelevantFilesForTraining({ + ...params, messages: messagesExcludingLastIfByUser, - system, userPrompt: keyFilesPrompt, requestType: 'Key', - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, - logger, }) const nonObviousFiles = await getRelevantFilesForTraining({ + ...params, messages: messagesExcludingLastIfByUser, - system, userPrompt: nonObviousPrompt, requestType: 'Non-Obvious', - agentStepId, - clientSessionId, - fingerprintId, - userInputId, - userId, - repoId, - logger, }) const candidateFiles = [...keyFiles.files, ...nonObviousFiles.files] @@ -315,20 +274,25 @@ export async function requestRelevantFilesForTraining(params: { return validatedFiles.slice(0, MAX_FILES_PER_REQUEST) } -async function getRelevantFiles(params: { - messages: Message[] - system: string | Array - userPrompt: string - requestType: string - agentStepId: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - repoId: string | undefined - modelId?: FinetunedVertexModel - logger: Logger -}) { +async function getRelevantFiles( + params: { + messages: Message[] + system: string | Array + userPrompt: string + requestType: string + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + repoId: string | undefined + modelId?: FinetunedVertexModel + logger: Logger + } & ParamsExcluding< + typeof promptFlashWithFallbacks, + 'messages' | 'model' | 'useFinetunedModel' + >, +) { const { messages, system, @@ -374,14 +338,10 @@ async function getRelevantFiles(params: { const finetunedModel = modelId ?? finetunedVertexModels.ft_filepicker_010 let response = await promptFlashWithFallbacks({ + ...params, messages: codebuffMessages, - clientSessionId, - userInputId, model: models.openrouter_gemini2_5_flash, - userId, useFinetunedModel: finetunedModel, - fingerprintId, - logger, }) const end = performance.now() const duration = end - start @@ -425,6 +385,7 @@ async function getRelevantFilesForTraining(params: { userInputId: string userId: string | undefined repoId: string | undefined + promptAiSdk: PromptAiSdkFn logger: Logger }) { const { @@ -438,6 +399,7 @@ async function getRelevantFilesForTraining(params: { userInputId, userId, repoId, + promptAiSdk, logger, } = params const bufferTokens = 100_000 diff --git a/backend/src/generate-diffs-prompt.ts b/backend/src/generate-diffs-prompt.ts index 11f0ac39c9..4575ab3c54 100644 --- a/backend/src/generate-diffs-prompt.ts +++ b/backend/src/generate-diffs-prompt.ts @@ -4,9 +4,9 @@ import { createSearchReplaceBlock, } from '@codebuff/common/util/file' -import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' - +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' export const parseAndGetDiffBlocksSingleFile = (params: { newContent: string @@ -133,24 +133,27 @@ export const tryToDoStringReplacementWithExtraIndentation = (params: { return null } -export async function retryDiffBlocksPrompt(params: { - filePath: string - oldContent: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - diffBlocksThatDidntMatch: { searchContent: string; replaceContent: string }[] - logger: Logger -}) { +export async function retryDiffBlocksPrompt( + params: { + filePath: string + oldContent: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + diffBlocksThatDidntMatch: { + searchContent: string + replaceContent: string + }[] + promptAiSdk: PromptAiSdkFn + logger: Logger + } & ParamsExcluding, +) { const { filePath, oldContent, - clientSessionId, - fingerprintId, - userInputId, - userId, diffBlocksThatDidntMatch, + promptAiSdk, logger, } = params const newPrompt = @@ -168,13 +171,9 @@ The search content needs to match an exact substring of the old file content, wh Provide a new set of SEARCH/REPLACE changes to make the intended edit from the old file.`.trim() const response = await promptAiSdk({ + ...params, messages: [{ role: 'user', content: newPrompt }], model: models.openrouter_claude_sonnet_4, - clientSessionId, - fingerprintId, - userInputId, - userId, - logger, }) const { diffBlocks: newDiffBlocks, diff --git a/backend/src/llm-apis/gemini-with-fallbacks.ts b/backend/src/llm-apis/gemini-with-fallbacks.ts index 15763c43f9..4a2cb71350 100644 --- a/backend/src/llm-apis/gemini-with-fallbacks.ts +++ b/backend/src/llm-apis/gemini-with-fallbacks.ts @@ -1,14 +1,13 @@ import { openaiModels, openrouterModels } from '@codebuff/common/old-constants' -import { promptAiSdk } from './vercel-ai-sdk/ai-sdk' - import type { CostMode, FinetunedVertexModel, - Model, } from '@codebuff/common/old-constants' -import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' +import type { Message } from '@codebuff/common/types/messages/codebuff-message' /** * Prompts a Gemini model with fallback logic. @@ -35,38 +34,33 @@ import type { Logger } from '@codebuff/common/types/contracts/logger' * @returns A promise that resolves to the complete response string from the successful API call. * @throws If all API calls (primary and fallbacks) fail. */ -export async function promptFlashWithFallbacks(params: { - messages: Message[] - clientSessionId: string - fingerprintId: string - userInputId: string - model: Model - userId: string | undefined - maxTokens?: number - temperature?: number - costMode?: CostMode - useGPT4oInsteadOfClaude?: boolean - thinkingBudget?: number - useFinetunedModel?: FinetunedVertexModel | undefined - logger: Logger -}): Promise { +export async function promptFlashWithFallbacks( + params: { + messages: Message[] + costMode?: CostMode + useGPT4oInsteadOfClaude?: boolean + thinkingBudget?: number + useFinetunedModel?: FinetunedVertexModel | undefined + promptAiSdk: PromptAiSdkFn + logger: Logger + } & ParamsExcluding, +): Promise { const { messages, costMode, useGPT4oInsteadOfClaude, useFinetunedModel, + promptAiSdk, logger, - ...geminiOptions } = params // Try finetuned model first if enabled if (useFinetunedModel) { try { return await promptAiSdk({ - ...geminiOptions, + ...params, messages, model: useFinetunedModel, - logger, }) } catch (error) { logger.warn( @@ -78,14 +72,14 @@ export async function promptFlashWithFallbacks(params: { try { // First try Gemini - return await promptAiSdk({ ...geminiOptions, messages, logger }) + return await promptAiSdk({ ...params, messages }) } catch (error) { logger.warn( { error }, `Error calling Gemini API, falling back to ${useGPT4oInsteadOfClaude ? 'gpt-4o' : 'Claude'}`, ) return await promptAiSdk({ - ...geminiOptions, + ...params, messages, model: useGPT4oInsteadOfClaude ? openaiModels.gpt4o @@ -96,7 +90,6 @@ export async function promptFlashWithFallbacks(params: { experimental: openrouterModels.openrouter_claude_3_5_haiku, ask: openrouterModels.openrouter_claude_3_5_haiku, }[costMode ?? 'normal'], - logger, }) } } diff --git a/backend/src/llm-apis/relace-api.ts b/backend/src/llm-apis/relace-api.ts index 54a2ba5075..1daf04aa49 100644 --- a/backend/src/llm-apis/relace-api.ts +++ b/backend/src/llm-apis/relace-api.ts @@ -7,8 +7,8 @@ import { env } from '@codebuff/internal' import { saveMessage } from '../llm-apis/message-cost-tracker' import { countTokens } from '../util/token-counter' -import { promptAiSdk } from './vercel-ai-sdk/ai-sdk' +import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' const timeoutPromise = (ms: number) => @@ -26,6 +26,7 @@ export async function promptRelaceAI(params: { userId: string | undefined messageId: string userMessage?: string + promptAiSdk: PromptAiSdkFn logger: Logger }) { const { @@ -38,6 +39,7 @@ export async function promptRelaceAI(params: { userId, userMessage, messageId, + promptAiSdk, logger, } = params const startTime = Date.now() diff --git a/backend/src/main-prompt.ts b/backend/src/main-prompt.ts index 4206cd69d8..a4faa0ccae 100644 --- a/backend/src/main-prompt.ts +++ b/backend/src/main-prompt.ts @@ -11,14 +11,14 @@ import { requestToolCall } from './websockets/websocket-action' import type { AgentTemplate } from './templates/types' import type { ClientAction } from '@codebuff/common/actions' import type { CostMode } from '@codebuff/common/old-constants' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { SessionState, AgentTemplateType, AgentOutput, } from '@codebuff/common/types/session-state' -import type { ParamsExcluding } from '@codebuff/common/types/function-params' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { WebSocket } from 'ws' export const mainPrompt = async ( @@ -42,7 +42,11 @@ export const mainPrompt = async ( | 'agentType' | 'fingerprintId' | 'fileContext' - >, + > & + ParamsExcluding< + typeof checkTerminalCommand, + 'prompt' | 'fingerprintId' | 'userInputId' + >, ): Promise<{ sessionState: SessionState output: AgentOutput @@ -156,12 +160,10 @@ export const mainPrompt = async ( // Check if this is a direct terminal command const startTime = Date.now() const terminalCommand = await checkTerminalCommand({ + ...params, prompt, - clientSessionId, fingerprintId, userInputId: promptId, - userId, - logger, }) const duration = Date.now() - startTime diff --git a/backend/src/process-file-block.ts b/backend/src/process-file-block.ts index 7ed48a3336..b74bad728e 100644 --- a/backend/src/process-file-block.ts +++ b/backend/src/process-file-block.ts @@ -18,7 +18,6 @@ import type { Message } from '@codebuff/common/types/messages/codebuff-message' export async function processFileBlock( params: { path: string - instructions: string | undefined initialContentPromise: Promise newContent: string messages: Message[] @@ -32,7 +31,15 @@ export async function processFileBlock( } & ParamsExcluding< typeof handleLargeFile, 'oldContent' | 'editSnippet' | 'filePath' - >, + > & + ParamsExcluding< + typeof fastRewrite, + 'initialContent' | 'editSnippet' | 'filePath' | 'userMessage' + > & + ParamsExcluding< + typeof shouldAddFilePlaceholders, + 'filePath' | 'oldContent' | 'rewrittenNewContent' | 'messageHistory' + >, ): Promise< | { tool: 'write_file' @@ -49,7 +56,6 @@ export async function processFileBlock( > { const { path, - instructions, initialContentPromise, newContent, messages, @@ -137,44 +143,29 @@ export async function processFileBlock( updatedContent = largeFileContent } else { updatedContent = await fastRewrite({ + ...params, initialContent: normalizedInitialContent, editSnippet: normalizedEditSnippet, filePath: path, - instructions, - clientSessionId, - fingerprintId, - userInputId, - userId, userMessage: lastUserPrompt, - logger, }) const shouldAddPlaceholders = await shouldAddFilePlaceholders({ + ...params, filePath: path, oldContent: normalizedInitialContent, rewrittenNewContent: updatedContent, messageHistory: messages, - fullResponse, - userId, - clientSessionId, - fingerprintId, - userInputId, - logger, }) if (shouldAddPlaceholders) { const placeholderComment = `... existing code ...` const updatedEditSnippet = `${placeholderComment}\n${updatedContent}\n${placeholderComment}` updatedContent = await fastRewrite({ + ...params, initialContent: normalizedInitialContent, editSnippet: updatedEditSnippet, filePath: path, - instructions, - clientSessionId, - fingerprintId, - userInputId, - userId, userMessage: lastUserPrompt, - logger, }) } } @@ -232,17 +223,22 @@ export async function processFileBlock( const LARGE_FILE_TOKEN_LIMIT = 64_000 -export async function handleLargeFile(params: { - oldContent: string - editSnippet: string - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - filePath: string - logger: Logger - promptAiSdk: PromptAiSdkFn -}): Promise { +export async function handleLargeFile( + params: { + oldContent: string + editSnippet: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + filePath: string + logger: Logger + promptAiSdk: PromptAiSdkFn + } & ParamsExcluding< + typeof retryDiffBlocksPrompt, + 'oldContent' | 'diffBlocksThatDidntMatch' + >, +): Promise { const { oldContent, editSnippet, @@ -329,14 +325,9 @@ Please output just the SEARCH/REPLACE blocks like this: const { newDiffBlocks, newDiffBlocksThatDidntMatch } = await retryDiffBlocksPrompt({ - filePath, + ...params, oldContent: updatedContent, - clientSessionId, - fingerprintId, - userInputId, - userId, diffBlocksThatDidntMatch, - logger, }) if (newDiffBlocksThatDidntMatch.length > 0) { diff --git a/backend/src/tools/handlers/tool/find-files.ts b/backend/src/tools/handlers/tool/find-files.ts index 322fe12f84..60a0d450f7 100644 --- a/backend/src/tools/handlers/tool/find-files.ts +++ b/backend/src/tools/handlers/tool/find-files.ts @@ -10,14 +10,17 @@ import { renderReadFilesResult } from '../../../util/parse-tool-call-xml' import { countTokens, countTokensJson } from '../../../util/token-counter' import { requestFiles } from '../../../websockets/websocket-action' -import type { TextBlock } from '../../../llm-apis/claude' import type { CodebuffToolHandlerFunction } from '../handler-function-type' -import type { Logger } from '@codebuff/common/types/contracts/logger' import type { GetExpandedFileContextForTrainingBlobTrace } from '@codebuff/bigquery' import type { CodebuffToolCall, CodebuffToolOutput, } from '@codebuff/common/tools/list' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { + ParamsExcluding, + ParamsOf, +} from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { WebSocket } from 'ws' @@ -26,24 +29,44 @@ import type { WebSocket } from 'ws' // TODO: We might want to be able to turn this on on a per-repo basis. const COLLECT_FULL_FILE_CONTEXT = false -export const handleFindFiles = ((params: { - previousToolCallFinished: Promise - toolCall: CodebuffToolCall<'find_files'> - logger: Logger - - fileContext: ProjectFileContext - agentStepId: string - clientSessionId: string - userInputId: string - - state: { - ws?: WebSocket - fingerprintId?: string - userId?: string - repoId?: string - messages?: Message[] - } -}): { result: Promise>; state: {} } => { +export const handleFindFiles = (( + params: { + previousToolCallFinished: Promise + toolCall: CodebuffToolCall<'find_files'> + logger: Logger + + fileContext: ProjectFileContext + agentStepId: string + clientSessionId: string + userInputId: string + + state: { + ws?: WebSocket + fingerprintId?: string + userId?: string + repoId?: string + messages?: Message[] + } + } & ParamsExcluding< + typeof requestRelevantFiles, + | 'messages' + | 'system' + | 'assistantPrompt' + | 'fingerprintId' + | 'userId' + | 'repoId' + > & + ParamsExcluding< + typeof uploadExpandedFileContextForTraining, + | 'ws' + | 'messages' + | 'system' + | 'assistantPrompt' + | 'fingerprintId' + | 'userId' + | 'repoId' + >, +): { result: Promise>; state: {} } => { const { previousToolCallFinished, toolCall, @@ -87,36 +110,29 @@ export const handleFindFiles = ((params: { CodebuffToolOutput<'find_files'> > = async () => { const requestedFiles = await requestRelevantFiles({ + ...params, messages, system, - fileContext, assistantPrompt: prompt, - agentStepId, - clientSessionId, fingerprintId, - userInputId, userId, repoId, - logger, }) if (requestedFiles && requestedFiles.length > 0) { const addedFiles = await getFileReadingUpdates(ws, requestedFiles) if (COLLECT_FULL_FILE_CONTEXT && addedFiles.length > 0) { - uploadExpandedFileContextForTraining( + uploadExpandedFileContextForTraining({ + ...params, ws, - { messages, system }, - fileContext, - prompt, - agentStepId, - clientSessionId, + messages, + system, + assistantPrompt: prompt, fingerprintId, - userInputId, userId, repoId, - logger, - ).catch((error) => { + }).catch((error) => { logger.error( { error }, 'Error uploading expanded file context for training', @@ -165,37 +181,26 @@ export const handleFindFiles = ((params: { }) satisfies CodebuffToolHandlerFunction<'find_files'> async function uploadExpandedFileContextForTraining( - ws: WebSocket, - { - messages, - system, - }: { - messages: Message[] - system: string | Array - }, - fileContext: ProjectFileContext, - assistantPrompt: string | null, - agentStepId: string, - clientSessionId: string, - fingerprintId: string, - userInputId: string, - userId: string | undefined, - repoId: string | undefined, - logger: Logger, + params: { + ws: WebSocket + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + logger: Logger + } & ParamsOf, ) { - const files = await requestRelevantFilesForTraining({ - messages, - system, - fileContext, - assistantPrompt, + const { + ws, agentStepId, clientSessionId, fingerprintId, userInputId, userId, - repoId, logger, - }) + } = params + const files = await requestRelevantFilesForTraining(params) const loadedFiles = await requestFiles({ ws, filePaths: files }) diff --git a/scripts/ft-file-selection/grade-traces.ts b/scripts/ft-file-selection/grade-traces.ts index 7e7fdcedd3..b54d477240 100644 --- a/scripts/ft-file-selection/grade-traces.ts +++ b/scripts/ft-file-selection/grade-traces.ts @@ -1,3 +1,4 @@ +import { promptAiSdk } from '@codebuff/backend/llm-apis/vercel-ai-sdk/ai-sdk' import { getTracesAndRelabelsForUser, setupBigQuery } from '@codebuff/bigquery' import { gradeRun } from '../../backend/src/admin/grade-runs' @@ -46,7 +47,11 @@ async function gradeTraces({ logger }: { logger: Logger }) { batch.map(async (traceAndRelabels) => { try { console.log(`Grading trace ${traceAndRelabels.trace.id}`) - const result = await gradeRun({ ...traceAndRelabels, logger }) + const result = await gradeRun({ + ...traceAndRelabels, + promptAiSdk, + logger, + }) return { traceId: traceAndRelabels.trace.id, status: 'success', diff --git a/scripts/ft-file-selection/relabel-traces-with-context.ts b/scripts/ft-file-selection/relabel-traces-with-context.ts index 144f929c0b..0e4c63c39c 100644 --- a/scripts/ft-file-selection/relabel-traces-with-context.ts +++ b/scripts/ft-file-selection/relabel-traces-with-context.ts @@ -1,3 +1,4 @@ +import { promptAiSdk } from '@codebuff/backend/llm-apis/vercel-ai-sdk/ai-sdk' import { getTracesAndAllDataForUser, setupBigQuery } from '@codebuff/bigquery' import { models } from '@codebuff/common/old-constants' @@ -95,6 +96,7 @@ async function runTraces() { fileBlobs, model: MODEL_TO_TEST, dataset: DATASET, + promptAiSdk, logger: console, }) console.log(`Successfully stored relabel for trace ${trace.id}`) diff --git a/scripts/ft-file-selection/relabel-traces.ts b/scripts/ft-file-selection/relabel-traces.ts index 5ce518477f..f9b296e834 100644 --- a/scripts/ft-file-selection/relabel-traces.ts +++ b/scripts/ft-file-selection/relabel-traces.ts @@ -82,6 +82,7 @@ async function runTraces() { userInputId: 'relabel-trace-run', userId: 'relabel-trace-run', logger: console, + promptAiSdk, }) } From 01a4aff777ecd08c6c9ce50da7bf2126a7b1e55d Mon Sep 17 00:00:00 2001 From: Charles Lien Date: Thu, 9 Oct 2025 18:02:30 -0700 Subject: [PATCH 6/6] pass in promptAiSdkStructured --- backend/src/get-documentation-for-query.ts | 51 +++++++++++++--------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/backend/src/get-documentation-for-query.ts b/backend/src/get-documentation-for-query.ts index 35dafa8cd5..4492216818 100644 --- a/backend/src/get-documentation-for-query.ts +++ b/backend/src/get-documentation-for-query.ts @@ -4,9 +4,13 @@ import { uniq } from 'lodash' import { z } from 'zod/v4' import { fetchContext7LibraryDocumentation } from './llm-apis/context7-api' -import { promptAiSdkStructured } from './llm-apis/vercel-ai-sdk/ai-sdk' +import type { PromptAiSdkStructuredFn } from '@codebuff/common/types/contracts/llm' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { + ParamsExcluding, + ParamsOf, +} from '@codebuff/common/types/function-params' const DELIMITER = `\n\n----------------------------------------\n\n` @@ -20,15 +24,18 @@ const DELIMITER = `\n\n----------------------------------------\n\n` * @param options.userId The ID of the user making the request * @returns The documentation text chunks or null if no relevant docs found */ -export async function getDocumentationForQuery(params: { - query: string - tokens?: number - clientSessionId: string - userInputId: string - fingerprintId: string - userId?: string - logger: Logger -}): Promise { +export async function getDocumentationForQuery( + params: { + query: string + tokens?: number + clientSessionId: string + userInputId: string + fingerprintId: string + userId?: string + logger: Logger + } & ParamsOf & + ParamsExcluding, +): Promise { const { query, tokens, @@ -41,14 +48,7 @@ export async function getDocumentationForQuery(params: { const startTime = Date.now() // 1. Search for relevant libraries - const libraryResults = await suggestLibraries({ - query, - clientSessionId, - userInputId, - fingerprintId, - userId, - logger, - }) + const libraryResults = await suggestLibraries(params) if (!libraryResults || libraryResults.libraries.length === 0) { logger.info( @@ -104,6 +104,7 @@ export async function getDocumentationForQuery(params: { // 3. Filter relevant chunks using another LLM call const filterResults = await filterRelevantChunks({ + ...params, query, allChunks: allUniqueChunks, clientSessionId, @@ -163,10 +164,18 @@ const suggestLibraries = async (params: { userInputId: string fingerprintId: string userId?: string + promptAiSdkStructured: PromptAiSdkStructuredFn logger: Logger }) => { - const { query, clientSessionId, userInputId, fingerprintId, userId, logger } = - params + const { + query, + clientSessionId, + userInputId, + fingerprintId, + userId, + promptAiSdkStructured, + logger, + } = params const prompt = `You are an expert at documentation for libraries. Given a user's query return a list of (library name, topic) where each library name is the name of a library and topic is a keyword or phrase that specifies a topic within the library that is most relevant to the user's query. @@ -231,6 +240,7 @@ async function filterRelevantChunks(params: { userInputId: string fingerprintId: string userId?: string + promptAiSdkStructured: PromptAiSdkStructuredFn logger: Logger }): Promise<{ relevantChunks: string[]; geminiDuration: number } | null> { const { @@ -240,6 +250,7 @@ async function filterRelevantChunks(params: { userInputId, fingerprintId, userId, + promptAiSdkStructured, logger, } = params const prompt = `You are an expert at analyzing documentation queries. Given a user's query and a list of documentation chunks, determine which chunks are relevant to the query. Choose as few chunks as possible, likely none. Only include chunks if they are relevant to the user query.