diff --git a/src/cloudflare/internal/ai-api.ts b/src/cloudflare/internal/ai-api.ts index 0176ec9e6bb..73f4a0e542a 100644 --- a/src/cloudflare/internal/ai-api.ts +++ b/src/cloudflare/internal/ai-api.ts @@ -133,6 +133,62 @@ function findReadableStreamKeys( return readableStreamKeys; } +/** + * Convert `gateway` options into the `cf-aig-*` request headers that AI Gateway + * enforces. This mirrors `AiGateway.#getHeadersFromOptions` in `aig-api.ts` so + * that `env.AI.run(model, inputs, { gateway })` honors the same options as the + * Universal endpoint (`env.AI.gateway(id).run(...)`). `gateway.id` is + * intentionally not emitted here — it is conveyed via the request body / the + * `/ai-gateway/run` endpoint. + */ +function gatewayOptionsToHeaders( + gateway: GatewayOptions +): Record { + const headers: Record = {}; + + if (gateway.skipCache !== undefined) { + headers['cf-aig-skip-cache'] = gateway.skipCache ? 'true' : 'false'; + } + + if (gateway.cacheTtl) { + headers['cf-aig-cache-ttl'] = gateway.cacheTtl.toString(); + } + + if (gateway.metadata) { + headers['cf-aig-metadata'] = JSON.stringify(gateway.metadata); + } + + if (gateway.cacheKey) { + headers['cf-aig-cache-key'] = gateway.cacheKey; + } + + if (gateway.collectLog !== undefined) { + headers['cf-aig-collect-log'] = gateway.collectLog ? 'true' : 'false'; + } + + if (gateway.eventId !== undefined) { + headers['cf-aig-event-id'] = gateway.eventId; + } + + if (gateway.requestTimeoutMs !== undefined) { + headers['cf-aig-request-timeout'] = gateway.requestTimeoutMs.toString(); + } + + if (gateway.retries !== undefined) { + if (gateway.retries.maxAttempts !== undefined) { + headers['cf-aig-max-attempts'] = gateway.retries.maxAttempts.toString(); + } + if (gateway.retries.retryDelayMs !== undefined) { + headers['cf-aig-retry-delay'] = gateway.retries.retryDelayMs.toString(); + } + if (gateway.retries.backoff !== undefined) { + headers['cf-aig-backoff'] = gateway.retries.backoff; + } + } + + return headers; +} + export class Ai { #fetcher: Fetcher; @@ -177,6 +233,12 @@ export class Ai { headers: { ...this.#options.sessionOptions?.extraHeaders, ...this.#options.extraHeaders, + // Translate gateway options into the cf-aig-* headers that AI Gateway + // enforces (e.g. requestTimeoutMs -> cf-aig-request-timeout). Placed + // after extraHeaders so the typed `gateway` option takes precedence. + ...(cleanedOptions.gateway + ? gatewayOptionsToHeaders(cleanedOptions.gateway) + : {}), 'content-type': 'application/json', 'cf-consn-sdk-version': '2.0.0', 'cf-consn-model-id': `${this.#options.prefix ? `${this.#options.prefix}:` : ''}${model}`, diff --git a/src/cloudflare/internal/test/ai/ai-api-test.js b/src/cloudflare/internal/test/ai/ai-api-test.js index e9b31898c43..7b3c3093770 100644 --- a/src/cloudflare/internal/test/ai/ai-api-test.js +++ b/src/cloudflare/internal/test/ai/ai-api-test.js @@ -271,6 +271,64 @@ export const tests = { }); } + { + // Gateway options must be forwarded as cf-aig-* request headers, not just + // in the body. Regression test for ESCALATION-3355: requestTimeoutMs was + // previously sent only in the body and therefore ignored by AI Gateway, + // which enforces the cf-aig-request-timeout header. + const resp = await env.ai.run( + 'echoGatewayHeaders', + { prompt: 'test' }, + { + gateway: { + id: 'my-gateway', + requestTimeoutMs: 1000, + cacheTtl: 3600, + skipCache: true, + cacheKey: 'abc', + metadata: { employee: 1233 }, + collectLog: false, + eventId: 'evt-1', + retries: { + maxAttempts: 3, + retryDelayMs: 250, + backoff: 'exponential', + }, + }, + } + ); + + assert.deepStrictEqual(resp, { + headers: { + 'cf-aig-request-timeout': '1000', + 'cf-aig-cache-ttl': '3600', + 'cf-aig-skip-cache': 'true', + 'cf-aig-cache-key': 'abc', + 'cf-aig-metadata': '{"employee":1233}', + 'cf-aig-collect-log': 'false', + 'cf-aig-event-id': 'evt-1', + 'cf-aig-max-attempts': '3', + 'cf-aig-retry-delay': '250', + 'cf-aig-backoff': 'exponential', + }, + requestUrl: 'https://workers-binding.ai/ai-gateway/run?version=3', + }); + } + + { + // Gateway requestTimeoutMs alone maps to cf-aig-request-timeout. + const resp = await env.ai.run( + 'echoGatewayHeaders', + { prompt: 'test' }, + { gateway: { id: 'my-gateway', requestTimeoutMs: 1000 } } + ); + + assert.deepStrictEqual(resp, { + headers: { 'cf-aig-request-timeout': '1000' }, + requestUrl: 'https://workers-binding.ai/ai-gateway/run?version=3', + }); + } + { // Test models const resp = await env.ai.models(); diff --git a/src/cloudflare/internal/test/ai/ai-mock.js b/src/cloudflare/internal/test/ai/ai-mock.js index 2b1746a985f..663d83e28e9 100644 --- a/src/cloudflare/internal/test/ai/ai-mock.js +++ b/src/cloudflare/internal/test/ai/ai-mock.js @@ -74,6 +74,26 @@ export default { ); } + if (modelName === 'echoGatewayHeaders') { + // Echo back the cf-aig-* headers the binding sent so the test can assert + // that gateway options are translated into headers (not just the body). + const aigHeaders = {}; + for (const [key, value] of request.headers.entries()) { + if (key.startsWith('cf-aig-')) { + aigHeaders[key] = value; + } + } + return Response.json( + { + headers: aigHeaders, + requestUrl: request.url, + }, + { + headers: respHeaders, + } + ); + } + if (modelName === 'readableStreamIputs') { return Response.json( {