|
1 | 1 | import { createLogger, type Logger } from '@sim/logger' |
2 | 2 | import type { ChatCompletionChunk } from 'openai/resources/chat/completions' |
3 | 3 | import type { CompletionUsage } from 'openai/resources/completions' |
4 | | -import { getEnv, isTruthy } from '@/lib/core/config/env' |
| 4 | +import { env } from '@/lib/core/config/env' |
5 | 5 | import { isHosted } from '@/lib/core/config/feature-flags' |
6 | 6 | import { isCustomTool } from '@/executor/constants' |
7 | 7 | import { |
@@ -131,6 +131,9 @@ function filterBlacklistedModelsFromProviderMap( |
131 | 131 | ): Record<string, ProviderId> { |
132 | 132 | const filtered: Record<string, ProviderId> = {} |
133 | 133 | for (const [model, providerId] of Object.entries(providerMap)) { |
| 134 | + if (isProviderBlacklisted(providerId)) { |
| 135 | + continue |
| 136 | + } |
134 | 137 | if (!isModelBlacklisted(model)) { |
135 | 138 | filtered[model] = providerId |
136 | 139 | } |
@@ -192,35 +195,42 @@ export function getProviderModels(providerId: ProviderId): string[] { |
192 | 195 | return getProviderModelsFromDefinitions(providerId) |
193 | 196 | } |
194 | 197 |
|
195 | | -interface ModelBlacklist { |
196 | | - models: string[] |
197 | | - prefixes: string[] |
198 | | - envOverride?: string |
| 198 | +function getBlacklistedProviders(): string[] { |
| 199 | + if (!env.BLACKLISTED_PROVIDERS) return [] |
| 200 | + return env.BLACKLISTED_PROVIDERS.split(',').map((p) => p.trim().toLowerCase()) |
199 | 201 | } |
200 | 202 |
|
201 | | -const MODEL_BLACKLISTS: ModelBlacklist[] = [ |
202 | | - { |
203 | | - models: ['deepseek-chat', 'deepseek-v3', 'deepseek-r1'], |
204 | | - prefixes: ['openrouter/deepseek', 'openrouter/tngtech'], |
205 | | - envOverride: 'DEEPSEEK_MODELS_ENABLED', |
206 | | - }, |
207 | | -] |
| 203 | +export function isProviderBlacklisted(providerId: string): boolean { |
| 204 | + const blacklist = getBlacklistedProviders() |
| 205 | + return blacklist.includes(providerId.toLowerCase()) |
| 206 | +} |
| 207 | + |
| 208 | +/** |
| 209 | + * Get the list of blacklisted models from env var. |
| 210 | + * BLACKLISTED_MODELS supports: |
| 211 | + * - Exact model names: "gpt-4,claude-3-opus" |
| 212 | + * - Prefix patterns with *: "claude-*,gpt-4-*" (matches models starting with that prefix) |
| 213 | + */ |
| 214 | +function getBlacklistedModels(): { models: string[]; prefixes: string[] } { |
| 215 | + if (!env.BLACKLISTED_MODELS) return { models: [], prefixes: [] } |
| 216 | + |
| 217 | + const entries = env.BLACKLISTED_MODELS.split(',').map((m) => m.trim().toLowerCase()) |
| 218 | + const models = entries.filter((e) => !e.endsWith('*')) |
| 219 | + const prefixes = entries.filter((e) => e.endsWith('*')).map((e) => e.slice(0, -1)) |
| 220 | + |
| 221 | + return { models, prefixes } |
| 222 | +} |
208 | 223 |
|
209 | 224 | function isModelBlacklisted(model: string): boolean { |
210 | 225 | const lowerModel = model.toLowerCase() |
| 226 | + const blacklist = getBlacklistedModels() |
211 | 227 |
|
212 | | - for (const blacklist of MODEL_BLACKLISTS) { |
213 | | - if (blacklist.envOverride && isTruthy(getEnv(blacklist.envOverride))) { |
214 | | - continue |
215 | | - } |
216 | | - |
217 | | - if (blacklist.models.includes(lowerModel)) { |
218 | | - return true |
219 | | - } |
| 228 | + if (blacklist.models.includes(lowerModel)) { |
| 229 | + return true |
| 230 | + } |
220 | 231 |
|
221 | | - if (blacklist.prefixes.some((prefix) => lowerModel.startsWith(prefix))) { |
222 | | - return true |
223 | | - } |
| 232 | + if (blacklist.prefixes.some((prefix) => lowerModel.startsWith(prefix))) { |
| 233 | + return true |
224 | 234 | } |
225 | 235 |
|
226 | 236 | return false |
|
0 commit comments