-
Notifications
You must be signed in to change notification settings - Fork 62
Expand file tree
/
Copy pathdatabricks.ts
More file actions
79 lines (71 loc) · 2.1 KB
/
databricks.ts
File metadata and controls
79 lines (71 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import { z } from "zod/v3";
import { DatabricksOAuthSecretSchema } from "@braintrust/proxy/schema";
const databricksOAuthResponseSchema = z.union([
z.object({
access_token: z.string(),
token_type: z.literal("Bearer"),
expires_in: z.number(),
}),
z.object({
error: z.string(),
}),
]);
export async function getDatabricksOAuthAccessToken({
secret,
apiBase,
digest,
cacheGet,
cachePut,
}: {
secret: z.infer<typeof DatabricksOAuthSecretSchema>;
apiBase: string;
digest: (message: string) => Promise<string>;
cacheGet: (encryptionKey: string, key: string) => Promise<string | null>;
cachePut: (
encryptionKey: string,
key: string,
value: string,
ttl_seconds?: number,
) => Promise<void>;
}): Promise<string> {
const { client_id, client_secret } = secret;
const tokenUrl = `${apiBase}/oidc/v1/token`;
const cachePath = await digest(`${client_id}:${client_secret}:${apiBase}`);
const cacheKey = `aiproxy/proxy/databricks/${cachePath}`;
const encryptionKey = await digest(`${cachePath}:${client_secret}`);
const cached = await cacheGet(encryptionKey, cacheKey);
if (cached) {
return cached;
}
// Create credentials for basic auth.
const credentials = Buffer.from(`${client_id}:${client_secret}`).toString(
"base64",
);
const res = await fetch(tokenUrl, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization: `Basic ${credentials}`,
},
body: new URLSearchParams({
grant_type: "client_credentials",
scope: "all-apis",
}),
});
if (!res.ok) {
throw new Error(
`Databricks OAuth error (${res.status}): ${res.statusText} ${await res.text()}`,
);
}
const data = await res.json();
const parsed = databricksOAuthResponseSchema.parse(data);
if ("error" in parsed) {
throw new Error(`Databricks OAuth error: ${parsed.error}`);
}
// Give it a 1 minute buffer.
const cacheTtl = Math.max(parsed.expires_in - 60, 0);
if (cacheTtl > 0) {
await cachePut(encryptionKey, cacheKey, parsed.access_token, cacheTtl);
}
return parsed.access_token;
}