Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/heavy-queens-watch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": patch
---

Add support for LLM provider middleware
30 changes: 24 additions & 6 deletions packages/core/lib/v3/llm/LLMProvider.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { LanguageModelV2Middleware } from "@ai-sdk/provider";
import {
ExperimentalNotConfiguredError,
UnsupportedAISDKModelProviderError,
Expand Down Expand Up @@ -31,7 +32,7 @@ import { mistral, createMistral } from "@ai-sdk/mistral";
import { deepseek, createDeepSeek } from "@ai-sdk/deepseek";
import { perplexity, createPerplexity } from "@ai-sdk/perplexity";
import { ollama, createOllama } from "ollama-ai-provider-v2";
import { gateway, createGateway } from "ai";
import { gateway, createGateway, wrapLanguageModel } from "ai";
import { AISDKProvider, AISDKCustomProvider } from "../types/public/model.js";

const AISDKProviders: Record<string, AISDKProvider> = {
Expand Down Expand Up @@ -103,11 +104,13 @@ export function getAISDKLanguageModel(
subProvider: string,
subModelName: string,
clientOptions?: ClientOptions,
middleware?: LanguageModelV2Middleware,
) {
const hasValidOptions =
clientOptions &&
Object.values(clientOptions).some((v) => v !== undefined && v !== null);

let model;
if (hasValidOptions) {
const creator = AISDKProvidersWithAPIKey[subProvider];
if (!creator) {
Expand All @@ -117,8 +120,7 @@ export function getAISDKLanguageModel(
);
}
const provider = creator(clientOptions);
// Get the specific model from the provider
return provider(subModelName);
model = provider(subModelName);
} else {
const provider = AISDKProviders[subProvider];
if (!provider) {
Expand All @@ -127,21 +129,35 @@ export function getAISDKLanguageModel(
Object.keys(AISDKProviders),
);
}
return provider(subModelName);
model = provider(subModelName);
}

if (middleware) {
return wrapLanguageModel({ model, middleware });
}
return model;
}

export class LLMProvider {
private logger: (message: LogLine) => void;
private middleware?: LanguageModelV2Middleware;

constructor(logger: (message: LogLine) => void) {
constructor(
logger: (message: LogLine) => void,
middleware?: LanguageModelV2Middleware,
) {
this.logger = logger;
this.middleware = middleware;
}

getClient(
modelName: AvailableModel,
clientOptions?: ClientOptions,
options?: { experimental?: boolean; disableAPI?: boolean },
options?: {
experimental?: boolean;
disableAPI?: boolean;
middleware?: LanguageModelV2Middleware;
},
): LLMClient {
if (modelName.includes("/")) {
const firstSlashIndex = modelName.indexOf("/");
Expand All @@ -155,10 +171,12 @@ export class LLMProvider {
throw new ExperimentalNotConfiguredError("Vertex provider");
}

const effectiveMiddleware = options?.middleware ?? this.middleware;
const languageModel = getAISDKLanguageModel(
subProvider,
subModelName,
clientOptions,
effectiveMiddleware,
);

return new AISdkClient({
Expand Down
16 changes: 14 additions & 2 deletions packages/core/lib/v3/types/public/model.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import type { ClientOptions as AnthropicClientOptionsBase } from "@anthropic-ai/sdk";
import type { GoogleVertexProviderSettings as GoogleVertexProviderSettingsBase } from "@ai-sdk/google-vertex";
import type { LanguageModelV2 } from "@ai-sdk/provider";
import type {
LanguageModelV2,
LanguageModelV2Middleware,
} from "@ai-sdk/provider";
import type { ClientOptions as OpenAIClientOptionsBase } from "openai";
import type { AgentProviderType } from "./agent.js";

Expand Down Expand Up @@ -120,4 +123,13 @@ export type ClientOptions = (

export type ModelConfiguration =
| AvailableModel
| (ClientOptions & { modelName: AvailableModel });
| (ClientOptions & {
modelName: AvailableModel;
/**
* Optional AI SDK middleware applied to every LanguageModelV2 created for this model.
* Use this to intercept LLM calls for usage tracking, logging, request transforms, etc.
*
* Only effective when running locally (direct mode). Cannot be serialized over HTTP,
*/
middleware?: LanguageModelV2Middleware;
});
23 changes: 18 additions & 5 deletions packages/core/lib/v3/v3.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { LanguageModelV2Middleware } from "@ai-sdk/provider";
import fs from "fs";
import os from "os";
import path from "path";
Expand Down Expand Up @@ -98,6 +99,7 @@ const DEFAULT_AGENT_TOOL_TIMEOUT_MS = 45000;
type ResolvedModelConfiguration = {
modelName: AvailableModel;
clientOptions?: ClientOptions;
middleware?: LanguageModelV2Middleware;
};

function resolveModelConfiguration(
Expand All @@ -112,7 +114,7 @@ function resolveModelConfiguration(
}

if (model && typeof model === "object") {
const { modelName, ...clientOptions } = model;
const { modelName, middleware, ...clientOptions } = model;
if (!modelName) {
throw new StagehandInvalidArgumentError(
"model.modelName is required when providing client options.",
Expand All @@ -121,6 +123,7 @@ function resolveModelConfiguration(
return {
modelName,
clientOptions: clientOptions as ClientOptions,
middleware,
};
}

Expand Down Expand Up @@ -323,11 +326,13 @@ export class V3 {
} catch {
// ignore
}
const { modelName, clientOptions } = resolveModelConfiguration(opts.model);
const { modelName, clientOptions, middleware } = resolveModelConfiguration(
opts.model,
);
this.modelName = modelName;
this.experimental = opts.experimental ?? false;
this.logInferenceToFile = opts.logInferenceToFile ?? false;
this.llmProvider = new LLMProvider(this.logger);
this.llmProvider = new LLMProvider(this.logger, middleware);
this.domSettleTimeoutMs = opts.domSettleTimeout;
this.disableAPI = opts.disableAPI ?? false;

Expand Down Expand Up @@ -440,17 +445,20 @@ export class V3 {

let modelName: AvailableModel | string;
let clientOptions: ClientOptions | undefined;
let perCallMiddleware: LanguageModelV2Middleware | undefined;

if (typeof model === "string") {
modelName = model;
} else {
const { modelName: overrideModelName, ...rest } = model;
const { modelName: overrideModelName, middleware, ...rest } = model;
modelName = overrideModelName;
clientOptions = rest as ClientOptions;
perCallMiddleware = middleware;
}

if (
modelName === this.modelName &&
!perCallMiddleware &&
(!clientOptions || Object.keys(clientOptions).length === 0)
) {
return this.llmClient;
Expand All @@ -475,6 +483,7 @@ export class V3 {
const cacheKey = JSON.stringify({
modelName,
clientOptions: mergedOptions,
hasMiddleware: !!perCallMiddleware,
});

const cached = this.overrideLlmClients.get(cacheKey);
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
Outdated
Expand All @@ -485,7 +494,11 @@ export class V3 {
const client = this.llmProvider.getClient(
modelName as AvailableModel,
mergedOptions,
{ experimental: this.experimental, disableAPI: this.disableAPI },
{
experimental: this.experimental,
disableAPI: this.disableAPI,
middleware: perCallMiddleware,
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
Outdated
},
);

this.overrideLlmClients.set(cacheKey, client);
Expand Down
Loading
Loading