-
Notifications
You must be signed in to change notification settings - Fork 103
Expand file tree
/
Copy pathmodelEntries.ts
More file actions
195 lines (164 loc) · 5.16 KB
/
modelEntries.ts
File metadata and controls
195 lines (164 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import type { ProviderModelEntry } from "@/common/orpc/types";
import { normalizeToCanonical } from "@/common/utils/ai/models";
/**
* Minimal providers-config shape needed for model-entry lookup.
* Both the raw disk config (`ProvidersConfig`) and the API-facing map
* (`ProvidersConfigMap`) satisfy this, so callers don't need to convert.
*/
export type ProvidersConfigWithModels = Record<
string,
{ models?: ProviderModelEntry[] } | undefined
>;
interface ParsedProviderModelId {
provider: string;
modelId: string;
}
export function getProviderModelEntryId(entry: ProviderModelEntry): string {
return typeof entry === "string" ? entry : entry.id;
}
export function getProviderModelEntryContextWindowTokens(entry: ProviderModelEntry): number | null {
if (typeof entry === "string") {
return null;
}
return entry.contextWindowTokens ?? null;
}
export function getProviderModelEntryMappedTo(entry: ProviderModelEntry): string | null {
if (typeof entry === "string") {
return null;
}
return entry.mappedToModel ?? null;
}
function parseProviderModelId(fullModelId: string): ParsedProviderModelId | null {
const separatorIndex = fullModelId.indexOf(":");
if (separatorIndex <= 0 || separatorIndex >= fullModelId.length - 1) {
return null;
}
return {
provider: fullModelId.slice(0, separatorIndex),
modelId: fullModelId.slice(separatorIndex + 1),
};
}
function findProviderModelEntry(
providersConfig: ProvidersConfigWithModels | null,
provider: string,
modelId: string
): ProviderModelEntry | null {
const entries = providersConfig?.[provider]?.models;
if (!entries || entries.length === 0) {
return null;
}
for (const entry of entries) {
if (getProviderModelEntryId(entry) === modelId) {
return entry;
}
}
return null;
}
/**
* Scoped-first provider model entry lookup.
*
* Checks the raw (possibly gateway-scoped) provider block first so
* gateway-local overrides like contextWindowTokens and mappedToModel
* take effect. Falls back to canonical lookup only when the scoped
* lookup misses.
*/
function findProviderModelEntryScoped(
fullModelId: string,
providersConfig: ProvidersConfigWithModels | null
): ProviderModelEntry | null {
const rawParsed = parseProviderModelId(fullModelId);
if (rawParsed) {
const scopedEntry = findProviderModelEntry(
providersConfig,
rawParsed.provider,
rawParsed.modelId
);
if (scopedEntry) {
return scopedEntry;
}
}
const canonical = normalizeToCanonical(fullModelId);
if (canonical === fullModelId) {
return null;
}
const canonicalParsed = parseProviderModelId(canonical);
if (!canonicalParsed) {
return null;
}
return findProviderModelEntry(providersConfig, canonicalParsed.provider, canonicalParsed.modelId);
}
export function getModelContextWindowOverride(
fullModelId: string,
providersConfig: ProvidersConfigWithModels | null
): number | null {
const entry = findProviderModelEntryScoped(fullModelId, providersConfig);
return entry ? getProviderModelEntryContextWindowTokens(entry) : null;
}
export function resolveModelForMetadata(
fullModelId: string,
providersConfig: ProvidersConfigWithModels | null
): string {
const entry = findProviderModelEntryScoped(fullModelId, providersConfig);
return (entry ? getProviderModelEntryMappedTo(entry) : null) ?? fullModelId;
}
function parseModelId(rawValue: unknown): string | null {
if (typeof rawValue !== "string") {
return null;
}
const trimmed = rawValue.trim();
return trimmed.length > 0 ? trimmed : null;
}
function parseContextWindowTokens(rawValue: unknown): number | null {
if (typeof rawValue !== "number" || !Number.isInteger(rawValue) || rawValue <= 0) {
return null;
}
return rawValue;
}
export function normalizeProviderModelEntry(rawEntry: unknown): ProviderModelEntry | null {
if (typeof rawEntry === "string") {
const modelId = parseModelId(rawEntry);
return modelId ?? null;
}
if (typeof rawEntry !== "object" || rawEntry === null) {
return null;
}
const entry = rawEntry as {
id?: unknown;
contextWindowTokens?: unknown;
mappedToModel?: unknown;
};
const modelId = parseModelId(entry.id);
if (!modelId) {
return null;
}
const contextWindowTokens = parseContextWindowTokens(entry.contextWindowTokens);
const mappedToModel = parseModelId(entry.mappedToModel);
if (contextWindowTokens === null && mappedToModel === null) {
return modelId;
}
return {
id: modelId,
...(contextWindowTokens !== null ? { contextWindowTokens } : {}),
...(mappedToModel !== null ? { mappedToModel } : {}),
};
}
export function normalizeProviderModelEntries(rawEntries: unknown): ProviderModelEntry[] {
if (!Array.isArray(rawEntries)) {
return [];
}
const normalized: ProviderModelEntry[] = [];
const seen = new Set<string>();
for (const rawEntry of rawEntries) {
const normalizedEntry = normalizeProviderModelEntry(rawEntry);
if (!normalizedEntry) {
continue;
}
const modelId = getProviderModelEntryId(normalizedEntry);
if (seen.has(modelId)) {
continue;
}
seen.add(modelId);
normalized.push(normalizedEntry);
}
return normalized;
}