diff --git a/Makefile b/Makefile index 8221f7c0..45c119f2 100644 --- a/Makefile +++ b/Makefile @@ -91,6 +91,15 @@ record-api: -output=tests/contract/testdata/openai/models.json @echo "Done! Golden files saved to tests/contract/testdata/" +record-api-kimi: + @echo "Recording Kimi chat completion..." + go run ./cmd/recordapi -provider=kimi -endpoint=chat \ + -output=tests/contract/testdata/kimi/chat_completion.json + @echo "Recording Kimi models..." + go run ./cmd/recordapi -provider=kimi -endpoint=models \ + -output=tests/contract/testdata/kimi/models.json + @echo "Done! Golden files saved to tests/contract/testdata/" + swagger: go run github.com/swaggo/swag/v2/cmd/swag init --generalInfo main.go \ --dir cmd/gomodel,internal \ diff --git a/README.md b/README.md index 4da6d3cc..dca55ad5 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@

- A fast and lightweight AI gateway written in Go, providing unified OpenAI-compatible and Anthropic-compatible APIs for OpenAI, Anthropic, Gemini, DeepSeek, xAI, Groq, OpenRouter, Z.ai, Azure OpenAI, Oracle, Ollama, and more. + A fast and lightweight AI gateway written in Go, providing unified OpenAI-compatible and Anthropic-compatible APIs for OpenAI, Anthropic, Gemini, DeepSeek, xAI, Groq, OpenRouter, Z.ai, Kimi, Azure OpenAI, Oracle, Ollama, and more.

@@ -64,9 +64,9 @@ curl http://localhost:8080/v1/chat/completions \ ### Supported LLM Providers GoModel supports OpenAI, Anthropic, Google Gemini, Vertex AI, DeepSeek, Groq, -OpenRouter, Z.ai, xAI (Grok), Alibaba Cloud Model Studio (Bailian), MiniMax, -Xiaomi MiMo, OpenCode Go, Azure OpenAI, Oracle, Ollama, vLLM, Amazon Bedrock, -and all OpenAI-compatible providers. +OpenRouter, Z.ai, xAI (Grok), Alibaba Cloud Model Studio (Bailian), Kimi, +MiniMax, Xiaomi MiMo, OpenCode Go, Azure OpenAI, Oracle, Ollama, vLLM, Amazon +Bedrock, and all OpenAI-compatible providers. See the [Providers Overview](./docs/providers/overview.mdx) for the full per-provider feature matrix (chat, `/responses`, embeddings, files, batches, diff --git a/cmd/gomodel/main.go b/cmd/gomodel/main.go index 216710e3..48e50a3e 100644 --- a/cmd/gomodel/main.go +++ b/cmd/gomodel/main.go @@ -22,6 +22,7 @@ import ( "gomodel/internal/providers/deepseek" "gomodel/internal/providers/gemini" "gomodel/internal/providers/groq" + "gomodel/internal/providers/kimi" "gomodel/internal/providers/minimax" "gomodel/internal/providers/ollama" "gomodel/internal/providers/openai" @@ -147,6 +148,7 @@ func main() { factory.Add(gemini.Registration) factory.Add(vertex.Registration) factory.Add(groq.Registration) + factory.Add(kimi.Registration) factory.Add(minimax.Registration) factory.Add(ollama.Registration) factory.Add(opencodego.Registration) diff --git a/cmd/recordapi/main.go b/cmd/recordapi/main.go index 44969438..4b0ce2d2 100644 --- a/cmd/recordapi/main.go +++ b/cmd/recordapi/main.go @@ -9,6 +9,7 @@ package main import ( "bytes" + "compress/gzip" "flag" "fmt" "io" @@ -21,7 +22,10 @@ import ( "github.com/goccy/go-json" ) -const oracleDefaultModel = "openai.gpt-oss-120b" +const ( + oracleDefaultModel = "openai.gpt-oss-120b" + kimiDefaultModel = "kimi-for-coding" +) // Provider configurations var providerConfigs = map[string]struct { @@ -30,6 +34,7 @@ var providerConfigs = map[string]struct { envKey string authHeader string contentType string + setHeaders func(*http.Request) }{ "openai": { baseURL: "https://api.openai.com", @@ -61,6 +66,13 @@ var providerConfigs = map[string]struct { authHeader: "Authorization", contentType: "application/json", }, + "kimi": { + baseURL: "https://api.kimi.com/coding", + envKey: "KIMI_API_KEY", + authHeader: "Authorization", + contentType: "application/json", + setHeaders: setKimiHeaders, + }, "oracle": { baseURLEnv: "ORACLE_BASE_URL", envKey: "ORACLE_API_KEY", @@ -155,7 +167,7 @@ func providerSupportsResponses(provider string) bool { } func main() { - provider := flag.String("provider", "openai", "Provider to test (openai, anthropic, gemini, groq, xai, oracle)") + provider := flag.String("provider", "openai", "Provider to test (openai, anthropic, gemini, groq, xai, kimi, oracle)") endpoint := flag.String("endpoint", "chat", "Endpoint to test (chat, chat_stream, models, responses, responses_stream)") output := flag.String("output", "", "Output file path (required)") model := flag.String("model", "", "Override model in request") @@ -203,12 +215,12 @@ func main() { if eConfig.requestBody != nil { reqBody := eConfig.requestBody - // Oracle's OpenAI-compatible endpoint expects OCI-hosted model IDs, - // so use a provider-specific default instead of the generic gpt-4o-mini fixture. if *model != "" { reqBody["model"] = *model } else if *provider == "oracle" { reqBody["model"] = oracleDefaultModel + } else if *provider == "kimi" { + reqBody["model"] = kimiDefaultModel } // Adjust request for different providers @@ -250,6 +262,11 @@ func main() { req.Header.Set("anthropic-version", "2023-06-01") } + // Apply provider-specific header overrides + if pConfig.setHeaders != nil { + pConfig.setHeaders(req) + } + // Send request client := &http.Client{Timeout: 60 * time.Second} fmt.Printf("Sending request to %s %s...\n", eConfig.method, url) @@ -263,8 +280,18 @@ func main() { fmt.Printf("Response status: %d %s\n", resp.StatusCode, resp.Status) - // Read response body - body, err := io.ReadAll(resp.Body) + // Decompress if the server honored an explicit Accept-Encoding header. + var respReader io.Reader = resp.Body + if resp.Header.Get("Content-Encoding") == "gzip" { + gr, err := gzip.NewReader(resp.Body) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating gzip reader: %v\n", err) + os.Exit(1) + } + defer gr.Close() + respReader = gr + } + body, err := io.ReadAll(respReader) if err != nil { fmt.Fprintf(os.Stderr, "Error reading response: %v\n", err) os.Exit(1) @@ -311,6 +338,26 @@ func main() { } } +// setKimiHeaders installs the client-identity header bundle Kimi's coding +// endpoint expects from an OpenAI-SDK-style client. +func setKimiHeaders(req *http.Request) { + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Encoding", "gzip, deflate") + req.Header.Set("Accept-Language", "*") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Http-Referer", "https://github.com/Zoo-Code-Org/Zoo-Code") + req.Header.Set("Sec-Fetch-Mode", "cors") + req.Header.Set("User-Agent", "ZooCode/3.62.0") + req.Header.Set("X-Stainless-Arch", "x64") + req.Header.Set("X-Stainless-Lang", "js") + req.Header.Set("X-Stainless-Os", "Linux") + req.Header.Set("X-Stainless-Package-Version", "5.12.2") + req.Header.Set("X-Stainless-Retry-Count", "0") + req.Header.Set("X-Stainless-Runtime", "node") + req.Header.Set("X-Stainless-Runtime-Version", "v22.22.1") + req.Header.Set("X-Title", "Zoo Code") +} + // adjustForAnthropic converts OpenAI-style request to Anthropic format func adjustForAnthropic(req map[string]any) map[string]any { result := make(map[string]any) diff --git a/config/config.example.yaml b/config/config.example.yaml index bbe657c8..0b474670 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -230,6 +230,10 @@ providers: openai: type: openai api_key: "sk-..." + # Optional per-provider header fields (same for every provider): + # passthrough_user_headers: false # false default (kimi defaults true); true forwards inbound headers + # passthrough_user_headers_skip: {} # optional list of header names to exclude from passthrough + # custom_upstream_headers: {} # static header map, sent on every request; mutually exclusive with passthrough # Per-provider resilience overrides (optional). # Only specified fields override the global defaults above. # resilience: @@ -283,6 +287,29 @@ providers: # Optional: use GLM Coding Plan endpoint instead of the general endpoint. # base_url: "https://api.z.ai/api/coding/paas/v4" + kimi: + type: kimi + # base_url: "https://api.kimi.com/coding/v1" # default + api_key: "${KIMI_API_KEY}" + # Kimi defaults passthrough_user_headers to true: inbound headers (minus a + # skip list) are forwarded automatically. custom_upstream_headers and + # passthrough_user_headers are mutually exclusive. See docs/features/provider-header-passthrough.mdx. + # + # The Kimi coding endpoint expects a recognizable client identity. With + # passthrough on, a caller sending those headers is forwarded as-is. + # Alternatively, disable passthrough and ship a static identity bundle. + # Use values truthful for your client and compliant with Kimi's terms. + # passthrough_user_headers: false + # passthrough_user_headers_skip: + # - X-MyOrg-Trace + # custom_upstream_headers: + # Http-Referer: "https://github.com/Zoo-Code-Org/Zoo-Code" + # User-Agent: "ZooCode/3.62.0" + # X-Stainless-Lang: "js" + # X-Stainless-Os: "Linux" + # X-Stainless-Runtime: "node" + # X-Title: "Zoo Code" + xiaomi: type: xiaomi api_key: "..." diff --git a/config/providers.go b/config/providers.go index 20876024..bbee9993 100644 --- a/config/providers.go +++ b/config/providers.go @@ -4,19 +4,22 @@ package config // overrides, credential filtering, or resilience merging. Exported so the // providers package can resolve it into a fully-configured ProviderConfig. type RawProviderConfig struct { - Type string `yaml:"type"` - APIKey string `yaml:"api_key"` - BaseURL string `yaml:"base_url"` - APIVersion string `yaml:"api_version"` - Backend string `yaml:"backend"` - AuthType string `yaml:"auth_type"` - APIMode string `yaml:"api_mode"` - VertexProject string `yaml:"vertex_project"` - VertexLocation string `yaml:"vertex_location"` - ServiceAccountFile string `yaml:"service_account_file"` - ServiceAccountJSON string `yaml:"service_account_json"` - ServiceAccountJSONBase64 string `yaml:"service_account_json_base64"` - GCPScope string `yaml:"gcp_scope"` - Models []RawProviderModel `yaml:"models"` - Resilience *RawResilienceConfig `yaml:"resilience"` + Type string `yaml:"type"` + APIKey string `yaml:"api_key"` + BaseURL string `yaml:"base_url"` + APIVersion string `yaml:"api_version"` + Backend string `yaml:"backend"` + AuthType string `yaml:"auth_type"` + APIMode string `yaml:"api_mode"` + VertexProject string `yaml:"vertex_project"` + VertexLocation string `yaml:"vertex_location"` + ServiceAccountFile string `yaml:"service_account_file"` + ServiceAccountJSON string `yaml:"service_account_json"` + ServiceAccountJSONBase64 string `yaml:"service_account_json_base64"` + GCPScope string `yaml:"gcp_scope"` + Models []RawProviderModel `yaml:"models"` + CustomUpstreamHeaders map[string]string `yaml:"custom_upstream_headers"` + PassthroughUserHeaders *bool `yaml:"passthrough_user_headers"` + PassthroughUserHeadersSkip []string `yaml:"passthrough_user_headers_skip"` + Resilience *RawResilienceConfig `yaml:"resilience"` } diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx index a001c38e..c2648efe 100644 --- a/docs/advanced/configuration.mdx +++ b/docs/advanced/configuration.mdx @@ -218,6 +218,8 @@ Set these to automatically register providers. No YAML configuration required. Most providers can use a custom base URL via `_BASE_URL` (for example `OPENAI_BASE_URL`). DeepSeek defaults to `https://api.deepseek.com`; set `DEEPSEEK_BASE_URL` only for a compatible proxy or alternate DeepSeek endpoint. OpenRouter defaults to `https://openrouter.ai/api/v1` and can be overridden with `OPENROUTER_BASE_URL`. Z.ai defaults to `https://api.z.ai/api/paas/v4`; set `ZAI_BASE_URL=https://api.z.ai/api/coding/paas/v4` for the GLM Coding Plan endpoint. vLLM defaults to `http://localhost:8000/v1` when `VLLM_API_KEY` is set, but keyless deployments should set `VLLM_BASE_URL` explicitly to register the provider. Azure uses `AZURE_BASE_URL` for its deployment base URL and accepts an optional `AZURE_API_VERSION` override; otherwise it defaults to `2024-10-21`. Oracle requires `ORACLE_BASE_URL` because its OpenAI-compatible endpoint is region-specific. +YAML provider blocks also accept `custom_upstream_headers` and `passthrough_user_headers` for tuning outbound headers; see the [Provider Header Passthrough feature page](/features/provider-header-passthrough) for the schema, defaults, env-var pattern, and worked examples. + Every provider type also accepts a comma-separated configured model list via `_MODELS`, for example `OPENROUTER_MODELS`, `ORACLE_MODELS`, `AZURE_MODELS`, or `VLLM_MODELS`. By default, diff --git a/docs/docs.json b/docs/docs.json index b50b4048..20fc155d 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -54,6 +54,7 @@ "features/virtual-models", "features/user-path", "features/passthrough-api", + "features/provider-header-passthrough", "features/budgets", "features/cost-tracking", "features/cache", diff --git a/docs/features/provider-header-passthrough.mdx b/docs/features/provider-header-passthrough.mdx new file mode 100644 index 00000000..1716d7ff --- /dev/null +++ b/docs/features/provider-header-passthrough.mdx @@ -0,0 +1,118 @@ +--- +title: "Provider Header Passthrough" +description: "Tune outbound request headers per provider — pick one: static custom_upstream_headers OR inbound passthrough. Kimi defaults to passthrough on." +icon: "shuffle" +--- + +## Overview + +Each provider block accepts two optional, **mutually exclusive** header fields. +Setting both fails config-load validation. + +| Field | Purpose | +| --- | --- | +| `custom_upstream_headers` | Static map of header name → value, written on every outbound request. YAML-only. | +| `passthrough_user_headers` | Boolean. When `true`, every non-skipped inbound header is forwarded onto the outbound request. | + +Pick the path that fits the upstream: + +- **Forward what the caller sent** → `passthrough_user_headers: true`. Defaults + to `true` for Kimi, `false` for every other provider. Drop specific keys the + upstream rejects with `passthrough_user_headers_skip`. +- **Write a static bundle** → `passthrough_user_headers: false` + + `custom_upstream_headers`. The outbound request gets only the static headers + + provider-set auth, never anything the caller sent. + +Provider-set authentication headers (e.g. `Authorization: Bearer …`) are +written by the provider factory before any of this runs; the feature never +touches them. + +Override precedence, highest wins: env var > YAML field > provider-type +default. + +```bash +export KIMI_PASSTHROUGH_USER_HEADERS=false +``` + +This is unrelated to the [HTTP passthrough API](/features/passthrough-api) at +`/p/{provider}/...`, which proxies provider-native bodies verbatim. + +## Static custom headers (passthrough off) + +`passthrough_user_headers: false` plus a `custom_upstream_headers` map writes +the same bundle on every request — useful for stable client identity, tracing, +or any upstream-specific static key the inbound call shouldn't influence. + +```yaml +providers: + : + api_key: "" + passthrough_user_headers: false + custom_upstream_headers: + User-Agent: "my-app/1.0" + X-Trace-Source: "gomodel" +``` + +Header names are canonicalized with `http.CanonicalHeaderKey`, so +`x-trace-source` and `X-Trace-Source` are the same key. There is no env-var +equivalent — keep static bundles in `config.yaml`. + +## Inbound header passthrough (custom headers off) + +`passthrough_user_headers: true` forwards every non-skipped inbound header +onto the outbound request. Provider auth headers win over inbound ones; +cookies and `X-Forwarded-*` are never forwarded. + +```yaml +providers: + : + api_key: "" + passthrough_user_headers: true +``` + +### Per-provider skip additions + +When the upstream rejects a header the always-on floor doesn't already cover, +list it under `passthrough_user_headers_skip` (or the +`_PASSTHROUGH_USER_HEADERS_SKIP` env var, comma-separated; env wins +over YAML). + +```yaml +providers: + : + api_key: "" + passthrough_user_headers: true + passthrough_user_headers_skip: + - X-MyOrg-Trace + - X-Internal-Id +``` + +```bash +export KIMI_PASSTHROUGH_USER_HEADERS_SKIP="X-MyOrg-Trace, X-Internal-Id" +``` + +Keys are matched with `http.CanonicalHeaderKey`, so `x-myorg-trace` and +`X-MyOrg-Trace` are the same. + +## Skip list + +The always-on floor drops the following from inbound passthrough on every +provider, and is not per-provider configurable — only additions via +`passthrough_user_headers_skip` are: + +- **Credential headers**: `Authorization`, `X-Api-Key` +- **Transport-managed / hop-by-hop (RFC 7230)**: `Host`, `Content-Length`, + `Connection`, `Keep-Alive`, `Proxy-Authenticate`, `Proxy-Authorization`, + `Te`, `Trailer`, `Transfer-Encoding`, `Upgrade` +- **Cookies and forwarding context**: `Cookie`, `Set-Cookie`, `Forwarded`, + every `X-Forwarded-*` prefix + +The skip list has no effect when `passthrough_user_headers: false`. + +## See also + +- [Kimi provider guide](/providers/kimi) — Kimi's passthrough default and the + client-identity headers its coding endpoint expects. +- [Configuration reference](/advanced/configuration) — full `config.yaml` + schema. +- [Providers overview](/providers/overview) — comparison matrix. diff --git a/docs/providers/kimi.mdx b/docs/providers/kimi.mdx new file mode 100644 index 00000000..cad290dd --- /dev/null +++ b/docs/providers/kimi.mdx @@ -0,0 +1,113 @@ +--- +title: "Kimi" +description: "Configure Kimi (Moonshot AI) in GoModel as a plain OpenAI-compatible provider — no header injection, just standard Bearer auth." +icon: "sparkles" +--- + +[Kimi](https://www.kimi.com/) (Moonshot AI) exposes an OpenAI-compatible API. +GoModel ships a native `kimi` provider that is a thin wrapper over the shared +OpenAI-compatible transport — no provider-specific headers, no request +mutators, and no auth scheme overrides. Authentication uses the standard +`Authorization: Bearer ` header. + +## Configure + +```bash +KIMI_API_KEY=... +``` + +Or in `config.yaml`: + +```yaml +providers: + kimi: + type: kimi + base_url: "https://api.kimi.com/coding/v1" + api_key: "${KIMI_API_KEY}" +``` + +The default base URL is `https://api.kimi.com/coding/v1`. Override it with +`KIMI_BASE_URL` or the YAML `base_url` field when you need to point at a +compatible proxy or alternate Kimi endpoint. + +## Authentication + +Kimi uses standard Bearer auth. GoModel sets one header on every outbound +request: + +| Header | Value | +| ------ | ----- | +| `Authorization` | `Bearer ` | + +GoModel itself adds no other static headers — but note the [Header +customization](#header-customization) default below, which forwards inbound +headers onto the outbound request. + +## Model IDs + +Kimi's `/v1/models` endpoint advertises the model family exposed through the +configured base URL. Common identifiers include the `kimi-k2` series. Pin the +advertised set with a configured list when you only want a subset: + +```bash +KIMI_MODELS=kimi-for-coding +``` + +By default, `CONFIGURED_PROVIDER_MODELS_MODE=fallback` uses the configured +list only when the upstream `/models` response is unavailable or empty. Set +`CONFIGURED_PROVIDER_MODELS_MODE=allowlist` to expose only configured models +and skip the upstream `/models` call. + +## Endpoints + +Kimi is served through the standard OpenAI-compatible surface. All of these +work out of the box: + +| Endpoint | Notes | +| -------- | ----- | +| `/v1/chat/completions` | Chat completions, including streaming. | +| `/v1/responses` | Translated to chat completions. | +| `/v1/embeddings` | Embeddings. | +| `/v1/models` | Model listing (see [Model IDs](#model-ids)). | +| `/v1/files` | File uploads and retrieval. | +| `/v1/batches` | Batch jobs. | +| `/v1/audio/speech` | Text-to-speech. | +| `/v1/audio/transcriptions` | Audio transcription. | +| `/p/kimi/...` | Provider-native passthrough (unchanged). | + +## Header customization + +Kimi is the only provider where GoModel **defaults `passthrough_user_headers` +to `true`**: inbound request headers (minus a small skip list of credential +and transport headers) are forwarded onto the outbound Kimi request +automatically. + +The Kimi coding endpoint also expects a recognizable client identity — a +`User-Agent` plus the `X-Stainless-*`/`X-Title` headers most OpenAI SDKs send. +With passthrough on, a caller that sends those headers (e.g. an OpenAI SDK or +a configured proxy) is enough. Alternatively, set a static bundle and turn +passthrough off: + +```yaml +providers: + kimi: + type: kimi + api_key: "${KIMI_API_KEY}" + passthrough_user_headers: false + custom_upstream_headers: + User-Agent: "ZooCode/3.62.0" + X-Title: "Zoo Code" + Http-Referer: "https://github.com/Zoo-Code-Org/Zoo-Code" + X-Stainless-Lang: "js" + X-Stainless-Os: "Linux" +``` + +```bash +KIMI_PASSTHROUGH_USER_HEADERS=false +``` + +Header identity is forwarded as supplied. Make sure the values you send are +truthful for your client, and that their use complies with Kimi's terms. + +See the [Provider Header Passthrough feature page](/features/provider-header-passthrough) +for the full skip list, env-var overrides, and the per-provider schema. diff --git a/docs/providers/overview.mdx b/docs/providers/overview.mdx index 281e6e27..cc94ab1b 100644 --- a/docs/providers/overview.mdx +++ b/docs/providers/overview.mdx @@ -29,6 +29,7 @@ support, not every individual model capability exposed by an upstream provider. | Z.ai | `ZAI_API_KEY` (`ZAI_BASE_URL` optional) | `glm-5.1` | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | — | | xAI (Grok) | `XAI_API_KEY` | `grok-4` | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | — | | Alibaba Cloud Model Studio (Bailian) | `BAILIAN_API_KEY` (`BAILIAN_BASE_URL` optional) | `qwen3-max` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [Alibaba Cloud Model Studio](/providers/bailian) | +| Kimi | `KIMI_API_KEY` (`KIMI_BASE_URL` optional) | `kimi-for-coding` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [Kimi](/providers/kimi) | | MiniMax | `MINIMAX_API_KEY` (`MINIMAX_BASE_URL` optional) | `MiniMax-M3` | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | — | | Xiaomi MiMo | `XIAOMI_API_KEY` (`XIAOMI_BASE_URL` optional) | `mimo-v2.5-pro` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | [Xiaomi MiMo](/providers/xiaomi) | | OpenCode Go | `OPENCODE_GO_API_KEY` (`OPENCODE_GO_BASE_URL` optional) | `glm-5.1` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | [OpenCode Go](/providers/opencode-go) | @@ -66,6 +67,13 @@ support, not every individual model capability exposed by an upstream provider. for providers that define a list, skipping their upstream `/models` calls. - **vLLM** — set `VLLM_API_KEY` only if the upstream server was started with `--api-key`. +- **Per-provider header customization** — each provider block accepts + `custom_upstream_headers` (static bundle, YAML-only) and + `passthrough_user_headers` (forward inbound headers after a skip list; env: + `_PASSTHROUGH_USER_HEADERS`). Kimi defaults passthrough to `true`, + every other provider to `false`; the two fields are mutually exclusive. See + the [Provider Header Passthrough feature page](/features/provider-header-passthrough) + for the full schema, skip list, and examples. - **Multiple instances of one provider type** — without `config.yaml`, use suffixed env vars such as `OPENAI_EAST_API_KEY` and `OPENAI_EAST_BASE_URL`; add `OPENAI_EAST_MODELS` to configure that instance's model list. This diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index 9a6a7ed4..b9c56b28 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -47,6 +47,10 @@ type Provider struct { client *llmclient.Client apiKey string + customHeaders map[string]string + passthrough bool + passthroughSkip []string + batchEndpointsMu sync.RWMutex // batchResultEndpoints keeps endpoint hints by provider batch id and custom_id. // Used only to shape native batch result items (e.g., /v1/responses vs /v1/chat/completions). @@ -57,6 +61,9 @@ type Provider struct { func New(providerCfg providers.ProviderConfig, opts providers.ProviderOptions) core.Provider { p := &Provider{ apiKey: providerCfg.APIKey, + customHeaders: providerCfg.CustomUpstreamHeaders, + passthrough: providerCfg.PassthroughUserHeaders, + passthroughSkip: providerCfg.PassthroughUserHeadersSkip, batchResultEndpoints: make(map[string]map[string]string), } clientCfg := llmclient.Config{ @@ -164,6 +171,8 @@ func (p *Provider) setHeaders(req *http.Request) { if requestID := core.GetRequestID(req.Context()); requestID != "" { req.Header.Set("X-Request-Id", requestID) } + + providers.ApplyRequestHeaderOverrides(req.Context(), req.Header, p.customHeaders, p.passthrough, p.passthroughSkip...) } // Passthrough forwards an opaque Anthropic-native request without typed translation. diff --git a/internal/providers/azure/azure.go b/internal/providers/azure/azure.go index e074c43e..2008da29 100644 --- a/internal/providers/azure/azure.go +++ b/internal/providers/azure/azure.go @@ -37,9 +37,12 @@ func New(providerCfg providers.ProviderConfig, opts providers.ProviderOptions) c apiVersion := providers.ResolveAPIVersion(providerCfg.APIVersion, defaultAPIVersion) p := &Provider{apiVersion: apiVersion, apiKey: providerCfg.APIKey} clientCfg := openai.CompatibleProviderConfig{ - ProviderName: "azure", - BaseURL: baseURL, - SetHeaders: setHeaders, + ProviderName: "azure", + BaseURL: baseURL, + SetHeaders: setHeaders, + CustomUpstreamHeaders: providerCfg.CustomUpstreamHeaders, + PassthroughUserHeaders: providerCfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: providerCfg.PassthroughUserHeadersSkip, } p.CompatibleProvider = openai.NewCompatibleProvider(providerCfg.APIKey, opts, clientCfg) p.resourceProvider = openai.NewCompatibleProvider(providerCfg.APIKey, opts, clientCfg) diff --git a/internal/providers/config.go b/internal/providers/config.go index dd2a5bfc..0f0072f1 100644 --- a/internal/providers/config.go +++ b/internal/providers/config.go @@ -1,9 +1,11 @@ package providers import ( + "fmt" "maps" "os" "sort" + "strconv" "strings" "unicode" @@ -32,8 +34,11 @@ type ProviderConfig struct { // ID (as it appears in the provider's /models response). The registry merges // these onto remote-registry metadata after enrichment; non-zero fields here // win. Empty/nil when no per-model metadata is declared in YAML. - ModelMetadataOverrides map[string]*core.ModelMetadata - Resilience config.ResilienceConfig + ModelMetadataOverrides map[string]*core.ModelMetadata + CustomUpstreamHeaders map[string]string + PassthroughUserHeaders bool + PassthroughUserHeadersSkip []string + Resilience config.ResilienceConfig } // resolveProviders applies env var overrides to the raw YAML provider map, filters @@ -47,6 +52,18 @@ func resolveProviders(raw map[string]config.RawProviderConfig, global config.Res return buildProviderConfigs(filtered, global), filtered } +func validateMutuallyExclusiveHeaders(providers map[string]ProviderConfig) error { + for name, cfg := range providers { + if cfg.PassthroughUserHeaders && len(cfg.CustomUpstreamHeaders) > 0 { + return fmt.Errorf( + "provider %q sets both passthrough_user_headers and custom_upstream_headers; pick one (passthrough forwards inbound headers; custom writes a static bundle)", + name, + ) + } + } + return nil +} + // applyProviderEnvVars overlays well-known provider env vars onto the raw YAML map. // Env var values always win over YAML values for the same provider name. func applyProviderEnvVars(raw map[string]config.RawProviderConfig, discovery map[string]DiscoveryConfig) map[string]config.RawProviderConfig { @@ -91,6 +108,8 @@ const ( providerEnvFieldServiceAccountJSON providerEnvFieldServiceAccountJSONBase64 providerEnvFieldGCPScope + providerEnvFieldPassthroughUserHeaders + providerEnvFieldPassthroughUserHeadersSkip ) type providerEnvSource struct { @@ -114,6 +133,10 @@ type providerEnvValues struct { ServiceAccountJSONBase64 string GCPScope string Models []string + // *bool so "unset" is distinguishable from "explicitly false". + PassthroughUserHeaders *bool + // Comma-separated header names from _PASSTHROUGH_USER_HEADERS_SKIP. + PassthroughUserHeadersSkip []string } func (v providerEnvValues) empty() bool { @@ -129,7 +152,9 @@ func (v providerEnvValues) empty() bool { strings.TrimSpace(v.ServiceAccountJSON) == "" && strings.TrimSpace(v.ServiceAccountJSONBase64) == "" && strings.TrimSpace(v.GCPScope) == "" && - len(v.Models) == 0 + len(v.Models) == 0 && + v.PassthroughUserHeaders == nil && + len(v.PassthroughUserHeadersSkip) == 0 } func providerEnvSources(providerType string, spec DiscoveryConfig) []providerEnvSource { @@ -188,6 +213,14 @@ func collectProviderEnvValues(prefix string, spec DiscoveryConfig, environ []str values.ServiceAccountJSONBase64 = value case providerEnvFieldGCPScope: values.GCPScope = value + case providerEnvFieldPassthroughUserHeaders: + parsed, err := strconv.ParseBool(value) + if err != nil { + continue + } + values.PassthroughUserHeaders = &parsed + case providerEnvFieldPassthroughUserHeadersSkip: + values.PassthroughUserHeadersSkip = parseCSVEnvList(value) } groups[suffix] = values } @@ -214,6 +247,8 @@ func parseProviderEnvKey(prefix, key string, spec DiscoveryConfig) (string, prov name string field providerEnvField }{ + {name: "PASSTHROUGH_USER_HEADERS_SKIP", field: providerEnvFieldPassthroughUserHeadersSkip}, + {name: "PASSTHROUGH_USER_HEADERS", field: providerEnvFieldPassthroughUserHeaders}, {name: "API_VERSION", field: providerEnvFieldAPIVersion}, {name: "BASE_URL", field: providerEnvFieldBaseURL}, {name: "AUTH_TYPE", field: providerEnvFieldAuthType}, @@ -333,20 +368,22 @@ func applySuffixedProviderEnvVars(result map[string]config.RawProviderConfig, pr func (v providerEnvValues) rawConfig(providerType string, spec DiscoveryConfig) config.RawProviderConfig { backend := v.Backend return config.RawProviderConfig{ - Type: providerType, - APIKey: v.APIKey, - BaseURL: v.resolvedBaseURL(spec), - APIVersion: v.APIVersion, - Backend: backend, - AuthType: v.AuthType, - APIMode: v.APIMode, - VertexProject: v.VertexProject, - VertexLocation: v.VertexLocation, - ServiceAccountFile: v.ServiceAccountFile, - ServiceAccountJSON: v.ServiceAccountJSON, - ServiceAccountJSONBase64: v.ServiceAccountJSONBase64, - GCPScope: v.GCPScope, - Models: rawProviderModelsFromIDs(v.Models), + Type: providerType, + APIKey: v.APIKey, + BaseURL: v.resolvedBaseURL(spec), + APIVersion: v.APIVersion, + Backend: backend, + AuthType: v.AuthType, + APIMode: v.APIMode, + VertexProject: v.VertexProject, + VertexLocation: v.VertexLocation, + ServiceAccountFile: v.ServiceAccountFile, + ServiceAccountJSON: v.ServiceAccountJSON, + ServiceAccountJSONBase64: v.ServiceAccountJSONBase64, + GCPScope: v.GCPScope, + Models: rawProviderModelsFromIDs(v.Models), + PassthroughUserHeaders: v.PassthroughUserHeaders, + PassthroughUserHeadersSkip: v.PassthroughUserHeadersSkip, } } @@ -400,6 +437,12 @@ func overlayProviderEnvValues(existing config.RawProviderConfig, values provider if len(values.Models) > 0 { existing.Models = rawProviderModelsFromIDs(values.Models) } + if values.PassthroughUserHeaders != nil { + existing.PassthroughUserHeaders = values.PassthroughUserHeaders + } + if len(values.PassthroughUserHeadersSkip) > 0 { + existing.PassthroughUserHeadersSkip = values.PassthroughUserHeadersSkip + } return existing } @@ -622,22 +665,25 @@ func buildProviderConfigs(raw map[string]config.RawProviderConfig, global config // Non-nil fields in the raw config override the global defaults. func buildProviderConfig(raw config.RawProviderConfig, global config.ResilienceConfig) ProviderConfig { resolved := ProviderConfig{ - Type: normalizeProviderType(raw), - APIKey: raw.APIKey, - BaseURL: raw.BaseURL, - APIVersion: raw.APIVersion, - Backend: raw.Backend, - AuthType: raw.AuthType, - APIMode: raw.APIMode, - VertexProject: raw.VertexProject, - VertexLocation: raw.VertexLocation, - ServiceAccountFile: raw.ServiceAccountFile, - ServiceAccountJSON: raw.ServiceAccountJSON, - ServiceAccountJSONBase64: raw.ServiceAccountJSONBase64, - GCPScope: raw.GCPScope, - Models: config.ProviderModelIDs(raw.Models), - ModelMetadataOverrides: config.ProviderModelMetadataOverrides(raw.Models), - Resilience: global, + Type: normalizeProviderType(raw), + APIKey: raw.APIKey, + BaseURL: raw.BaseURL, + APIVersion: raw.APIVersion, + Backend: raw.Backend, + AuthType: raw.AuthType, + APIMode: raw.APIMode, + VertexProject: raw.VertexProject, + VertexLocation: raw.VertexLocation, + ServiceAccountFile: raw.ServiceAccountFile, + ServiceAccountJSON: raw.ServiceAccountJSON, + ServiceAccountJSONBase64: raw.ServiceAccountJSONBase64, + GCPScope: raw.GCPScope, + Models: config.ProviderModelIDs(raw.Models), + ModelMetadataOverrides: config.ProviderModelMetadataOverrides(raw.Models), + CustomUpstreamHeaders: raw.CustomUpstreamHeaders, + PassthroughUserHeaders: resolvePassthroughUserHeaders(raw), + PassthroughUserHeadersSkip: raw.PassthroughUserHeadersSkip, + Resilience: global, } if raw.Resilience == nil { @@ -677,6 +723,13 @@ func buildProviderConfig(raw config.RawProviderConfig, global config.ResilienceC return resolved } +func resolvePassthroughUserHeaders(raw config.RawProviderConfig) bool { + if raw.PassthroughUserHeaders != nil { + return *raw.PassthroughUserHeaders + } + return strings.EqualFold(normalizeProviderType(raw), "kimi") +} + func normalizeProviderType(raw config.RawProviderConfig) string { providerType := strings.TrimSpace(raw.Type) if strings.EqualFold(providerType, "gemini") && strings.EqualFold(strings.TrimSpace(raw.Backend), "vertex") { diff --git a/internal/providers/config_test.go b/internal/providers/config_test.go index 74ecc01d..cad668f2 100644 --- a/internal/providers/config_test.go +++ b/internal/providers/config_test.go @@ -1,6 +1,7 @@ package providers import ( + "strings" "testing" "time" @@ -43,6 +44,9 @@ var testDiscoveryConfigs = map[string]DiscoveryConfig{ "zai": { DefaultBaseURL: "https://api.z.ai/api/paas/v4", }, + "kimi": { + DefaultBaseURL: "https://api.kimi.com/coding/v1", + }, "vllm": { DefaultBaseURL: "http://localhost:8000/v1", AllowAPIKeyless: true, @@ -1516,6 +1520,160 @@ func TestResolveProviders_SingleCustomNamedProviderDoesNotDuplicateTypeKey(t *te } } +func boolPtr(b bool) *bool { return &b } + +func TestBuildProviderConfig_KimiPassthroughDefaultTrue(t *testing.T) { + raw := config.RawProviderConfig{Type: "kimi", APIKey: "sk-kimi"} + got := buildProviderConfig(raw, globalResilience) + + if !got.PassthroughUserHeaders { + t.Errorf("kimi default PassthroughUserHeaders = false, want true") + } +} + +func TestBuildProviderConfig_NonKimiPassthroughDefaultFalse(t *testing.T) { + cases := []struct { + name string + typ string + }{ + {"openai", "openai"}, + {"anthropic", "anthropic"}, + {"gemini", "gemini"}, + {"vertex", "vertex"}, + {"custom", "custom-provider"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + raw := config.RawProviderConfig{Type: tc.typ, APIKey: "sk"} + got := buildProviderConfig(raw, globalResilience) + + if got.PassthroughUserHeaders { + t.Errorf("%s default PassthroughUserHeaders = true, want false", tc.typ) + } + }) + } +} + +func TestBuildProviderConfig_PassthroughYAMLOverride(t *testing.T) { + t.Run("kimi_false", func(t *testing.T) { + raw := config.RawProviderConfig{ + Type: "kimi", + APIKey: "sk-kimi", + PassthroughUserHeaders: boolPtr(false), + } + got := buildProviderConfig(raw, globalResilience) + if got.PassthroughUserHeaders { + t.Errorf("kimi with explicit PassthroughUserHeaders=false got true, want false") + } + }) + t.Run("openai_true", func(t *testing.T) { + raw := config.RawProviderConfig{ + Type: "openai", + APIKey: "sk", + PassthroughUserHeaders: boolPtr(true), + } + got := buildProviderConfig(raw, globalResilience) + if !got.PassthroughUserHeaders { + t.Errorf("openai with explicit PassthroughUserHeaders=true got false, want true") + } + }) +} + +func TestBuildProviderConfig_PassthroughEnvOverride(t *testing.T) { + t.Setenv("KIMI_PASSTHROUGH_USER_HEADERS", "false") + + discovery := map[string]DiscoveryConfig{ + "kimi": {DefaultBaseURL: "https://api.kimi.com/v1"}, + } + raw := map[string]config.RawProviderConfig{ + "kimi": {Type: "kimi", APIKey: "sk-kimi"}, + } + + merged := applyProviderEnvVars(raw, discovery) + got := buildProviderConfig(merged["kimi"], globalResilience) + + if got.PassthroughUserHeaders { + t.Errorf("KIMI_PASSTHROUGH_USER_HEADERS=false should override kimi default true, got true") + } +} + +func TestBuildProviderConfig_PassthroughUserHeadersSkipRoundTrip(t *testing.T) { + skip := []string{"X-MyOrg-Trace", "X-Internal-Id"} + raw := config.RawProviderConfig{ + Type: "kimi", + APIKey: "sk-kimi", + PassthroughUserHeadersSkip: skip, + } + got := buildProviderConfig(raw, globalResilience) + + if len(got.PassthroughUserHeadersSkip) != len(skip) { + t.Fatalf("PassthroughUserHeadersSkip length = %d, want %d", len(got.PassthroughUserHeadersSkip), len(skip)) + } + for _, k := range skip { + found := false + for _, got := range got.PassthroughUserHeadersSkip { + if got == k { + found = true + break + } + } + if !found { + t.Errorf("PassthroughUserHeadersSkip missing %q", k) + } + } +} + +func TestBuildProviderConfig_PassthroughSkipEnvOverride(t *testing.T) { + t.Setenv("KIMI_PASSTHROUGH_USER_HEADERS_SKIP", "X-A, X-B , X-C") + + discovery := map[string]DiscoveryConfig{ + "kimi": {DefaultBaseURL: "https://api.kimi.com/v1"}, + } + raw := map[string]config.RawProviderConfig{ + "kimi": { + Type: "kimi", + APIKey: "sk-kimi", + PassthroughUserHeadersSkip: []string{"old"}, + }, + } + + merged := applyProviderEnvVars(raw, discovery) + got := buildProviderConfig(merged["kimi"], globalResilience) + + want := []string{"X-A", "X-B", "X-C"} + if len(got.PassthroughUserHeadersSkip) != len(want) { + t.Fatalf("PassthroughUserHeadersSkip = %v, want %v", got.PassthroughUserHeadersSkip, want) + } + for i, k := range want { + if got.PassthroughUserHeadersSkip[i] != k { + t.Errorf("PassthroughUserHeadersSkip[%d] = %q, want %q", i, got.PassthroughUserHeadersSkip[i], k) + } + } +} + +func TestBuildProviderConfig_CustomUpstreamHeadersRoundTrip(t *testing.T) { + headers := map[string]string{ + "X-Org-Id": "acme", + "X-Tenant": "primary", + "X-Trace": "yes", + } + raw := config.RawProviderConfig{ + Type: "openai", + APIKey: "sk", + CustomUpstreamHeaders: headers, + } + got := buildProviderConfig(raw, globalResilience) + + if len(got.CustomUpstreamHeaders) != len(headers) { + t.Fatalf("CustomUpstreamHeaders length = %d, want %d", len(got.CustomUpstreamHeaders), len(headers)) + } + for k, v := range headers { + if got.CustomUpstreamHeaders[k] != v { + t.Errorf("CustomUpstreamHeaders[%q] = %q, want %q", k, got.CustomUpstreamHeaders[k], v) + } + } +} + func TestResolveProviders_NoProvidersNoEnvVars(t *testing.T) { got, filteredRaw := resolveProviders(map[string]config.RawProviderConfig{}, globalResilience, testDiscoveryConfigs) if len(got) != 0 { @@ -1525,3 +1683,65 @@ func TestResolveProviders_NoProvidersNoEnvVars(t *testing.T) { t.Errorf("expected empty filtered raw, got %d entries", len(filteredRaw)) } } + +func TestValidateMutuallyExclusiveHeaders_RejectsBoth(t *testing.T) { + t.Run("passthrough_true_with_custom_headers", func(t *testing.T) { + providers := map[string]ProviderConfig{ + "kimi": { + Type: "kimi", + PassthroughUserHeaders: true, + CustomUpstreamHeaders: map[string]string{"X-Trace-Source": "gomodel"}, + }, + } + err := validateMutuallyExclusiveHeaders(providers) + if err == nil { + t.Fatalf("expected error, got nil") + } + if !strings.Contains(err.Error(), "kimi") { + t.Errorf("expected error to name the provider, got %v", err) + } + if !strings.Contains(err.Error(), "both") { + t.Errorf("expected error to explain the conflict, got %v", err) + } + }) + + t.Run("default_passthrough_with_custom_headers", func(t *testing.T) { + // Default for non-kimi is false; only true should trigger. + providers := map[string]ProviderConfig{ + "openai": { + Type: "openai", + PassthroughUserHeaders: false, + CustomUpstreamHeaders: map[string]string{"X-Trace-Source": "gomodel"}, + }, + } + if err := validateMutuallyExclusiveHeaders(providers); err != nil { + t.Errorf("passthrough=false with custom must be allowed, got %v", err) + } + }) +} + +func TestValidateMutuallyExclusiveHeaders_AllowsEachAlone(t *testing.T) { + t.Run("neither_set", func(t *testing.T) { + if err := validateMutuallyExclusiveHeaders(map[string]ProviderConfig{ + "openai": {Type: "openai"}, + }); err != nil { + t.Errorf("neither set must be allowed, got %v", err) + } + }) + + t.Run("only_passthrough", func(t *testing.T) { + if err := validateMutuallyExclusiveHeaders(map[string]ProviderConfig{ + "kimi": {Type: "kimi", PassthroughUserHeaders: true}, + }); err != nil { + t.Errorf("only passthrough must be allowed, got %v", err) + } + }) + + t.Run("only_custom", func(t *testing.T) { + if err := validateMutuallyExclusiveHeaders(map[string]ProviderConfig{ + "openai": {Type: "openai", CustomUpstreamHeaders: map[string]string{"X-Trace": "v"}}, + }); err != nil { + t.Errorf("only custom must be allowed, got %v", err) + } + }) +} diff --git a/internal/providers/gemini/gemini.go b/internal/providers/gemini/gemini.go index ac00cff8..c2b9345c 100644 --- a/internal/providers/gemini/gemini.go +++ b/internal/providers/gemini/gemini.go @@ -54,6 +54,10 @@ type Provider struct { useNativeAPI bool modelsURL string configErr error + + customHeaders map[string]string + passthrough bool + passthroughSkip []string } // New creates a new Gemini provider. @@ -75,11 +79,14 @@ func newProvider(providerCfg providers.ProviderConfig, opts providers.ProviderOp baseURL, nativeBaseURL := geminiBaseURLs(providerCfg, backend) modelsURL := geminiModelsBaseURL(backend, nativeBaseURL) p := &Provider{ - apiKey: providerCfg.APIKey, - backend: backend, - authType: authType, - useNativeAPI: useNativeAPI(providerCfg.APIMode), - modelsURL: modelsURL, + apiKey: providerCfg.APIKey, + backend: backend, + authType: authType, + useNativeAPI: useNativeAPI(providerCfg.APIMode), + modelsURL: modelsURL, + customHeaders: providerCfg.CustomUpstreamHeaders, + passthrough: providerCfg.PassthroughUserHeaders, + passthroughSkip: providerCfg.PassthroughUserHeadersSkip, } p.validateConfig(providerCfg) if !preauthenticated { @@ -123,11 +130,14 @@ func NewWithHTTPClient(apiKey string, httpClient *http.Client, hooks llmclient.H baseURL, nativeBaseURL := geminiBaseURLs(providerCfg, geminiBackendAIStudio) modelsURL := geminiModelsBaseURL(geminiBackendAIStudio, nativeBaseURL) p := &Provider{ - apiKey: apiKey, - backend: geminiBackendAIStudio, - authType: geminiAuthTypeAPIKey, - useNativeAPI: useNativeAPIFromEnv(), - modelsURL: modelsURL, + apiKey: apiKey, + backend: geminiBackendAIStudio, + authType: geminiAuthTypeAPIKey, + useNativeAPI: useNativeAPIFromEnv(), + modelsURL: modelsURL, + customHeaders: providerCfg.CustomUpstreamHeaders, + passthrough: providerCfg.PassthroughUserHeaders, + passthroughSkip: providerCfg.PassthroughUserHeadersSkip, } modelsCfg := llmclient.DefaultConfig("gemini", modelsURL) modelsCfg.Hooks = hooks @@ -235,6 +245,8 @@ func (p *Provider) setHeaders(req *http.Request) { if requestID := core.GetRequestID(req.Context()); requestID != "" { req.Header.Set("X-Request-Id", requestID) } + + providers.ApplyRequestHeaderOverrides(req.Context(), req.Header, p.customHeaders, p.passthrough, p.passthroughSkip...) } // setNativeHeaders sets the required headers for Gemini native API requests. diff --git a/internal/providers/groq/groq.go b/internal/providers/groq/groq.go index ca89fac2..0f20b891 100644 --- a/internal/providers/groq/groq.go +++ b/internal/providers/groq/groq.go @@ -37,9 +37,12 @@ type Provider struct { func New(providerCfg providers.ProviderConfig, opts providers.ProviderOptions) core.Provider { return &Provider{ compat: openai.NewCompatibleProvider(providerCfg.APIKey, opts, openai.CompatibleProviderConfig{ - ProviderName: "groq", - BaseURL: providers.ResolveBaseURL(providerCfg.BaseURL, defaultBaseURL), - SetHeaders: setHeaders, + ProviderName: "groq", + BaseURL: providers.ResolveBaseURL(providerCfg.BaseURL, defaultBaseURL), + SetHeaders: setHeaders, + CustomUpstreamHeaders: providerCfg.CustomUpstreamHeaders, + PassthroughUserHeaders: providerCfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: providerCfg.PassthroughUserHeadersSkip, }), } } diff --git a/internal/providers/headers.go b/internal/providers/headers.go new file mode 100644 index 00000000..fe822002 --- /dev/null +++ b/internal/providers/headers.go @@ -0,0 +1,35 @@ +package providers + +import ( + "net/http" + "strings" +) + +// SkipPassthroughHeader reports whether a header key is excluded from provider +// passthrough. The always-on floor covers hop-by-hop, transport-managed, +// credential, cookie, and X-Forwarded-* headers; extras adds operator-defined +// keys on top. The key is trimmed and canonicalized before matching. +func SkipPassthroughHeader(key string, extras ...string) bool { + canonicalKey := http.CanonicalHeaderKey(strings.TrimSpace(key)) + if canonicalKey == "" { + return false + } + switch canonicalKey { + case "Authorization", "X-Api-Key", "Host", "Content-Length", "Connection", "Keep-Alive", + "Proxy-Authenticate", "Proxy-Authorization", "Te", "Trailer", "Transfer-Encoding", "Upgrade", + "Cookie", "Forwarded", "Set-Cookie": + return true + } + if strings.HasPrefix(canonicalKey, "X-Forwarded-") { + return true + } + for _, raw := range extras { + if raw == "" { + continue + } + if http.CanonicalHeaderKey(strings.TrimSpace(raw)) == canonicalKey { + return true + } + } + return false +} diff --git a/internal/providers/headers_test.go b/internal/providers/headers_test.go new file mode 100644 index 00000000..a5adac5d --- /dev/null +++ b/internal/providers/headers_test.go @@ -0,0 +1,162 @@ +package providers + +import "testing" + +func TestSkipPassthroughHeader_ExactSkippedKeys(t *testing.T) { + skipped := []string{ + "Authorization", + "X-Api-Key", + "Host", + "Content-Length", + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", + "Cookie", + "Forwarded", + "Set-Cookie", + } + + for _, key := range skipped { + key := key + t.Run(key, func(t *testing.T) { + if !SkipPassthroughHeader(key) { + t.Fatalf("expected %q to be skipped", key) + } + }) + } +} + +func TestSkipPassthroughHeader_XForwardedPrefix(t *testing.T) { + cases := []string{ + "X-Forwarded-For", + "X-Forwarded-Host", + "X-Forwarded-Proto", + "X-Forwarded-Client-IP", + } + + for _, key := range cases { + key := key + t.Run(key, func(t *testing.T) { + if !SkipPassthroughHeader(key) { + t.Fatalf("expected %q (X-Forwarded-* prefix) to be skipped", key) + } + }) + } +} + +func TestSkipPassthroughHeader_NotSkipped(t *testing.T) { + notSkipped := []string{ + "X-Request-Id", + "Content-Type", + "Accept", + "User-Agent", + "X-Stainless-Arch", + "X-Custom", + } + + for _, key := range notSkipped { + key := key + t.Run(key, func(t *testing.T) { + if SkipPassthroughHeader(key) { + t.Fatalf("expected %q NOT to be skipped", key) + } + }) + } +} + +func TestSkipPassthroughHeader_CaseInsensitive(t *testing.T) { + cases := []struct { + input string + expected bool + }{ + {"authorization", true}, + {"AUTHORIZATION", true}, + {"x-api-key", true}, + {"x-api-KEY", true}, + {"x-forwarded-for", true}, + {"X-FORWARDED-FOR", true}, + {"x-forwarded-proto", true}, + {"x-request-id", false}, + {"CONTENT-TYPE", false}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + got := SkipPassthroughHeader(tc.input) + if got != tc.expected { + t.Fatalf("SkipPassthroughHeader(%q) = %v, want %v", tc.input, got, tc.expected) + } + }) + } +} + +func TestSkipPassthroughHeader_TrimsWhitespace(t *testing.T) { + if !SkipPassthroughHeader(" Authorization ") { + t.Fatalf("expected whitespace-padded Authorization to be skipped") + } + + if !SkipPassthroughHeader("\tX-Forwarded-For\n") { + t.Fatalf("expected whitespace-padded X-Forwarded-For to be skipped") + } + + if SkipPassthroughHeader(" Content-Type ") { + t.Fatalf("expected whitespace-padded Content-Type NOT to be skipped") + } +} + +func TestSkipPassthroughHeader_UserExtras(t *testing.T) { + t.Run("empty_extras_does_nothing", func(t *testing.T) { + if SkipPassthroughHeader("X-Custom", "" /* no extras */) { + t.Fatalf("X-Custom must not be skipped with empty extras") + } + if !SkipPassthroughHeader("Authorization") { + t.Fatalf("Authorization must still be skipped regardless of extras") + } + }) + + t.Run("exact_match", func(t *testing.T) { + if !SkipPassthroughHeader("X-Custom", "X-Custom") { + t.Fatalf("X-Custom must be skipped when in extras") + } + }) + + t.Run("case_insensitive_match", func(t *testing.T) { + if !SkipPassthroughHeader("x-custom", "X-Custom") { + t.Fatalf("lowercase x-custom must match uppercase extra X-Custom") + } + if !SkipPassthroughHeader("X-CUSTOM", "x-custom") { + t.Fatalf("uppercase X-CUSTOM must match lowercase extra x-custom") + } + }) + + t.Run("whitespace_trimmed_match", func(t *testing.T) { + if !SkipPassthroughHeader("X-Custom", " X-Custom ") { + t.Fatalf("extra with whitespace padding must still match") + } + }) + + t.Run("empty_extra_entry_ignored", func(t *testing.T) { + if SkipPassthroughHeader("X-Custom", "", " ", "Other") { + t.Fatalf("X-Custom must not be skipped by empty/whitespace extra entries") + } + if !SkipPassthroughHeader("Other", "", "Other") { + t.Fatalf("Other must be skipped when listed in extras") + } + }) + + t.Run("does_not_affect_always_on_floor", func(t *testing.T) { + // Always-on keys must still be skipped even if not in extras. + if !SkipPassthroughHeader("Authorization", "X-Custom") { + t.Fatalf("Authorization must remain skipped") + } + if !SkipPassthroughHeader("X-Forwarded-For", "X-Custom") { + t.Fatalf("X-Forwarded-* prefix must remain skipped") + } + }) +} diff --git a/internal/providers/init.go b/internal/providers/init.go index 2f9fe573..ed716def 100644 --- a/internal/providers/init.go +++ b/internal/providers/init.go @@ -82,6 +82,10 @@ func Init(ctx context.Context, result *config.LoadResult, factory *ProviderFacto providerMap, credentialResolved := resolveProviders(result.RawProviders, result.Config.Resilience, factory.discoveryConfigsSnapshot()) + if err := validateMutuallyExclusiveHeaders(providerMap); err != nil { + return nil, fmt.Errorf("invalid provider header configuration: %w", err) + } + modelCache, err := initCache(result.Config) if err != nil { return nil, fmt.Errorf("failed to initialize cache: %w", err) diff --git a/internal/providers/kimi/kimi.go b/internal/providers/kimi/kimi.go new file mode 100644 index 00000000..a404143d --- /dev/null +++ b/internal/providers/kimi/kimi.go @@ -0,0 +1,178 @@ +// Package kimi integrates Kimi's OpenAI-compatible API as a thin wrapper over +// the shared openai.CompatibleProvider with standard Bearer auth. +package kimi + +import ( + "context" + "io" + "net/http" + + "gomodel/internal/core" + "gomodel/internal/llmclient" + "gomodel/internal/providers" + "gomodel/internal/providers/openai" +) + +const defaultBaseURL = "https://api.kimi.com/coding/v1" + +func bearerSetHeaders(req *http.Request, apiKey string) { + providers.SetAuthHeaders(req, apiKey, providers.AuthHeaderConfig{AuthScheme: "Bearer "}) +} + +// Registration provides factory registration for the Kimi provider. +var Registration = providers.Registration{ + Type: "kimi", + New: New, + Discovery: providers.DiscoveryConfig{ + DefaultBaseURL: defaultBaseURL, + }, +} + +type Provider struct { + compat *openai.CompatibleProvider +} + +var _ core.Provider = (*Provider)(nil) + +// New creates a new Kimi provider. +func New(providerCfg providers.ProviderConfig, opts providers.ProviderOptions) core.Provider { + return &Provider{ + compat: openai.NewCompatibleProvider(providerCfg.APIKey, opts, openai.CompatibleProviderConfig{ + ProviderName: "kimi", + BaseURL: providers.ResolveBaseURL(providerCfg.BaseURL, defaultBaseURL), + SetHeaders: bearerSetHeaders, + CustomUpstreamHeaders: providerCfg.CustomUpstreamHeaders, + PassthroughUserHeaders: providerCfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: providerCfg.PassthroughUserHeadersSkip, + }), + } +} + +// NewWithHTTPClient creates a new Kimi provider with a custom HTTP client. +// If httpClient is nil, http.DefaultClient is used. +func NewWithHTTPClient(apiKey string, httpClient *http.Client, hooks llmclient.Hooks) *Provider { + return &Provider{ + compat: openai.NewCompatibleProviderWithHTTPClient(apiKey, httpClient, hooks, openai.CompatibleProviderConfig{ + ProviderName: "kimi", + BaseURL: defaultBaseURL, + SetHeaders: bearerSetHeaders, + }), + } +} + +// SetBaseURL allows configuring a custom base URL for the provider. +func (p *Provider) SetBaseURL(url string) { + p.compat.SetBaseURL(url) +} + +// GetBaseURL returns the provider's current base URL. It reads from the +// underlying client so it always reflects SetBaseURL overrides. +func (p *Provider) GetBaseURL() string { + return p.compat.GetBaseURL() +} + +// ChatCompletion sends a chat completion request to Kimi. +func (p *Provider) ChatCompletion(ctx context.Context, req *core.ChatRequest) (*core.ChatResponse, error) { + return p.compat.ChatCompletion(ctx, req) +} + +// StreamChatCompletion returns a raw response body for streaming (caller must close). +func (p *Provider) StreamChatCompletion(ctx context.Context, req *core.ChatRequest) (io.ReadCloser, error) { + return p.compat.StreamChatCompletion(ctx, req) +} + +// ListModels retrieves the list of available models from Kimi. +func (p *Provider) ListModels(ctx context.Context) (*core.ModelsResponse, error) { + return p.compat.ListModels(ctx) +} + +// Responses sends a Responses API request to Kimi. +func (p *Provider) Responses(ctx context.Context, req *core.ResponsesRequest) (*core.ResponsesResponse, error) { + return p.compat.Responses(ctx, req) +} + +// StreamResponses returns a raw response body for streaming Responses API (caller must close). +func (p *Provider) StreamResponses(ctx context.Context, req *core.ResponsesRequest) (io.ReadCloser, error) { + return p.compat.StreamResponses(ctx, req) +} + +// GetResponse retrieves a stored Response by id. +func (p *Provider) GetResponse(ctx context.Context, id string, params core.ResponseRetrieveParams) (*core.ResponsesResponse, error) { + return p.compat.GetResponse(ctx, id, params) +} + +// ListResponseInputItems lists the input items of a stored Response. +func (p *Provider) ListResponseInputItems(ctx context.Context, id string, params core.ResponseInputItemsParams) (*core.ResponseInputItemListResponse, error) { + return p.compat.ListResponseInputItems(ctx, id, params) +} + +// Embeddings sends an embeddings request to Kimi. +func (p *Provider) Embeddings(ctx context.Context, req *core.EmbeddingRequest) (*core.EmbeddingResponse, error) { + return p.compat.Embeddings(ctx, req) +} + +// Passthrough routes an opaque provider-native request to Kimi. +func (p *Provider) Passthrough(ctx context.Context, req *core.PassthroughRequest) (*core.PassthroughResponse, error) { + return p.compat.Passthrough(ctx, req) +} + +// CreateBatch creates a native Kimi batch job. +func (p *Provider) CreateBatch(ctx context.Context, req *core.BatchRequest) (*core.BatchResponse, error) { + return p.compat.CreateBatch(ctx, req) +} + +// GetBatch retrieves a native Kimi batch job. +func (p *Provider) GetBatch(ctx context.Context, id string) (*core.BatchResponse, error) { + return p.compat.GetBatch(ctx, id) +} + +// ListBatches lists native Kimi batch jobs. +func (p *Provider) ListBatches(ctx context.Context, limit int, after string) (*core.BatchListResponse, error) { + return p.compat.ListBatches(ctx, limit, after) +} + +// CancelBatch cancels a native Kimi batch job. +func (p *Provider) CancelBatch(ctx context.Context, id string) (*core.BatchResponse, error) { + return p.compat.CancelBatch(ctx, id) +} + +// GetBatchResults fetches Kimi batch results via the output file API. +func (p *Provider) GetBatchResults(ctx context.Context, id string) (*core.BatchResultsResponse, error) { + return p.compat.GetBatchResults(ctx, id) +} + +// CreateFile uploads a file through Kimi's OpenAI-compatible /files API. +func (p *Provider) CreateFile(ctx context.Context, req *core.FileCreateRequest) (*core.FileObject, error) { + return p.compat.CreateFile(ctx, req) +} + +// ListFiles lists files through Kimi's OpenAI-compatible /files API. +func (p *Provider) ListFiles(ctx context.Context, purpose string, limit int, after string) (*core.FileListResponse, error) { + return p.compat.ListFiles(ctx, purpose, limit, after) +} + +// GetFile retrieves one file object through Kimi's OpenAI-compatible /files API. +func (p *Provider) GetFile(ctx context.Context, id string) (*core.FileObject, error) { + return p.compat.GetFile(ctx, id) +} + +// DeleteFile deletes a file object through Kimi's OpenAI-compatible /files API. +func (p *Provider) DeleteFile(ctx context.Context, id string) (*core.FileDeleteResponse, error) { + return p.compat.DeleteFile(ctx, id) +} + +// GetFileContent fetches raw file bytes through Kimi's /files/{id}/content API. +func (p *Provider) GetFileContent(ctx context.Context, id string) (*core.FileContentResponse, error) { + return p.compat.GetFileContent(ctx, id) +} + +// CreateTranscription transcribes audio through Kimi's OpenAI-compatible +// /audio/transcriptions API. +func (p *Provider) CreateTranscription(ctx context.Context, req *core.AudioTranscriptionRequest) (*core.AudioResponse, error) { + return p.compat.CreateTranscription(ctx, req) +} + +// CreateSpeech synthesizes speech through Kimi's OpenAI-compatible /audio/speech API. +func (p *Provider) CreateSpeech(ctx context.Context, req *core.AudioSpeechRequest) (*core.AudioResponse, error) { + return p.compat.CreateSpeech(ctx, req) +} diff --git a/internal/providers/kimi/kimi_test.go b/internal/providers/kimi/kimi_test.go new file mode 100644 index 00000000..68d7b59b --- /dev/null +++ b/internal/providers/kimi/kimi_test.go @@ -0,0 +1,499 @@ +package kimi + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "gomodel/internal/core" + "gomodel/internal/llmclient" + "gomodel/internal/providers" + "gomodel/internal/providers/openai" +) + +// forbiddenHeaders are headers that must never be present on outbound Kimi +// requests. Kimi's API does not require them and the gateway must not invent +// them on its behalf. +var forbiddenHeaders = []string{ + "X-Stainless-Arch", + "X-Stainless-Os", + "X-Stainless-Package-Version", + "X-Stainless-Runtime", + "X-Stainless-Raw-Response", + "Http-Referer", + "X-Title", +} + +// TestRegistration_Type asserts the factory registration advertises the +// expected provider type string. +func TestRegistration_Type(t *testing.T) { + if Registration.Type != "kimi" { + t.Errorf("Registration.Type = %q, want %q", Registration.Type, "kimi") + } +} + +// TestRegistration_HasDefaultBaseURL asserts the registration carries the +// documented Kimi default base URL so the factory can use it when no override +// is supplied. +func TestRegistration_HasDefaultBaseURL(t *testing.T) { + if Registration.Discovery.DefaultBaseURL != defaultBaseURL { + t.Errorf("Registration.Discovery.DefaultBaseURL = %q, want %q", + Registration.Discovery.DefaultBaseURL, defaultBaseURL) + } +} + +// TestNew_ReturnsNonNilProvider asserts the public constructor returns a +// usable core.Provider implementation. +func TestNew_ReturnsNonNilProvider(t *testing.T) { + p := New(providers.ProviderConfig{APIKey: "test-api-key"}, providers.ProviderOptions{}) + if p == nil { + t.Fatal("New() returned nil") + } +} + +// TestNew_DelegatesToCompatibleProvider asserts the provider wraps a +// *openai.CompatibleProvider. This is the structural contract that lets Kimi +// reuse the shared OpenAI-compatible transport. +func TestNew_DelegatesToCompatibleProvider(t *testing.T) { + p := New(providers.ProviderConfig{APIKey: "test-api-key"}, providers.ProviderOptions{}) + + kp, ok := p.(*Provider) + if !ok { + t.Fatalf("New() returned %T, want *kimi.Provider", p) + } + if kp.compat == nil { + t.Fatal("Provider.compat should not be nil") + } +} + +// TestNewWithHTTPClient_DelegatesToCompatibleProvider asserts the +// HTTP-client-aware constructor also wires up the shared compatible provider. +func TestNewWithHTTPClient_DelegatesToCompatibleProvider(t *testing.T) { + p := NewWithHTTPClient("test-api-key", nil, llmclient.Hooks{}) + if p == nil { + t.Fatal("NewWithHTTPClient() returned nil") + } + if p.compat == nil { + t.Fatal("Provider.compat should not be nil") + } + if _, ok := any(p).(core.Provider); !ok { + t.Fatal("*Provider should satisfy core.Provider") + } +} + +// TestNew_DefaultBaseURL asserts that omitting BaseURL on the public +// constructor resolves to the package-level default base URL. +func TestNew_DefaultBaseURL(t *testing.T) { + p := New(providers.ProviderConfig{APIKey: "test-api-key"}, providers.ProviderOptions{}) + kp, ok := p.(*Provider) + if !ok { + t.Fatalf("New() returned %T, want *kimi.Provider", p) + } + if got := kp.GetBaseURL(); got != defaultBaseURL { + t.Errorf("GetBaseURL() = %q, want %q", got, defaultBaseURL) + } +} + +// TestNewWithHTTPClient_DefaultBaseURL asserts the same default for the +// HTTP-client-aware constructor. +func TestNewWithHTTPClient_DefaultBaseURL(t *testing.T) { + p := NewWithHTTPClient("test-api-key", nil, llmclient.Hooks{}) + if got := p.GetBaseURL(); got != defaultBaseURL { + t.Errorf("GetBaseURL() = %q, want %q", got, defaultBaseURL) + } +} + +// TestNew_CustomBaseURLOverridesDefault asserts a caller-supplied BaseURL +// takes precedence over the package default. +func TestNew_CustomBaseURLOverridesDefault(t *testing.T) { + const custom = "https://example.test/v1" + p := New(providers.ProviderConfig{ + APIKey: "test-api-key", + BaseURL: custom, + }, providers.ProviderOptions{}) + + kp, ok := p.(*Provider) + if !ok { + t.Fatalf("New() returned %T, want *kimi.Provider", p) + } + if got := kp.GetBaseURL(); got != custom { + t.Errorf("GetBaseURL() = %q, want %q", got, custom) + } +} + +// TestChatCompletion_PlainBearerAuthOnly asserts that an outbound Kimi chat +// completion request carries exactly one Authorization header with the API key +// and no other provider-specific headers. +func TestChatCompletion_PlainBearerAuthOnly(t *testing.T) { + var captured http.Header + var capturedPath string + var capturedMethod string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Clone() + capturedPath = r.URL.Path + capturedMethod = r.Method + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "chatcmpl-kimi-1", + "object": "chat.completion", + "created": 1700000000, + "model": "kimi-k2-0711-preview", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2} + }`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("kimi-test-key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "kimi-k2-0711-preview", + Messages: []core.Message{ + {Role: "user", Content: "Hello"}, + }, + }) + if err != nil { + t.Fatalf("ChatCompletion() error = %v", err) + } + if resp == nil || len(resp.Choices) != 1 { + t.Fatalf("ChatCompletion() returned unexpected response: %+v", resp) + } + + if capturedMethod != http.MethodPost { + t.Errorf("method = %q, want %q", capturedMethod, http.MethodPost) + } + if capturedPath != "/chat/completions" { + t.Errorf("path = %q, want %q", capturedPath, "/chat/completions") + } + + // Exactly one Authorization header, with "Bearer " value. + auth := captured.Values("Authorization") + if len(auth) != 1 { + t.Fatalf("Authorization header count = %d, want 1", len(auth)) + } + if auth[0] != "Bearer kimi-test-key" { + t.Errorf("Authorization = %q, want %q", auth[0], "Bearer kimi-test-key") + } + + // No provider-specific headers must leak onto the outbound request. + for _, h := range forbiddenHeaders { + if v := captured.Get(h); v != "" { + t.Errorf("unexpected %s header present: %q", h, v) + } + } + + // Content-Type must still be JSON (the only "non-auth" header we expect). + if ct := captured.Get("Content-Type"); ct == "" { + t.Error("Content-Type header missing") + } +} + +// TestListModels_PlainBearerAuthOnly asserts the same plain-Bearer contract +// for the ListModels endpoint, which is also routed through the compat layer. +func TestListModels_PlainBearerAuthOnly(t *testing.T) { + var captured http.Header + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Clone() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"object":"list","data":[]}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("kimi-test-key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + if _, err := provider.ListModels(context.Background()); err != nil { + t.Fatalf("ListModels() error = %v", err) + } + + auth := captured.Values("Authorization") + if len(auth) != 1 { + t.Fatalf("Authorization header count = %d, want 1", len(auth)) + } + if !strings.HasPrefix(auth[0], "Bearer ") { + t.Errorf("Authorization = %q, want Bearer prefix", auth[0]) + } + + for _, h := range forbiddenHeaders { + if v := captured.Get(h); v != "" { + t.Errorf("unexpected %s header present: %q", h, v) + } + } +} + +// TestProvider_SatisfiesCoreProviderInterface is a compile-time check that +// *Provider continues to implement core.Provider. If a method is added to +// the interface and not delegated here, this test fails to build. +func TestProvider_SatisfiesCoreProviderInterface(t *testing.T) { + var _ core.Provider = (*Provider)(nil) + var _ core.Provider = New(providers.ProviderConfig{APIKey: "k"}, providers.ProviderOptions{}) +} + +// TestSetBaseURL_OverridesDefault asserts the SetBaseURL passthrough updates +// the underlying compatible provider. +func TestSetBaseURL_OverridesDefault(t *testing.T) { + provider := NewWithHTTPClient("kimi-test-key", nil, llmclient.Hooks{}) + const custom = "https://kimi.example.test/v1" + provider.SetBaseURL(custom) + + if got := provider.GetBaseURL(); got != custom { + t.Errorf("GetBaseURL() = %q, want %q", got, custom) + } +} + +// TestStreamChatCompletion_PlainBearerAuthOnly asserts streaming requests +// also use only Bearer auth. +func TestStreamChatCompletion_PlainBearerAuthOnly(t *testing.T) { + var captured http.Header + var bodyBytes []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Clone() + bodyBytes, _ = io.ReadAll(r.Body) + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"id\":\"x\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"m\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\ndata: [DONE]\n")) + })) + defer server.Close() + + provider := NewWithHTTPClient("kimi-test-key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + body, err := provider.StreamChatCompletion(context.Background(), &core.ChatRequest{ + Model: "kimi-k2-0711-preview", + Messages: []core.Message{{Role: "user", Content: "Hello"}}, + }) + if err != nil { + t.Fatalf("StreamChatCompletion() error = %v", err) + } + defer func() { _ = body.Close() }() + + auth := captured.Values("Authorization") + if len(auth) != 1 || auth[0] != "Bearer kimi-test-key" { + t.Errorf("Authorization headers = %v, want exactly [\"Bearer kimi-test-key\"]", auth) + } + + for _, h := range forbiddenHeaders { + if v := captured.Get(h); v != "" { + t.Errorf("unexpected %s header present: %q", h, v) + } + } + + if len(bodyBytes) == 0 { + t.Error("request body was empty") + } +} + +// TestCompatProviderType ensures the delegation target is exactly the +// shared OpenAI-compatible provider implementation. This guards against +// accidental swaps to a different transport type. +func TestCompatProviderType(t *testing.T) { + p := NewWithHTTPClient("kimi-test-key", nil, llmclient.Hooks{}) + if p.compat == nil { + t.Fatal("Provider.compat should not be nil") + } + var _ *openai.CompatibleProvider = p.compat +} + +// kimiHeaderServer wraps an httptest.Server whose handler records the headers +// of the most recent inbound request. Tests inspect captures via the returned +// getter to keep assertions on the same goroutine that served the request. +func kimiHeaderServer(t *testing.T) (*httptest.Server, func() http.Header) { + t.Helper() + var last http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + last = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id": "chatcmpl-kimi-1", + "object": "chat.completion", + "created": 1700000000, + "model": "kimi-k2-0711-preview", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2} + }`)) + })) + t.Cleanup(server.Close) + return server, func() http.Header { return last } +} + +// kimiSnapshotContext builds a context carrying a RequestSnapshot populated +// with the supplied inbound headers, so the header applier observes them as +// if they had arrived at ingress. +func kimiSnapshotContext(t *testing.T, headers map[string][]string) context.Context { + t.Helper() + snapshot := core.NewRequestSnapshot( + http.MethodPost, + "/v1/chat/completions", + nil, + nil, + headers, + "", + nil, + false, + "", + nil, + ) + return core.WithRequestSnapshot(context.Background(), snapshot) +} + +// TestChatCompletion_PassthroughDefaultTrueForwardsInboundHeaders asserts that +// the kimi provider — when constructed with PassthroughUserHeaders: true, the +// resolved value of buildProviderConfig's default for "kimi" — forwards the +// non-skipped inbound headers onto the outbound chat completion request while +// still installing the bearer Authorization set by bearerSetHeaders. The +// default-on behavior itself is covered in providers/config_test.go; this +// test pins the wiring through kimi.New → CompatibleProviderConfig. +func TestChatCompletion_PassthroughDefaultTrueForwardsInboundHeaders(t *testing.T) { + server, captured := kimiHeaderServer(t) + + provider := New(providers.ProviderConfig{ + APIKey: "kimi-test-key", + BaseURL: server.URL, + PassthroughUserHeaders: true, // resolved value of kimi's default-true + }, providers.ProviderOptions{}) + + ctx := kimiSnapshotContext(t, map[string][]string{ + "X-Tenant-Id": {"acme"}, + "X-Trace-Id": {"trace-abc"}, + }) + + if _, err := provider.ChatCompletion(ctx, &core.ChatRequest{ + Model: "kimi-k2-0711-preview", + Messages: []core.Message{{Role: "user", Content: "Hello"}}, + }); err != nil { + t.Fatalf("ChatCompletion() error = %v", err) + } + + got := captured() + if v := got.Get("X-Tenant-Id"); v != "acme" { + t.Errorf("X-Tenant-Id = %q, want acme (default true must forward)", v) + } + if v := got.Get("X-Trace-Id"); v != "trace-abc" { + t.Errorf("X-Trace-Id = %q, want trace-abc (default true must forward)", v) + } + if v := got.Get("Authorization"); v != "Bearer kimi-test-key" { + t.Errorf("Authorization = %q, want Bearer kimi-test-key", v) + } +} + +// TestChatCompletion_PassthroughSkipsReservedInboundHeaders asserts that when +// passthrough is enabled, reserved hop-by-hop / credential / cookie headers +// from the inbound snapshot never reach the upstream Kimi call. In particular, +// Authorization must be preserved as the bearer credential installed by +// bearerSetHeaders, not overwritten by a hostile inbound value. +func TestChatCompletion_PassthroughSkipsReservedInboundHeaders(t *testing.T) { + server, captured := kimiHeaderServer(t) + + provider := New(providers.ProviderConfig{ + APIKey: "kimi-test-key", + BaseURL: server.URL, + PassthroughUserHeaders: true, + }, providers.ProviderOptions{}) + + ctx := kimiSnapshotContext(t, map[string][]string{ + "Authorization": {"Bearer inbound-attacker"}, + "X-Api-Key": {"inbound-key"}, + "Cookie": {"session=abc"}, + "Connection": {"close"}, + "Host": {"inbound.example"}, + "Content-Length": {"9999"}, + "X-Forwarded-For": {"1.2.3.4"}, + "X-Tenant-Id": {"acme"}, + }) + + if _, err := provider.ChatCompletion(ctx, &core.ChatRequest{ + Model: "kimi-k2-0711-preview", + Messages: []core.Message{{Role: "user", Content: "Hello"}}, + }); err != nil { + t.Fatalf("ChatCompletion() error = %v", err) + } + + got := captured() + if v := got.Get("Authorization"); v != "Bearer kimi-test-key" { + t.Errorf("Authorization = %q, want Bearer kimi-test-key (bearerSetHeaders must win)", v) + } + if v := got.Get("X-Api-Key"); v != "" { + t.Errorf("X-Api-Key = %q, want empty (skipped passthrough)", v) + } + if v := got.Get("Cookie"); v != "" { + t.Errorf("Cookie = %q, want empty (skipped passthrough)", v) + } + if v := got.Get("Connection"); v != "" { + t.Errorf("Connection = %q, want empty (skipped passthrough)", v) + } + if v := got.Get("X-Forwarded-For"); v != "" { + t.Errorf("X-Forwarded-For = %q, want empty (skipped X-Forwarded-* prefix)", v) + } + if v := got.Get("X-Tenant-Id"); v != "acme" { + t.Errorf("X-Tenant-Id = %q, want acme (non-skipped headers must still flow)", v) + } +} + +// TestChatCompletion_PassthroughDisabledIgnoresInboundHeaders asserts that +// setting PassthroughUserHeaders: false on the resolved ProviderConfig blocks +// inbound header forwarding entirely, while bearerSetHeaders still installs +// Authorization on the outbound request. CustomUpstreamHeaders remain applicable when +// supplied alongside the explicit disable. +func TestChatCompletion_PassthroughDisabledIgnoresInboundHeaders(t *testing.T) { + server, captured := kimiHeaderServer(t) + + provider := New(providers.ProviderConfig{ + APIKey: "kimi-test-key", + BaseURL: server.URL, + PassthroughUserHeaders: false, + CustomUpstreamHeaders: map[string]string{ + "X-Org": "custom-org", + }, + }, providers.ProviderOptions{}) + + ctx := kimiSnapshotContext(t, map[string][]string{ + "X-Tenant-Id": {"acme"}, + "X-Trace-Id": {"trace-abc"}, + }) + + if _, err := provider.ChatCompletion(ctx, &core.ChatRequest{ + Model: "kimi-k2-0711-preview", + Messages: []core.Message{{Role: "user", Content: "Hello"}}, + }); err != nil { + t.Fatalf("ChatCompletion() error = %v", err) + } + + got := captured() + if v := got.Get("X-Tenant-Id"); v != "" { + t.Errorf("X-Tenant-Id = %q, want empty (passthrough disabled)", v) + } + if v := got.Get("X-Trace-Id"); v != "" { + t.Errorf("X-Trace-Id = %q, want empty (passthrough disabled)", v) + } + if v := got.Get("X-Org"); v != "custom-org" { + t.Errorf("X-Org = %q, want custom-org (custom headers must still apply)", v) + } + if v := got.Get("Authorization"); v != "Bearer kimi-test-key" { + t.Errorf("Authorization = %q, want Bearer kimi-test-key (bearerSetHeaders still installs auth)", v) + } +} + +// (mutex behavior covered by TestApplyRequestHeaderOverrides_BothPassthroughAndCustomPanics +// in internal/providers/request_header_overrides_test.go — the runtime mutex is the +// second line of defense, and the config-load validator in resolveProviders rejects +// the combination at startup) diff --git a/internal/providers/minimax/minimax.go b/internal/providers/minimax/minimax.go index cdea795a..a2a9b5c6 100644 --- a/internal/providers/minimax/minimax.go +++ b/internal/providers/minimax/minimax.go @@ -37,8 +37,11 @@ var _ core.Provider = (*Provider)(nil) // New creates a new MiniMax provider. func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Provider { return &Provider{openai.NewChatCompatible(cfg.APIKey, opts, openai.CompatibleProviderConfig{ - ProviderName: "minimax", - BaseURL: providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL), + ProviderName: "minimax", + BaseURL: providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL), + CustomUpstreamHeaders: cfg.CustomUpstreamHeaders, + PassthroughUserHeaders: cfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: cfg.PassthroughUserHeadersSkip, })} } diff --git a/internal/providers/openai/compatible_provider.go b/internal/providers/openai/compatible_provider.go index 0aee5bfc..567346bb 100644 --- a/internal/providers/openai/compatible_provider.go +++ b/internal/providers/openai/compatible_provider.go @@ -16,10 +16,13 @@ import ( type RequestMutator func(*llmclient.Request) type CompatibleProviderConfig struct { - ProviderName string - BaseURL string - SetHeaders func(*http.Request, string) - RequestMutator RequestMutator + ProviderName string + BaseURL string + SetHeaders func(*http.Request, string) + RequestMutator RequestMutator + CustomUpstreamHeaders map[string]string + PassthroughUserHeaders bool + PassthroughUserHeadersSkip []string } type CompatibleProvider struct { @@ -42,11 +45,14 @@ func NewCompatibleProvider(apiKey string, opts providers.ProviderOptions, cfg Co Hooks: opts.Hooks, CircuitBreaker: opts.Resilience.CircuitBreaker, } - p.client = llmclient.New(clientCfg, func(req *http.Request) { + setHeaders := func(req *http.Request) { if cfg.SetHeaders != nil { cfg.SetHeaders(req, apiKey) } - }) + // Run last so custom/passthrough headers win over the auth header above. + providers.ApplyRequestHeaderOverrides(req.Context(), req.Header, cfg.CustomUpstreamHeaders, cfg.PassthroughUserHeaders, cfg.PassthroughUserHeadersSkip...) + } + p.client = llmclient.New(clientCfg, setHeaders) return p } @@ -61,11 +67,13 @@ func NewCompatibleProviderWithHTTPClient(apiKey string, httpClient *http.Client, } clientCfg := llmclient.DefaultConfig(cfg.ProviderName, cfg.BaseURL) clientCfg.Hooks = hooks - p.client = llmclient.NewWithHTTPClient(httpClient, clientCfg, func(req *http.Request) { + setHeaders := func(req *http.Request) { if cfg.SetHeaders != nil { cfg.SetHeaders(req, apiKey) } - }) + providers.ApplyRequestHeaderOverrides(req.Context(), req.Header, cfg.CustomUpstreamHeaders, cfg.PassthroughUserHeaders, cfg.PassthroughUserHeadersSkip...) + } + p.client = llmclient.NewWithHTTPClient(httpClient, clientCfg, setHeaders) return p } diff --git a/internal/providers/openai/compatible_provider_test.go b/internal/providers/openai/compatible_provider_test.go index c4793416..cc615089 100644 --- a/internal/providers/openai/compatible_provider_test.go +++ b/internal/providers/openai/compatible_provider_test.go @@ -94,3 +94,272 @@ func TestCompatibleProvider_ListModels_ReturnsUpstreamError(t *testing.T) { t.Errorf("gatewayErr.Type = %q, want provider_error or not_found_error", gatewayErr.Type) } } + +// headerCaptureServer returns a test server that records the most recent +// request headers; the getter lets assertions run on the calling goroutine. +func headerCaptureServer(t *testing.T) (*httptest.Server, func() http.Header) { + t.Helper() + var last http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + last = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(server.Close) + return server, func() http.Header { return last } +} + +// snapshotContext builds a context whose RequestSnapshot carries the given +// inbound headers, so the header applier sees them as if set at ingress. +func snapshotContext(t *testing.T, headers map[string][]string) context.Context { + t.Helper() + snapshot := core.NewRequestSnapshot( + http.MethodPost, + "/v1/test", + nil, + nil, + headers, + "", + nil, + false, + "", + nil, + ) + return core.WithRequestSnapshot(context.Background(), snapshot) +} + +func TestCompatibleProvider_PassthroughForwardsInboundHeaders(t *testing.T) { + server, captured := headerCaptureServer(t) + + provider := NewCompatibleProviderWithHTTPClient( + "test-key", + server.Client(), + llmclient.Hooks{}, + CompatibleProviderConfig{ + ProviderName: "passthrough", + BaseURL: server.URL, + PassthroughUserHeaders: true, + }, + ) + + ctx := snapshotContext(t, map[string][]string{ + "X-Trace-Id": {"trace-abc"}, + "X-Tenant-Id": {"acme"}, + }) + + var out any + if err := provider.Do(ctx, llmclient.Request{ + Method: http.MethodPost, + Endpoint: "/test", + Body: map[string]any{"hello": "world"}, + }, &out); err != nil { + t.Fatalf("Do() error = %v", err) + } + + got := captured() + if v := got.Get("X-Trace-Id"); v != "trace-abc" { + t.Errorf("X-Trace-Id = %q, want trace-abc", v) + } + if v := got.Get("X-Tenant-Id"); v != "acme" { + t.Errorf("X-Tenant-Id = %q, want acme", v) + } +} + +func TestCompatibleProvider_PassthroughDisabledIgnoresInboundHeaders(t *testing.T) { + server, captured := headerCaptureServer(t) + + provider := NewCompatibleProviderWithHTTPClient( + "test-key", + server.Client(), + llmclient.Hooks{}, + CompatibleProviderConfig{ + ProviderName: "no-passthrough", + BaseURL: server.URL, + PassthroughUserHeaders: false, + }, + ) + + ctx := snapshotContext(t, map[string][]string{ + "X-Tenant-Id": {"acme"}, + }) + + var out any + if err := provider.Do(ctx, llmclient.Request{ + Method: http.MethodPost, + Endpoint: "/test", + Body: map[string]any{"hello": "world"}, + }, &out); err != nil { + t.Fatalf("Do() error = %v", err) + } + + if v := captured().Get("X-Tenant-Id"); v != "" { + t.Errorf("X-Tenant-Id = %q, want empty when passthrough disabled", v) + } +} + +func TestCompatibleProvider_CustomUpstreamHeadersApplied(t *testing.T) { + server, captured := headerCaptureServer(t) + + provider := NewCompatibleProviderWithHTTPClient( + "test-key", + server.Client(), + llmclient.Hooks{}, + CompatibleProviderConfig{ + ProviderName: "custom-headers", + BaseURL: server.URL, + CustomUpstreamHeaders: map[string]string{ + "X-Org": "acme", + "X-Request-Id": "req-1", + }, + }, + ) + + var out any + if err := provider.Do(context.Background(), llmclient.Request{ + Method: http.MethodPost, + Endpoint: "/test", + Body: map[string]any{"hello": "world"}, + }, &out); err != nil { + t.Fatalf("Do() error = %v", err) + } + + got := captured() + if v := got.Get("X-Org"); v != "acme" { + t.Errorf("X-Org = %q, want acme", v) + } + if v := got.Get("X-Request-Id"); v != "req-1" { + t.Errorf("X-Request-Id = %q, want req-1", v) + } +} + +func TestCompatibleProvider_SetHeadersAuthPreservedWhenNoConflict(t *testing.T) { + server, captured := headerCaptureServer(t) + + provider := NewCompatibleProviderWithHTTPClient( + "secret-key", + server.Client(), + llmclient.Hooks{}, + CompatibleProviderConfig{ + ProviderName: "auth-only", + BaseURL: server.URL, + SetHeaders: func(req *http.Request, apiKey string) { + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("X-Provider", "test") + }, + }, + ) + + var out any + if err := provider.Do(context.Background(), llmclient.Request{ + Method: http.MethodPost, + Endpoint: "/test", + Body: map[string]any{"hello": "world"}, + }, &out); err != nil { + t.Fatalf("Do() error = %v", err) + } + + got := captured() + if v := got.Get("Authorization"); v != "Bearer secret-key" { + t.Errorf("Authorization = %q, want Bearer secret-key", v) + } + if v := got.Get("X-Provider"); v != "test" { + t.Errorf("X-Provider = %q, want test", v) + } +} + +func TestCompatibleProvider_CustomHeaderOverridesSetHeaders(t *testing.T) { + server, captured := headerCaptureServer(t) + + provider := NewCompatibleProviderWithHTTPClient( + "secret-key", + server.Client(), + llmclient.Hooks{}, + CompatibleProviderConfig{ + ProviderName: "custom-overrides-auth", + BaseURL: server.URL, + SetHeaders: func(req *http.Request, apiKey string) { + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("X-Provider", "factory") + }, + CustomUpstreamHeaders: map[string]string{ + "X-Provider": "custom", + }, + }, + ) + + var out any + if err := provider.Do(context.Background(), llmclient.Request{ + Method: http.MethodPost, + Endpoint: "/test", + Body: map[string]any{"hello": "world"}, + }, &out); err != nil { + t.Fatalf("Do() error = %v", err) + } + + got := captured() + if v := got.Get("X-Provider"); v != "custom" { + t.Errorf("X-Provider = %q, want custom", v) + } + if v := got.Get("Authorization"); v != "Bearer secret-key" { + t.Errorf("Authorization = %q, want Bearer secret-key (custom header is non-auth)", v) + } +} + +func TestCompatibleProvider_PassthroughSkipsHopByHopAndAuthHeaders(t *testing.T) { + server, captured := headerCaptureServer(t) + + provider := NewCompatibleProviderWithHTTPClient( + "secret-key", + server.Client(), + llmclient.Hooks{}, + CompatibleProviderConfig{ + ProviderName: "skip", + BaseURL: server.URL, + SetHeaders: func(req *http.Request, apiKey string) { + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("X-Api-Key", apiKey) + }, + PassthroughUserHeaders: true, + }, + ) + + ctx := snapshotContext(t, map[string][]string{ + "Authorization": {"Bearer inbound-attacker"}, + "X-Api-Key": {"inbound-key"}, + "Cookie": {"session=abc"}, + "Connection": {"close"}, + "Host": {"inbound.example"}, + "Content-Length": {"9999"}, + "X-Forwarded-For": {"1.2.3.4"}, + "X-Tenant-Id": {"acme"}, + }) + + var out any + if err := provider.Do(ctx, llmclient.Request{ + Method: http.MethodPost, + Endpoint: "/test", + Body: map[string]any{"hello": "world"}, + }, &out); err != nil { + t.Fatalf("Do() error = %v", err) + } + + got := captured() + if v := got.Get("Authorization"); v != "Bearer secret-key" { + t.Errorf("Authorization = %q, want Bearer secret-key (factory value, not passthrough)", v) + } + if v := got.Get("X-Api-Key"); v != "secret-key" { + t.Errorf("X-Api-Key = %q, want secret-key (factory value, not passthrough)", v) + } + if v := got.Get("Cookie"); v != "" { + t.Errorf("Cookie = %q, want empty (skipped passthrough)", v) + } + if v := got.Get("Connection"); v != "" { + t.Errorf("Connection = %q, want empty (skipped passthrough)", v) + } + if v := got.Get("X-Forwarded-For"); v != "" { + t.Errorf("X-Forwarded-For = %q, want empty (skipped passthrough)", v) + } + if v := got.Get("X-Tenant-Id"); v != "acme" { + t.Errorf("X-Tenant-Id = %q, want acme (allowed passthrough)", v) + } +} diff --git a/internal/providers/openai/openai.go b/internal/providers/openai/openai.go index 8a40816f..7fcf6fae 100644 --- a/internal/providers/openai/openai.go +++ b/internal/providers/openai/openai.go @@ -41,9 +41,12 @@ func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Prov baseURL := providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL) return &Provider{ CompatibleProvider: NewCompatibleProvider(cfg.APIKey, opts, CompatibleProviderConfig{ - ProviderName: "openai", - BaseURL: baseURL, - SetHeaders: setHeaders, + ProviderName: "openai", + BaseURL: baseURL, + SetHeaders: setHeaders, + CustomUpstreamHeaders: cfg.CustomUpstreamHeaders, + PassthroughUserHeaders: cfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: cfg.PassthroughUserHeadersSkip, }), apiKey: cfg.APIKey, } diff --git a/internal/providers/opencodego/opencodego.go b/internal/providers/opencodego/opencodego.go index d00cfac0..c220a285 100644 --- a/internal/providers/opencodego/opencodego.go +++ b/internal/providers/opencodego/opencodego.go @@ -73,8 +73,11 @@ var _ core.Provider = (*Provider)(nil) func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Provider { baseURL := providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL) chat := openai.NewChatCompatible(cfg.APIKey, opts, openai.CompatibleProviderConfig{ - ProviderName: "opencode_go", - BaseURL: baseURL, + ProviderName: "opencode_go", + BaseURL: baseURL, + CustomUpstreamHeaders: cfg.CustomUpstreamHeaders, + PassthroughUserHeaders: cfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: cfg.PassthroughUserHeadersSkip, }) messages := anthropic.New(providers.ProviderConfig{APIKey: cfg.APIKey, BaseURL: baseURL}, opts) return &Provider{ diff --git a/internal/providers/openrouter/openrouter.go b/internal/providers/openrouter/openrouter.go index 501a4ef9..dddba400 100644 --- a/internal/providers/openrouter/openrouter.go +++ b/internal/providers/openrouter/openrouter.go @@ -39,9 +39,12 @@ func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Prov appName: envOrDefault("OPENROUTER_APP_NAME", defaultAppName), } p.CompatibleProvider = openai.NewCompatibleProvider(cfg.APIKey, opts, openai.CompatibleProviderConfig{ - ProviderName: "openrouter", - BaseURL: baseURL, - SetHeaders: setHeaders, + ProviderName: "openrouter", + BaseURL: baseURL, + SetHeaders: setHeaders, + CustomUpstreamHeaders: cfg.CustomUpstreamHeaders, + PassthroughUserHeaders: cfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: cfg.PassthroughUserHeadersSkip, }) p.SetRequestMutator(p.mutateRequest) return p diff --git a/internal/providers/oracle/oracle.go b/internal/providers/oracle/oracle.go index b2cf8727..1b446cc3 100644 --- a/internal/providers/oracle/oracle.go +++ b/internal/providers/oracle/oracle.go @@ -29,9 +29,12 @@ func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Prov baseURL := providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL) return &Provider{ compat: openai.NewCompatibleProvider(cfg.APIKey, opts, openai.CompatibleProviderConfig{ - ProviderName: "oracle", - BaseURL: baseURL, - SetHeaders: setHeaders, + ProviderName: "oracle", + BaseURL: baseURL, + SetHeaders: setHeaders, + CustomUpstreamHeaders: cfg.CustomUpstreamHeaders, + PassthroughUserHeaders: cfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: cfg.PassthroughUserHeadersSkip, }), } } diff --git a/internal/providers/request_header_overrides.go b/internal/providers/request_header_overrides.go new file mode 100644 index 00000000..6de1899f --- /dev/null +++ b/internal/providers/request_header_overrides.go @@ -0,0 +1,63 @@ +package providers + +import ( + "context" + "net/http" + "strings" + + "gomodel/internal/core" +) + +// ApplyRequestHeaderOverrides applies per-provider header rules to h. The two +// modes are mutually exclusive: setting both panics. Config-load validation in +// resolveProviders should prevent reaching it; the panic is the backstop so +// internal callers never produce an ambiguous outbound header mix. +// +// With passthrough true, every non-skipped inbound header from the request +// snapshot overwrites its key on h (inbound wins over provider defaults; an +// empty inbound value removes the key). With passthrough false, customHeaders +// is applied verbatim and inbound headers are ignored. Provider-set auth +// headers are written by the factory before this call and are left alone. +func ApplyRequestHeaderOverrides(ctx context.Context, h http.Header, customHeaders map[string]string, passthrough bool, extras ...string) { + if h == nil { + return + } + + hasCustom := len(customHeaders) > 0 + if hasCustom && passthrough { + panic("ApplyRequestHeaderOverrides: passthrough_user_headers and custom_upstream_headers are mutually exclusive; the config-loader should have rejected this") + } + + if passthrough { + snapshot := core.GetRequestSnapshot(ctx) + if snapshot == nil { + return + } + headers := snapshot.HeadersView() + for key, values := range headers { + if SkipPassthroughHeader(key, extras...) { + continue + } + canonicalKey := http.CanonicalHeaderKey(key) + if canonicalKey == "" { + continue + } + h.Del(canonicalKey) + for _, value := range values { + h.Add(canonicalKey, value) + } + } + return + } + + if !hasCustom { + return + } + for key, value := range customHeaders { + canonicalKey := http.CanonicalHeaderKey(strings.TrimSpace(key)) + if canonicalKey == "" { + continue + } + h.Set(canonicalKey, value) + } +} diff --git a/internal/providers/request_header_overrides_test.go b/internal/providers/request_header_overrides_test.go new file mode 100644 index 00000000..117a6993 --- /dev/null +++ b/internal/providers/request_header_overrides_test.go @@ -0,0 +1,352 @@ +package providers + +import ( + "context" + "net/http" + "reflect" + "sort" + "testing" + + "gomodel/internal/core" +) + +func TestApplyRequestHeaderOverrides_NoOverrides(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func() http.Header + headers map[string]string + pass bool + }{ + { + name: "nil custom + passthrough false", + setup: func() http.Header { + return http.Header{"Authorization": {"Bearer original"}} + }, + headers: nil, + pass: false, + }, + { + name: "empty custom + passthrough false", + setup: func() http.Header { + return http.Header{"Authorization": {"Bearer original"}} + }, + headers: map[string]string{}, + pass: false, + }, + { + name: "passthrough true but no snapshot on context", + setup: func() http.Header { + return http.Header{"Authorization": {"Bearer original"}} + }, + headers: nil, + pass: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := tt.setup() + before := cloneHeader(h) + + ApplyRequestHeaderOverrides(context.Background(), h, tt.headers, tt.pass) + + if !reflect.DeepEqual(h, before) { + t.Fatalf("expected no changes, got %v want %v", h, before) + } + }) + } +} + +func TestApplyRequestHeaderOverrides_NilHeader(t *testing.T) { + t.Parallel() + + // Must not panic with a nil header; nothing to apply onto. + ApplyRequestHeaderOverrides(context.Background(), nil, map[string]string{"X-Test": "v"}, false) + ApplyRequestHeaderOverrides(context.Background(), nil, nil, true) +} + +func TestApplyRequestHeaderOverrides_CustomUpstreamHeadersCanonicalizeAndOverwrite(t *testing.T) { + t.Parallel() + + h := http.Header{ + "Authorization": {"Bearer original"}, + "x-custom": {"original-lower"}, + "Existing": {"original-existing"}, + } + + custom := map[string]string{ + "authorization": "Bearer custom", // lower-case key, must canonicalize + "X-Custom": "new-upper", // mixed case + "New-Header": "fresh", // missing before + " Trim-Me ": "trimmed", // whitespace key + } + + ApplyRequestHeaderOverrides(context.Background(), h, custom, false) + + expectEqual(t, h.Get("Authorization"), "Bearer custom") + expectEqual(t, h.Get("X-Custom"), "new-upper") + expectEqual(t, h.Get("New-Header"), "fresh") + expectEqual(t, h.Get("Existing"), "original-existing") // untouched + expectEqual(t, h.Get("Trim-Me"), "trimmed") +} + +func TestApplyRequestHeaderOverrides_CustomHeaderSkipsEmptyKey(t *testing.T) { + t.Parallel() + + h := http.Header{"X-Existing": {"v"}} + + custom := map[string]string{ + "": "ignored", + " ": "ignored", + "X-Real": "value", + } + + ApplyRequestHeaderOverrides(context.Background(), h, custom, false) + + if _, ok := h["X-Ignored"]; ok { + t.Fatalf("unexpected empty key written: %v", h) + } + expectEqual(t, h.Get("X-Real"), "value") +} + +func TestApplyRequestHeaderOverrides_PassthroughAppliesSnapshotHeaders(t *testing.T) { + t.Parallel() + + h := http.Header{"X-Other": {"original"}} + + snapshot := core.NewRequestSnapshot( + "POST", + "/v1/chat", + nil, + nil, + map[string][]string{ + "X-Snapshot-Only": {"snap-1"}, + "x-trace-id": {"trace-abc"}, + "X-Multi": {"one", "two"}, + }, + "application/json", + nil, + false, + "req-1", + nil, + ) + ctx := core.WithRequestSnapshot(context.Background(), snapshot) + + ApplyRequestHeaderOverrides(ctx, h, nil, true) + + expectEqual(t, h.Get("X-Snapshot-Only"), "snap-1") + expectEqual(t, h.Get("X-Trace-Id"), "trace-abc") + expectSlice(t, h.Values("X-Multi"), []string{"one", "two"}) + expectEqual(t, h.Get("X-Other"), "original") +} + +func TestApplyRequestHeaderOverrides_PassthroughSkipsReservedHeaders(t *testing.T) { + t.Parallel() + + h := http.Header{} + + snapshot := core.NewRequestSnapshot( + "POST", "/v1/chat", nil, nil, + map[string][]string{ + "Authorization": {"Bearer snap"}, + "Cookie": {"sid=abc"}, + "X-Forwarded-For": {"1.2.3.4"}, + "X-Forwarded-Host": {"evil.example"}, + "Host": {"upstream.example"}, + "X-Safe": {"ok"}, + }, + "application/json", nil, false, "req-2", nil, + ) + ctx := core.WithRequestSnapshot(context.Background(), snapshot) + + ApplyRequestHeaderOverrides(ctx, h, nil, true) + + if v := h.Get("Authorization"); v != "" { + t.Fatalf("Authorization must be skipped, got %q", v) + } + if v := h.Get("Cookie"); v != "" { + t.Fatalf("Cookie must be skipped, got %q", v) + } + if v := h.Get("X-Forwarded-For"); v != "" { + t.Fatalf("X-Forwarded-* must be skipped, got %q", v) + } + if v := h.Get("X-Forwarded-Host"); v != "" { + t.Fatalf("X-Forwarded-* must be skipped, got %q", v) + } + if v := h.Get("Host"); v != "" { + t.Fatalf("Host must be skipped, got %q", v) + } + expectEqual(t, h.Get("X-Safe"), "ok") +} + +func TestApplyRequestHeaderOverrides_PassthroughOverwritesExisting(t *testing.T) { + t.Parallel() + + h := http.Header{ + "X-Trace-Id": {"old"}, + "X-Only-Local": {"untouched"}, + } + + snapshot := core.NewRequestSnapshot( + "POST", "/v1/chat", nil, nil, + map[string][]string{"X-Trace-Id": {"fresh"}}, + "", nil, false, "req-3", nil, + ) + ctx := core.WithRequestSnapshot(context.Background(), snapshot) + + ApplyRequestHeaderOverrides(ctx, h, nil, true) + + expectEqual(t, h.Get("X-Trace-Id"), "fresh") + expectEqual(t, h.Get("X-Only-Local"), "untouched") +} + +func TestApplyRequestHeaderOverrides_CustomIgnoresSnapshot(t *testing.T) { + t.Parallel() + + h := http.Header{"X-Existing": {"preserve"}} + + snapshot := core.NewRequestSnapshot( + "POST", "/v1/chat", nil, nil, + map[string][]string{"X-Tenant-Id": {"acme"}, "X-Trace-Id": {"trace"}}, + "", nil, false, "req-4", nil, + ) + ctx := core.WithRequestSnapshot(context.Background(), snapshot) + + custom := map[string]string{"X-My-Header": "value"} + ApplyRequestHeaderOverrides(ctx, h, custom, false) + + expectEqual(t, h.Get("X-My-Header"), "value") + if v := h.Get("X-Tenant-Id"); v != "" { + t.Fatalf("custom path must not copy X-Tenant-Id from snapshot, got %q", v) + } + if v := h.Get("X-Trace-Id"); v != "" { + t.Fatalf("custom path must not copy X-Trace-Id from snapshot, got %q", v) + } + expectEqual(t, h.Get("X-Existing"), "preserve") +} + +func TestApplyRequestHeaderOverrides_BothPassthroughAndCustomPanics(t *testing.T) { + t.Parallel() + + defer func() { + if recover() == nil { + t.Fatalf("expected panic when both passthrough and custom headers are set") + } + }() + + ApplyRequestHeaderOverrides( + context.Background(), + http.Header{}, + map[string]string{"X-Custom": "value"}, + true, + ) +} + +func TestApplyRequestHeaderOverrides_PassthroughDeletesExistingValues(t *testing.T) { + t.Parallel() + + h := http.Header{"X-Trace-Id": {"old-1", "old-2"}} + + snapshot := core.NewRequestSnapshot( + "POST", "/v1/chat", nil, nil, + map[string][]string{"X-Trace-Id": {"fresh"}}, + "", nil, false, "req-4", nil, + ) + ctx := core.WithRequestSnapshot(context.Background(), snapshot) + + ApplyRequestHeaderOverrides(ctx, h, nil, true) + + expectSlice(t, h.Values("X-Trace-Id"), []string{"fresh"}) +} + +func TestApplyRequestHeaderOverrides_PassthroughEmptySnapshotMap(t *testing.T) { + t.Parallel() + + h := http.Header{"X-Custom": {"existing"}} + + snapshot := core.NewRequestSnapshot( + "POST", "/v1/chat", nil, nil, nil, "", nil, false, "req-6", nil, + ) + ctx := core.WithRequestSnapshot(context.Background(), snapshot) + + ApplyRequestHeaderOverrides(ctx, h, nil, true) + + expectEqual(t, h.Get("X-Custom"), "existing") +} + +func TestApplyRequestHeaderOverrides_PassthroughSnapshotWithEmptyValueSlice(t *testing.T) { + t.Parallel() + + h := http.Header{"X-Empty": {"keep-or-not"}} + + snapshot := core.NewRequestSnapshot( + "POST", "/v1/chat", nil, nil, + map[string][]string{"X-Empty": nil}, + "", nil, false, "req-7", nil, + ) + ctx := core.WithRequestSnapshot(context.Background(), snapshot) + + ApplyRequestHeaderOverrides(ctx, h, nil, true) + + if _, ok := h["X-Empty"]; ok { + t.Fatalf("X-Empty must be deleted when snapshot slice empty, got %v", h.Values("X-Empty")) + } +} + +func TestApplyRequestHeaderOverrides_PassthroughUserExtrasSkipped(t *testing.T) { + t.Parallel() + + h := http.Header{} + + snapshot := core.NewRequestSnapshot( + "POST", "/v1/chat", nil, nil, + map[string][]string{ + "X-Custom-Skip": {"upstream-rejects"}, + "X-Safe": {"ok"}, + }, + "", nil, false, "req-8", nil, + ) + ctx := core.WithRequestSnapshot(context.Background(), snapshot) + + ApplyRequestHeaderOverrides(ctx, h, nil, true, "X-Custom-Skip") + + if v := h.Get("X-Custom-Skip"); v != "" { + t.Fatalf("X-Custom-Skip must be skipped, got %q", v) + } + expectEqual(t, h.Get("X-Safe"), "ok") +} + +// --- helpers --------------------------------------------------------------- + +func cloneHeader(h http.Header) http.Header { + cloned := make(http.Header, len(h)) + for k, v := range h { + vs := make([]string, len(v)) + copy(vs, v) + cloned[k] = vs + } + return cloned +} + +func expectEqual(t *testing.T, got, want string) { + t.Helper() + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func expectSlice(t *testing.T, got, want []string) { + t.Helper() + g := append([]string(nil), got...) + w := append([]string(nil), want...) + sort.Strings(g) + sort.Strings(w) + if !reflect.DeepEqual(g, w) { + t.Fatalf("got %v, want %v", got, want) + } +} diff --git a/internal/providers/vertex/vertex.go b/internal/providers/vertex/vertex.go index e1d2de35..cb0d9818 100644 --- a/internal/providers/vertex/vertex.go +++ b/internal/providers/vertex/vertex.go @@ -40,6 +40,10 @@ type Provider struct { nativeClient *llmclient.Client authType string configErr error + + customHeaders map[string]string + passthrough bool + passthroughSkip []string } // New creates a new Vertex AI provider. @@ -50,7 +54,10 @@ func New(providerCfg providers.ProviderConfig, opts providers.ProviderOptions) c func newProvider(providerCfg providers.ProviderConfig, opts providers.ProviderOptions, baseHTTPClient *http.Client) *Provider { providerCfg.Backend = "vertex" p := &Provider{ - authType: normalizeAuthType(providerCfg), + authType: normalizeAuthType(providerCfg), + customHeaders: providerCfg.CustomUpstreamHeaders, + passthrough: providerCfg.PassthroughUserHeaders, + passthroughSkip: providerCfg.PassthroughUserHeadersSkip, } p.validateConfig(providerCfg) @@ -153,6 +160,8 @@ func (p *Provider) setHeaders(req *http.Request) { if requestID := core.GetRequestID(req.Context()); requestID != "" { req.Header.Set("X-Request-Id", requestID) } + + providers.ApplyRequestHeaderOverrides(req.Context(), req.Header, p.customHeaders, p.passthrough, p.passthroughSkip...) } // ChatCompletion sends a chat completion request to Vertex AI Gemini. diff --git a/internal/providers/vllm/vllm.go b/internal/providers/vllm/vllm.go index 658f82f0..b5b2c5d7 100644 --- a/internal/providers/vllm/vllm.go +++ b/internal/providers/vllm/vllm.go @@ -38,9 +38,12 @@ func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Prov rootBaseURL := passthroughBaseURL(baseURL) return &Provider{ compatible: openai.NewCompatibleProvider(cfg.APIKey, opts, openai.CompatibleProviderConfig{ - ProviderName: "vllm", - BaseURL: baseURL, - SetHeaders: setHeaders, + ProviderName: "vllm", + BaseURL: baseURL, + SetHeaders: setHeaders, + CustomUpstreamHeaders: cfg.CustomUpstreamHeaders, + PassthroughUserHeaders: cfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: cfg.PassthroughUserHeadersSkip, }), rootClient: llmclient.New(llmclient.Config{ ProviderName: "vllm", diff --git a/internal/providers/xiaomi/xiaomi.go b/internal/providers/xiaomi/xiaomi.go index 9b84dc59..be8c3eed 100644 --- a/internal/providers/xiaomi/xiaomi.go +++ b/internal/providers/xiaomi/xiaomi.go @@ -34,8 +34,11 @@ var _ core.Provider = (*Provider)(nil) // New creates a new Xiaomi MiMo provider. func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Provider { return &Provider{openai.NewChatCompatible(cfg.APIKey, opts, openai.CompatibleProviderConfig{ - ProviderName: "xiaomi", - BaseURL: providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL), + ProviderName: "xiaomi", + BaseURL: providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL), + CustomUpstreamHeaders: cfg.CustomUpstreamHeaders, + PassthroughUserHeaders: cfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: cfg.PassthroughUserHeadersSkip, })} } diff --git a/internal/providers/zai/zai.go b/internal/providers/zai/zai.go index a64bde1d..25dda806 100644 --- a/internal/providers/zai/zai.go +++ b/internal/providers/zai/zai.go @@ -34,8 +34,11 @@ var _ core.Provider = (*Provider)(nil) func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Provider { return &Provider{ ChatCompatible: openai.NewChatCompatible(cfg.APIKey, opts, openai.CompatibleProviderConfig{ - ProviderName: "zai", - BaseURL: providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL), + ProviderName: "zai", + BaseURL: providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL), + CustomUpstreamHeaders: cfg.CustomUpstreamHeaders, + PassthroughUserHeaders: cfg.PassthroughUserHeaders, + PassthroughUserHeadersSkip: cfg.PassthroughUserHeadersSkip, }), apiKey: cfg.APIKey, } diff --git a/tests/contract/testdata/kimi/chat_completion.json b/tests/contract/testdata/kimi/chat_completion.json new file mode 100644 index 00000000..b39f3f3d --- /dev/null +++ b/tests/contract/testdata/kimi/chat_completion.json @@ -0,0 +1,26 @@ +{ + "id": "chatcmpl-dpG5AViIPO76geP1lTj1SuCl", + "object": "chat.completion", + "created": 1783000332, + "model": "kimi-for-coding", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello, World!", + "reasoning_content": "We need answer exactly user said: \"Say 'Hello, World!' and nothing else.\" So output Hello, World! and nothing else. Ensure no extra text." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 18, + "completion_tokens": 39, + "total_tokens": 57, + "cached_tokens": 18, + "prompt_tokens_details": { + "cached_tokens": 18 + } + } +} \ No newline at end of file diff --git a/tests/contract/testdata/kimi/models.json b/tests/contract/testdata/kimi/models.json new file mode 100644 index 00000000..ad90c1b4 --- /dev/null +++ b/tests/contract/testdata/kimi/models.json @@ -0,0 +1,21 @@ +{ + "data": [ + { + "id": "kimi-for-coding", + "created": 1761264000, + "created_at": "2025-10-24T00:00:00Z", + "object": "model", + "display_name": "K2.7 Code", + "type": "model", + "context_length": 262144, + "supports_reasoning": true, + "supports_image_in": true, + "supports_video_in": true, + "supports_thinking_type": "only" + } + ], + "object": "list", + "first_id": "kimi-for-coding", + "last_id": "kimi-for-coding", + "has_more": false +} \ No newline at end of file