diff --git a/.env.example b/.env.example index 1c5b896fe..2f6fc45d6 100644 --- a/.env.example +++ b/.env.example @@ -333,6 +333,81 @@ # Default: false # STREAM_RETRY_ON_REASONING_ONLY=false +# --- Optional Structured Config File --- +# Optional JSON config for structured experimental settings such as fallback +# groups, model routes, advisory pricing, streaming observability, and field +# cache rules. Environment variables override JSON config. Do not put API keys, +# OAuth tokens, bearer tokens, or authorization headers in this file. +# Supported top-level sections include: routing, pricing, streaming, retry, +# responses, field_cache, and providers. The providers section is for non-secret +# provider tuning only; credentials must stay in env or provider credential files. +# LLM_PROXY_CONFIG_FILE=./config/llm-proxy.json + +# --- Fallback Groups and Model Routes --- +# Ordered fallback groups try targets in order when the previous target fails +# with a retryable provider/category error. Execution suffixes: +# @auto Let the provider choose custom/native/LiteLLM behavior +# @native Require native protocol execution +# @custom Require provider custom execution +# @litellm_fallback Explicitly use LiteLLM fallback +# FALLBACK_GROUPS=code_chain +# FALLBACK_GROUP_CODE_CHAIN=codex/gpt-5.1-codex@native,openai/gpt-5.1@litellm_fallback +# MODEL_ROUTE_CODE=group:code_chain + +# --- Provider Cooldown Activation --- +# Provider-level cooldown is conservative and only intended for large/global +# retry-after events, not every per-credential quota error. Model-capacity +# errors can start a model-scoped cooldown without blocking unrelated models. +# PROVIDER_COOLDOWN_MIN_SECONDS=10 +# PROVIDER_COOLDOWN_DEFAULT_SECONDS=30 +# PROVIDER_COOLDOWN_ON_QUOTA=false +# Repeated transient provider/model failures can increase bounded backoff. +# PROVIDER_BACKOFF_WINDOW_SECONDS=60 +# PROVIDER_BACKOFF_THRESHOLD=3 +# PROVIDER_BACKOFF_BASE_SECONDS=0 # 0/unset means use provider cooldown default +# PROVIDER_BACKOFF_MAX_SECONDS=300 +# FAILURE_HISTORY_MAX_ENTRIES=200 + +# --- Responses API Store Policy --- +# Responses are stored in-memory by default. Use provider_cache for durable JSON +# storage via the existing provider-cache layer. TTL/max limits are disabled +# when unset or <= 0. Failed stream responses are stored by default; in-progress +# streaming state is disabled unless explicitly enabled. +# RESPONSES_STORE_BACKEND=memory +# RESPONSES_STORE_CACHE_NAME=responses +# RESPONSES_STORE_CACHE_PREFIX=responses +# RESPONSES_STORE_CACHE_DIR= +# RESPONSES_STORE_CACHE_MEMORY_TTL_SECONDS=3600 +# RESPONSES_STORE_CACHE_DISK_TTL_SECONDS=172800 +# RESPONSES_STORE_TTL_SECONDS=0 +# RESPONSES_STORE_MAX_ITEMS=0 +# RESPONSES_STORE_FAILED=true +# RESPONSES_STORE_IN_PROGRESS=false + +# --- Streaming Observability --- +# Stream lifecycle metrics are traced by default. TTFB/stall timeout values are +# active when set to >0; keep at 0 to disable. Heartbeats are SSE comments and +# do not count as visible output. +# STREAM_TRACE_METRICS=true +# STREAM_TTFB_TIMEOUT_SECONDS=0 +# STREAM_STALL_TIMEOUT_SECONDS=0 +# STREAM_HEARTBEAT_INTERVAL_SECONDS=0 +# STREAM_HEARTBEAT_SECONDS=0 # Legacy alias +# STREAM_CANCEL_UPSTREAM_ON_DISCONNECT=true + +# --- Advisory Model Pricing --- +# Per-token advisory prices used only when providers do not report actual cost. +# Precedence: skip-cost provider setting > provider-reported cost/SSE cost event +# > provider explicit pricing > env pricing > JSON pricing > LiteLLM metadata. +# Streaming providers can report actual cost with `: cost {...}` comments or +# `event: cost` frames. +# Env names sanitize provider/model by replacing non-alphanumerics with `_`. +# MODEL_PRICE_OPENAI_GPT_5_1_INPUT=0.000001 +# MODEL_PRICE_OPENAI_GPT_5_1_OUTPUT=0.00001 +# MODEL_PRICE_OPENAI_GPT_5_1_CACHE_READ=0.0000001 +# MODEL_PRICE_OPENAI_GPT_5_1_CACHE_WRITE=0.000001 +# MODEL_PRICE_OPENAI_GPT_5_1_REASONING=0.00001 + # ------------------------------------------------------------------------------ # | [ADVANCED] HTTP Timeout Configuration | # ------------------------------------------------------------------------------ diff --git a/docs/experimental/00-master-plan.md b/docs/experimental/00-master-plan.md new file mode 100644 index 000000000..2c85482a2 --- /dev/null +++ b/docs/experimental/00-master-plan.md @@ -0,0 +1,132 @@ +# Experimental Native Protocol Roadmap + +This branch is for a long-running experimental rewrite that makes native protocol support the first-class extension point of `rotator_library`, while preserving the existing credential rotation, quota, fair-cycle, session tracking, and provider plugin strengths. + +## Operating Rules + +- Work only on the `experimental` branch. +- Keep all repository work inside `C:\Projects\test\LLM-API-Key-Proxy` and child paths. +- Treat commits as checkpoints. A phase may contain many commits. +- Commit messages must include a body describing what changed, why, tests run, and follow-up considerations. +- Do not commit phase reports written for the user unless explicitly requested. Planning docs under `docs/experimental/` are committed. +- Before each phase implementation, first produce a fresh exhaustive phase plan in conversation text, based on the current code state. Only after that plan is settled should it be written to `docs/experimental/phase-N-*.md`. +- After each phase implementation, call both `explore` and `explore-heavy` agents to review the work against the phase plan, external reference areas, and current proxy behavior. Fix findings and re-review as needed. +- Keep LiteLLM as a fallback path for protocols/providers that are not natively covered yet. Native protocol support should be preferred when available. + +## Strategic Goal + +The target architecture is: + +```text +client API request + -> protocol parse into unified representation + -> field-cache injection + -> adapter chain + -> provider override hooks + -> provider-native request build + -> provider execution and credential rotation + -> provider-native response/stream parse + -> field-cache extraction + -> adapter chain + -> protocol formatting for the client + -> transaction logging for every transform state +``` + +Providers should be able to declare an existing protocol and only override the parts that are genuinely provider-specific. A custom provider should usually be configurable through protocol choice, adapters, field-cache rules, auth strategy, and model options rather than requiring a large bespoke provider implementation. + +## Priority Order + +1. Native protocol foundations, unified types, transformers, adapters, and field-cache rules. +2. OpenAI Responses API support, including future WebSocket extension points. +3. Provider work following the protocol layer: Claude Code, Codex, Copilot, Antigravity, and Gemini CLI parity review. +4. Routing and fallback groups, with optional target-group selectors later. +5. Retry, provider/model cooldown, and failover cleanup. +6. Protocol-aware quota, usage, and cost normalization. +7. Streaming library hardening: SSE now, WebSocket-ready later. +8. Config polish using `.env` and optional JSON. No SQLite dependency for now. +9. Extensive staged tests and review-agent verification. + +## Non-Goals For This Branch + +- Do not make the proxy a full multi-user admin product yet. +- Do not require SQLite or Postgres for the main feature set. +- Do not remove LiteLLM before native coverage exists. +- Do not replace the existing `UsageManager`, fair-cycle, custom caps, or evidence-based `SessionTracker`. +- Do not port frontend/UI work from the external reference gateway. + +## Current Strengths To Preserve + +- Credential-level rotation and priority-aware selection. +- Fair cycle and custom caps. +- Windowed quota tracking and quota groups. +- Evidence-based session tracking with compaction handling. +- Provider plugin discovery. +- Gemini CLI provider behavior unless a reviewed change is clearly better. +- Resilient file/JSON state writing. +- Dynamic OpenAI-compatible provider discovery. + +## Reference Gateway Ideas To Import Carefully + +- Unified protocol/transformer style. +- Adapter registry and configurable provider/model adapters. +- Target groups and direct routing syntax, adapted into fallback-first routing. +- Responses API transformer and storage concepts. +- Stream TTFB/stall detection concepts, implemented with Python-native async primitives. +- Provider/model cooldown and retry-history concepts. +- Usage/cost normalization and provider-reported cost extraction. +- Broader provider support patterns for Claude Code, Codex, Copilot, and Antigravity. + +## Phase Index + +1. Protocol Core. +2. Transform Pass Logging. +3. Adapter and Field Cache System. +4. Responses API and WebSocket-Ready Transport Shape. +5. Provider Protocol Overhaul. +6. Routing and Fallback Groups. +7. Retry/Cooldown/Failover Cleanup. +8. Streaming Library Upgrade. +9. Usage, Quota, and Cost Accuracy. +10. Config Polish. + +Each phase may be subdivided if implementation scope becomes too large. + +## Completeness Matrix + +This matrix exists so the branch does not lose any requested scope while phases evolve. The phase plans are still refreshed before implementation, but every item below must remain accounted for. + +| Requested area | Planned coverage | +| --- | --- | +| Protocols are priority #1 | Phases 1 and 4 create native protocol foundations and Responses support before provider work. | +| Protocols are bases, not gospel | Phase 1 requires override-friendly protocol methods, subclassing, copy/mutate registration, and provider-specific overrides. | +| Move away from LiteLLM | Phase 1 adds a `litellm_fallback` protocol path; later providers should prefer native protocols and use LiteLLM only for unsupported coverage. | +| Add protocols automatically like providers | Phase 1 adds protocol auto-discovery and registry behavior modeled after provider discovery. | +| Cover current providers and reference providers | Phase 1 protocols must cover shapes used by current providers; Phase 5 covers Claude Code, Codex, Copilot, Antigravity, and Gemini CLI parity. | +| Responses API is very needed | Phase 4 is dedicated to Responses, `previous_response_id`, storage, SSE, and WebSocket-ready transport shape. | +| WebSocket support later | Phases 1, 4, and 8 require transport separation so WebSocket can be added without rewriting protocol logic. | +| Adapters/transformers tied to protocols | Phases 1, 2, and 3 define protocol parse/build plus transform tracing, adapter registry, and field-cache rules. | +| Cache and return provider fields | Phase 3 implements configurable extraction/injection rules for request, response, and stream fields with scope and mode controls. | +| Reasoning content and similar fields | Phase 3 explicitly covers reasoning content, thinking signatures, prompt cache keys, response IDs, and provider session IDs. | +| Return all possible or last user/assistant use | Phase 3 modes include `last`, `all`, `last_user_turn`, `last_assistant_turn`, and `per_tool_call`. | +| Per-model custom provider behavior | Phases 3, 5, and 10 cover provider/model field cache rules, adapters, model options, and optional JSON config. | +| Transaction logging after every transform | Phase 2 adds ordered request, response, and stream transform trace passes and integrates them with transaction logging. | +| Comments, docstrings, and key decisions | All implementation phases require docstrings for public abstractions and comments for non-obvious transform, protocol, and future-extension decisions. | +| Providers are priority #2 | Phase 5 follows protocol foundations with Claude Code, Codex, Copilot, Antigravity, and Gemini CLI parity review. | +| Antigravity comparison | Phase 5 explicitly compares the reference Antigravity behavior against `src/rotator_library/providers/_retired/`. | +| Routing is interesting | Phase 6 implements fallback chains first, with target-group selectors later if useful. | +| Fallback groups preferred over target groups | Phase 6 starts with ordered fallback groups and only adds target-group-style selectors after that base works. | +| Retry/cooldown/failover cleanup | Phase 7 makes provider/model cooldown real, adds retry history, backoff, retry-after precedence, and success reset. | +| Quota/usage/cost improvements | Phase 9 adds protocol-aware normalizers, provider-reported cost extraction, structured cost fields, and checker abstractions while keeping existing usage engines. | +| Streaming as library capability | Phase 8 hardens streaming below the proxy route layer with TTFB, TTFT, stall detection, cancellation, and transport-aware stream events. | +| Config via env/json, no SQLite | Phase 10 adds optional JSON config with env overrides and validation. SQLite remains out of scope. | +| Multi-user proxy later | The branch keeps multi-user/admin features as a future expansion and only preserves extension points where natural. | +| Exhaustive tests in stages | Every phase requires tests alongside implementation and phase-end review by both `explore` and `explore-heavy`. | +| Reports are for the user, not git | `06-phase-workflow.md` says planning docs are committed, but phase reports are not committed by default. | + +## Code Quality Expectations + +- Public protocol, adapter, transport, field-cache, and provider-extension classes must have docstrings that explain intent, override points, and future expansion hooks. +- Non-obvious transformations must have comments explaining why data is changed, preserved, reordered, or intentionally dropped. +- Lossy protocol conversions must be documented at the conversion site. +- Future WebSocket, target-group, and multi-user extension seams should be noted in comments where they affect today's design. +- Tests should prefer golden fixtures for protocol shapes and focused unit tests for transform edge cases. diff --git a/docs/experimental/01-protocol-architecture.md b/docs/experimental/01-protocol-architecture.md new file mode 100644 index 000000000..bf6349d83 --- /dev/null +++ b/docs/experimental/01-protocol-architecture.md @@ -0,0 +1,132 @@ +# Native Protocol Architecture + +Protocols are reusable bases, not rigid gospel. Providers can subclass, wrap, copy, or override protocol behavior when a provider deviates from an otherwise standard protocol. + +## Why Protocols First + +The current code relies heavily on LiteLLM and provider-specific transforms. That works, but it makes new protocols hard to reason about and makes debugging transformations difficult. The experimental goal is to make a provider mostly declarative: + +```text +provider = protocol + auth + adapters + field cache rules + model options + quota behavior +``` + +If a provider needs custom behavior, it should override a narrow protocol method instead of forcing an entirely bespoke request path. + +## Auto-Discovery + +Protocols should follow the provider plugin style: + +- protocol modules live under `src/rotator_library/protocols/`. +- modules register concrete protocol classes by name. +- a registry exposes names such as `openai_chat`, `anthropic_messages`, `gemini`, `responses`, and `litellm_fallback`. +- third-party or local protocol modules can be added later with minimal registry changes. + +## Core Types + +The unified representation should be explicit enough to cover all existing providers and the external reference protocols without losing important data. + +Suggested types: + +- `UnifiedRequest` +- `UnifiedResponse` +- `UnifiedStreamEvent` +- `UnifiedMessage` +- `ContentBlock` +- `ToolDefinition` +- `ToolCall` +- `ToolResult` +- `ReasoningBlock` +- `Usage` +- `CostDetails` +- `ProtocolMetadata` + +These types should retain unknown provider-specific metadata in explicit extension dictionaries instead of dropping it. Robustness matters more than a narrow perfect schema. + +## Protocol Interface + +The base protocol should provide default methods that can be overridden: + +- `parse_request(raw_request, context) -> UnifiedRequest` +- `build_request(unified_request, context) -> raw_provider_request` +- `parse_response(raw_response, context) -> UnifiedResponse` +- `format_response(unified_response, context) -> raw_client_response` +- `parse_stream_event(raw_event, context) -> UnifiedStreamEvent` +- `format_stream_event(unified_event, context) -> raw_stream_payload` +- `extract_usage(raw_or_unified, context) -> Usage | None` +- `supports_transport(transport_name) -> bool` + +Provider-specific overrides should receive context that includes provider name, model, credential identity, source protocol, target protocol, request ID, and session tracking information. + +## Initial Protocols + +### OpenAI Chat + +Must support: + +- chat completions request/response. +- stream chunks. +- tools and tool calls. +- function-call legacy shapes. +- reasoning fields from OpenAI-compatible providers. +- cached token and reasoning token usage details. + +### Anthropic Messages + +Must support: + +- messages request/response. +- system content extraction. +- text, image, thinking, redacted thinking, tool_use, tool_result blocks. +- stream lifecycle events. +- count_tokens path later if needed. + +### Gemini + +Must support: + +- generateContent and streamGenerateContent shapes. +- content parts. +- functionCall/functionResponse. +- thought signatures. +- safety settings passthrough without unsafe auto-injection. +- Google/Gemini usage metadata. + +### Responses + +Must support: + +- OpenAI Responses request/response. +- `previous_response_id`. +- output items. +- event streams. +- storage-friendly response objects. +- future WebSocket transport. + +### LiteLLM Fallback + +Must preserve existing behavior for providers/protocols not yet native. This path should be explicit and transaction-logged as a fallback, not hidden. + +## Transport Separation + +Protocol formatting must not be tied only to HTTP SSE. Define a transport boundary so the same unified stream events can be emitted through: + +- non-streaming HTTP JSON. +- HTTP SSE. +- future WebSocket. + +The Responses phase should leave clear extension points for WebSocket even if WebSocket is implemented later. + +## Error Handling + +Protocols should preserve provider error bodies where safe, but format client-facing errors consistently. Parsing errors should include transform-pass names and request IDs to make transaction logs useful. + +## Docstrings And Comments + +Protocol code should include docstrings explaining: + +- which external API shape it models. +- what fields are intentionally preserved in metadata. +- where provider overrides are expected. +- future expansion hooks. + +Comments should explain non-obvious transformations, especially lossy conversions between protocols. diff --git a/docs/experimental/02-transform-logging.md b/docs/experimental/02-transform-logging.md new file mode 100644 index 000000000..5cf74b1d1 --- /dev/null +++ b/docs/experimental/02-transform-logging.md @@ -0,0 +1,84 @@ +# Transform Pass Logging + +Debuggability is a core requirement. Every request, response, and stream payload must be inspectable after each transformation pass. + +## Existing Baseline + +The current project already has transaction and raw I/O logging, but it does not consistently show every intermediate transformed state. The experimental protocol layer must improve this without making normal operation too noisy. + +## Transform Trace Model + +Each request should have a trace containing ordered pass records. A pass record should include: + +- request ID. +- pass name. +- direction: request, response, stream_in, stream_out, error. +- protocol/provider/model context. +- timestamp. +- payload snapshot, redacted if needed. +- optional notes or warnings. +- exception information if the pass failed. + +## Required Request Passes + +Suggested pass names: + +- `raw_client_request` +- `parsed_unified_request` +- `after_session_inference` +- `after_field_cache_injection` +- `after_request_adapters` +- `after_provider_override` +- `provider_request` +- `litellm_fallback_request` when fallback is used + +## Required Response Passes + +- `raw_provider_response` +- `parsed_unified_response` +- `after_field_cache_extraction` +- `after_response_adapters` +- `after_client_protocol_format` +- `final_client_response` + +## Required Stream Passes + +For streams, every event can be large. Logging should support configurable sampling/full capture, but the architecture must be able to record: + +- `raw_provider_stream_event` +- `parsed_unified_stream_event` +- `after_stream_field_cache_extraction` +- `after_stream_adapters` +- `formatted_client_stream_event` + +The transaction logger should be able to record stream events as JSONL to avoid retaining large streams in memory. + +## Redaction + +Even though full multi-user security is out of scope, transform logging must avoid accidental credential leakage. + +Redaction hooks should cover: + +- `Authorization` headers. +- `x-api-key` and related API key headers. +- provider API keys. +- OAuth access and refresh tokens. +- cookies. +- obvious `api_key`, `access_token`, `refresh_token`, `client_secret`, and `Authorization` fields in JSON payloads. + +Redaction should happen at the logging boundary, not by mutating live request objects. + +## Failure Debugging + +When a transform fails, the trace should identify: + +- failed pass name. +- protocol class. +- provider/model. +- whether the failure occurred before or after provider execution. +- original error type. +- redacted payload snapshot when possible. + +## Future Expansion + +Later admin/debug endpoints can read these traces. For now, file-based transaction logging is sufficient. diff --git a/docs/experimental/03-field-cache-rules.md b/docs/experimental/03-field-cache-rules.md new file mode 100644 index 000000000..a93e13636 --- /dev/null +++ b/docs/experimental/03-field-cache-rules.md @@ -0,0 +1,113 @@ +# Field Cache Rules + +Field caching is required for providers that need values from previous responses or stream events to be returned on later requests. Examples include reasoning content, thought signatures, prompt cache keys, provider session IDs, and response IDs. + +## Goals + +- Let custom providers configure what fields to extract and where to inject them. +- Avoid hardcoding every provider-specific memory behavior. +- Preserve strict scoping so values never leak across provider, model, credential, classifier scope, or session. +- Support both non-streaming responses and streaming events. +- Support rules per provider and per model. + +## Rule Shape + +Illustrative JSON shape: + +```json +{ + "name": "reasoning_content", + "source": "response", + "path": "choices.*.message.reasoning_content", + "scope": "session", + "mode": "last", + "inject": { + "target": "request", + "path": "messages[-1].reasoning_content" + } +} +``` + +This is a design sketch, not the final schema. + +## Sources + +- `request` +- `response` +- `stream_event` +- `unified_request` +- `unified_response` +- `unified_stream_event` + +## Targets + +- raw provider request path. +- unified request field. +- protocol metadata. +- provider-specific extension field. + +## Scopes + +- `provider` +- `model` +- `credential` +- `session` +- `conversation` +- combinations of the above when needed. + +The default for conversation-affecting fields should be at least provider+model+session scoped. + +## Modes + +- `last`: only the latest matching value. +- `all`: all matching values within the scope. +- `last_user_turn`: latest value associated with the last user turn. +- `last_assistant_turn`: latest value associated with the last assistant turn. +- `per_tool_call`: keyed by tool call ID. + +## Backing Store + +Use the existing provider cache infrastructure first. Do not require SQLite. + +Potential cache keys should include: + +```text +provider / model / credential-or-scope / session-id / rule-name +``` + +Private/classifier scoped credentials must not share cached fields with global credentials. + +## Examples + +### DeepSeek Reasoning Content + +Extract reasoning content from assistant responses and inject it into the next provider request when the provider expects continuity. + +### Gemini Thought Signatures + +Extract thought signatures from Gemini response parts and return them with matching future content parts. + +### Responses Previous Response ID + +Store response IDs and output items so `previous_response_id` can load prior context. + +### Prompt Cache Keys + +Carry `prompt_cache_key` or equivalent provider cache routing values forward when a provider benefits from stable cache routing. + +## Tests + +Required test categories: + +- extraction from response. +- extraction from stream event. +- injection into next request. +- `last` versus `all` behavior. +- scope isolation. +- missing path is a no-op. +- malformed path produces useful validation error. +- redaction in transform logs. + +## Key Decision + +Field cache rules are a protocol/provider extension system, not a replacement for `SessionTracker`. Session tracking decides continuity and credential affinity; field cache rules preserve provider-specific protocol state. diff --git a/docs/experimental/04-provider-roadmap.md b/docs/experimental/04-provider-roadmap.md new file mode 100644 index 000000000..923f84d84 --- /dev/null +++ b/docs/experimental/04-provider-roadmap.md @@ -0,0 +1,115 @@ +# Provider Roadmap + +Providers follow protocols. The protocol layer must land first so provider work can be small, testable, and declarative where possible. + +## Provider Declaration Target + +A provider should eventually be expressible as: + +```text +provider name + + protocol(s) + + auth strategy + + model definitions/options + + adapter chain + + field cache rules + + quota checker/parser + + optional protocol/provider overrides +``` + +Providers can still have custom Python code. The point is to make custom code narrow. + +## Priority Providers + +1. Claude Code. +2. Codex. +3. Copilot. +4. Antigravity. +5. Gemini CLI review and parity improvements. + +## Claude Code + +Review the external reference gateway for: + +- OAuth/token handling. +- request/response protocol shape. +- tool filtering or tool proxy behavior. +- quota checks. +- stream behavior. +- Claude Code-specific headers and model naming. + +Expected implementation direction: + +- use Anthropic Messages or Responses where applicable. +- add provider-specific adapters for tool behavior. +- field cache rules for thinking/signatures if needed. + +## Codex + +Review the external reference gateway for: + +- Responses API route usage. +- Codex-specific user-agent/version behavior. +- OAuth/account handling. +- cooldown parsing for Codex usage limits. +- stream events. + +Expected implementation direction: + +- build on the Responses protocol. +- add Codex provider auth and headers. +- include version/user-agent support if needed. + +## Copilot + +Review the external reference gateway for: + +- GitHub Copilot OAuth flows. +- endpoint selection. +- model naming. +- quota checker behavior. +- provider-specific request filtering. + +Expected implementation direction: + +- use protocol adapters where possible. +- add provider-specific auth/token refresh only where necessary. + +## Antigravity + +Required comparison: + +- current retired implementation under `src/rotator_library/providers/_retired/`. +- external reference Antigravity provider/checker/OAuth behavior. + +Expected implementation direction: + +- restore only what is still valid. +- reuse protocol and field cache rules. +- avoid resurrecting obsolete device-profile behavior unless clearly required. + +## Gemini CLI + +The current Gemini CLI provider is already deep. Review the external reference gateway only for missed behavior: + +- quota checker details. +- thought signature handling. +- stream transform differences. +- OAuth edge cases. +- Gemini 3 tool behavior. +- request headers and endpoint details. + +Do not rewrite Gemini CLI just for architectural purity. + +## Provider Tests + +Each provider should have tests for: + +- config/load/registration. +- auth header or token acquisition. +- request translation. +- response translation. +- stream translation. +- quota parser/checker behavior. +- field cache extraction/injection. +- LiteLLM fallback not used when native path should apply. diff --git a/docs/experimental/05-routing-retry-usage-roadmap.md b/docs/experimental/05-routing-retry-usage-roadmap.md new file mode 100644 index 000000000..210ccf13a --- /dev/null +++ b/docs/experimental/05-routing-retry-usage-roadmap.md @@ -0,0 +1,114 @@ +# Routing, Retry, Cooldown, Usage, And Cost Roadmap + +These systems should be layered around the protocol/provider work without replacing the existing credential engine. + +## Routing Direction + +The preferred first routing model is fallback chains, not a full target-group router from the external reference gateway. + +Example: + +```json +{ + "fallback_groups": [ + ["gemini_cli/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro", "openai/gpt-4.1"] + ] +} +``` + +Behavior: + +- If the requested model is in a group, start with the requested model and then continue through the rest of the group. +- Retryable provider/model failures can move to the next candidate. +- Non-retryable request errors stop unless explicitly configured otherwise. +- Each candidate delegates to the current credential rotation engine. +- Session tracking and classifier scopes must remain isolated. + +## Target Groups Later + +Reference target groups are still useful later for selectors: + +- `in_order` +- `random` +- `usage` +- `cost` +- `latency` +- `performance` + +But the first implementation should solve ordered fallback groups and the missing/stale fallback module before adding selector complexity. + +## Retry/Cooldown Direction + +Current provider cooldown is effectively inert because `CooldownManager.start_cooldown()` is not called. The real active cooldowns are per-credential in `UsageManager`. + +Upgrade direction: + +- keep per-credential cooldowns in `UsageManager`. +- add provider/model cooldown for provider-wide or model-wide failures. +- add consecutive failure tracking. +- add exponential backoff. +- `retry_after` should override computed backoff. +- success should clear provider/model failure count. +- credential quota exhaustion should not globally cool down a provider unless evidence says the provider/model is globally exhausted. + +## Retry History + +Add structured attempt records: + +- candidate provider/model. +- credential stable ID or masked identity. +- protocol path used. +- status: success, failed, skipped. +- error type. +- retryable decision. +- cooldown decision. +- timing. + +These records should be available to transaction logging and optionally client-facing debug output later. + +## Usage And Cost Direction + +Keep existing windowing, fair cycle, custom caps, quota groups, and usage persistence. Add protocol-aware normalization before usage is recorded. + +Needed normalizers: + +- OpenAI Chat. +- OpenAI Responses. +- Anthropic Messages. +- Gemini. +- OAuth/custom providers. + +Fields to preserve: + +- input tokens. +- output tokens. +- reasoning/thinking tokens. +- cache read tokens. +- cache write tokens. +- total tokens. +- provider-reported cost. +- estimated cost. +- cost source and metadata. + +## Cost Direction + +Provider-reported cost should win when available. Sources include: + +- `usage.cost_details`. +- provider-specific response fields. +- SSE comment lines such as `: cost { ... }`. + +Estimated cost should remain as fallback. + +## Tests + +Required tests: + +- fallback chain success after first failure. +- non-retryable error stops fallback. +- streaming fallback only before visible output. +- provider/model cooldown scope. +- backoff escalation. +- success reset. +- usage normalization for each protocol. +- provider-reported cost precedence. diff --git a/docs/experimental/06-phase-workflow.md b/docs/experimental/06-phase-workflow.md new file mode 100644 index 000000000..839fc88e8 --- /dev/null +++ b/docs/experimental/06-phase-workflow.md @@ -0,0 +1,89 @@ +# Phase Workflow + +This document describes how each phase should be executed. + +## 1. Refresh Understanding + +Before writing phase docs or code, inspect the current implementation relevant to the phase. The master plan is a guide, not a substitute for current-code analysis. + +## 2. Produce Phase Plan In Conversation + +First produce the phase plan in conversation text. This forces a fresh exhaustive design pass with the current implementation in mind. + +The plan should include: + +- goals. +- non-goals. +- files/modules to inspect or modify. +- data model changes. +- public API changes. +- transaction logging implications. +- docstrings, comments, and future-extension notes required for the phase. +- tests to add. +- commit checkpoints. +- risk/rollback notes. + +## 3. Write Phase Plan To Docs + +After the conversation plan is accepted or clearly settled, write it to: + +```text +docs/experimental/phase-N-*.md +``` + +Planning docs are committed. + +## 4. Implement In Checkpoint Commits + +Commits are not phase-only. Commit whenever a coherent slice is finished and tested. + +Each commit body should include: + +- what changed. +- why it changed. +- tests run. +- known limitations or follow-ups. + +## 5. Test Continuously + +Run the most relevant tests after each meaningful slice. Do not wait until the phase end. + +## 6. Review Agents + +At phase end, call exactly these two review perspectives: + +- `explore`: code/file-level verification. +- `explore-heavy`: deeper architecture/reference verification. + +Review prompts should compare the implementation against: + +- the phase plan. +- external reference areas where relevant. +- current proxy behavior that must be preserved. +- transaction logging expectations. +- tests. + +If either agent fails or runs out of context, restart it with a narrower prompt. + +## 7. Address Findings + +Fix real findings. If a finding is intentionally deferred, document it in the user-facing phase report, not necessarily in git. + +Repeat review if the changes are substantial. + +## 8. Report To User + +Write a phase report in the conversation. Reports are for the user and are not committed by default. + +The report should include: + +- completed work. +- commits made. +- tests run. +- review-agent findings and resolutions. +- known limitations. +- next phase recommendation. + +## 9. Move To Next Phase + +Do not rely blindly on previous plans. Start the next phase by refreshing context and producing the next phase plan in conversation text. diff --git a/docs/experimental/07-detailed-phase-roadmap.md b/docs/experimental/07-detailed-phase-roadmap.md new file mode 100644 index 000000000..45026167e --- /dev/null +++ b/docs/experimental/07-detailed-phase-roadmap.md @@ -0,0 +1,540 @@ +# Detailed Phase Roadmap + +This document expands the 10-phase roadmap before implementation begins. It exists to prevent later work from narrowing to only Phase 1 details and losing the full feature set. + +Each phase still requires a fresh conversation plan immediately before implementation. This document is the durable baseline; phase-specific plans can adapt after current-code inspection. + +## Phase 1: Protocol Core + +Purpose: create native protocol foundations without changing live execution. + +Primary deliverables: + +- `src/rotator_library/protocols/` package. +- Auto-discovered protocol registry modeled after provider discovery. +- Override-friendly `ProtocolAdapter` base class. +- Unified request, response, message, content, tool, reasoning, usage, cost, stream event, and context dataclasses. +- Base protocols for OpenAI Chat, Anthropic Messages, Gemini, OpenAI Responses, and LiteLLM fallback. +- JSON-safe serialization helpers for transaction tracing and fixtures. +- Protocol errors that identify protocol name, pass name, and payload preview. + +Key requirements: + +- Protocols are bases, not rigid implementations. +- Providers can subclass, wrap, copy, or override protocol methods. +- Unknown provider fields must be preserved in `extra`/metadata instead of dropped. +- Runtime behavior should not change yet. +- LiteLLM fallback remains explicit and named. + +Tests: + +- Registry discovery and alias resolution. +- Base preservation behavior. +- Round-trip and parse/format fixtures for OpenAI Chat, Anthropic Messages, Gemini, and Responses. +- Stream event smoke coverage. + +Review focus: + +- Ensure protocol abstractions are not too narrow for current providers or external reference protocols. +- Ensure types are trace-friendly for Phase 2. + +## Phase 2: Transform Pass Logging + +Purpose: make every transformation state inspectable for debugging. + +Primary deliverables: + +- Transform trace model with ordered pass records. +- Transaction logger integration for request, response, stream, and error transform states. +- Redaction-at-log-boundary helpers. +- JSONL stream transform logs. +- Pass names shared by protocols, adapters, provider overrides, field-cache rules, and fallback execution. + +Required pass coverage: + +- `raw_client_request` +- `parsed_unified_request` +- `after_session_inference` +- `after_field_cache_injection` +- `after_request_adapters` +- `after_provider_override` +- `provider_request` +- `litellm_fallback_request` +- `raw_provider_response` +- `parsed_unified_response` +- `after_field_cache_extraction` +- `after_response_adapters` +- `after_client_protocol_format` +- `final_client_response` +- `raw_provider_stream_event` +- `parsed_unified_stream_event` +- `after_stream_field_cache_extraction` +- `after_stream_adapters` +- `formatted_client_stream_event` + +Key requirements: + +- Log snapshots must not mutate live request/response objects. +- Logs must preserve enough context to debug provider-specific behavior. +- Secret redaction should cover API keys, OAuth tokens, auth headers, cookies, and common secret field names. +- Logging must be usable by future admin/debug endpoints but remain file-based for now. + +Tests: + +- Trace pass ordering. +- Redaction behavior. +- Request/response/stream JSON serialization. +- Transform failure logging with pass name and protocol/provider context. + +Review focus: + +- Verify every planned pass is reachable from the architecture. +- Verify log output is useful without leaking credentials. + +## Phase 3: Adapter And Field Cache System + +Purpose: let custom providers configure what to transform, cache, return, and reinject without hardcoding every provider. + +Primary deliverables: + +- Adapter registry. +- Adapter chain execution for request, response, and stream events. +- Field-cache rule schema and validation. +- JSON-path-like extraction/injection helpers. +- Cache scope builder using provider, model, credential or classifier scope, session, conversation, and rule name. +- Provider/model-level rule configuration hooks. +- Initial built-in adapter rules for reasoning/thinking-related fields. + +Field cache capabilities: + +- Sources: request, response, stream event, unified request, unified response, unified stream event. +- Targets: raw provider request path, unified request field, protocol metadata, provider extension field. +- Scopes: provider, model, credential, session, conversation, and combinations. +- Modes: `last`, `all`, `last_user_turn`, `last_assistant_turn`, `per_tool_call`. +- Missing paths are no-ops unless strict validation is enabled. + +Examples that must be supported by design: + +- DeepSeek-style reasoning content. +- Anthropic thinking and redacted thinking signatures. +- Gemini thought signatures. +- Prompt cache keys. +- Provider session IDs. +- Responses `previous_response_id` metadata. + +Key requirements: + +- Field cache rules complement `SessionTracker`; they do not replace it. +- Rules must never leak across provider/model/credential/session boundaries. +- Providers can define default rules, and model config can override them. +- Transform logging must capture before/after extraction and injection passes. + +Tests: + +- Extract from response and stream events. +- Inject into later requests. +- Scope isolation. +- Mode behavior. +- Malformed path validation. +- Redacted transform logs. + +Review focus: + +- Verify custom provider authoring becomes declarative for common stateful fields. +- Verify no cross-session or cross-credential leaks. + +## Phase 4: Responses API And WebSocket-Ready Transport Shape + +Purpose: add high-priority OpenAI Responses API support while designing stream transports for future WebSocket support. + +Primary deliverables: + +- `/v1/responses` route. +- `GET /v1/responses/{response_id}` route. +- `DELETE /v1/responses/{response_id}` route. +- Optional Codex alias route if supported by provider work. +- JSON/file response storage, not SQLite. +- TTL cleanup for stored responses. +- `previous_response_id` loading and session anchor integration. +- SSE Responses streaming. +- Transport interfaces that can later support WebSocket without rewriting protocol logic. + +Transport shape: + +- Non-streaming HTTP JSON transport. +- HTTP SSE transport. +- Future WebSocket transport extension point. +- Unified stream events from Phase 1 should flow through transports. + +Key requirements: + +- `previous_response_id` must become strong session evidence where safe. +- Stored response objects must be scoped and cleaned up. +- Responses usage and output items must be normalizable in Phase 9. +- Transform logging must show Responses parse/build/storage-relevant states. + +Tests: + +- Create response. +- Retrieve response. +- Delete response. +- Previous response continuation. +- SSE event formatting. +- TTL cleanup. +- Transport interface extension smoke tests. + +Review focus: + +- Verify compatibility with OpenAI Responses expectations and external reference Responses behavior. +- Verify WebSocket is not blocked by the design. + +## Phase 5: Provider Protocol Overhaul + +Purpose: make providers use native protocols where practical and add priority providers after protocol foundations exist. + +Primary deliverables: + +- Provider protocol declaration mechanism. +- Provider hooks for protocol selection, adapter rules, field cache rules, and model options. +- Claude Code provider implementation or integration path. +- Codex provider implementation or integration path. +- Copilot provider implementation or integration path. +- Restored Antigravity provider if current/reference behavior supports it. +- Gemini CLI parity review and targeted fixes only where the external reference gateway has real improvements. + +Provider priorities: + +- Claude Code. +- Codex. +- Copilot. +- Antigravity. +- Gemini CLI review. + +Antigravity comparison requirements: + +- Compare the external reference implementation to `src/rotator_library/providers/_retired/antigravity_provider.py`. +- Restore only valid behavior. +- Avoid obsolete device-profile or fragile logic unless required by the current service. + +Key requirements: + +- Providers can override protocol methods when the base is close but not exact. +- New providers should avoid monolithic transform logic where adapter rules suffice. +- Native path must be testable independently of live credentials. +- LiteLLM fallback should be explicit if used. + +Tests: + +- Provider registration. +- Protocol selection. +- Auth header/token behavior via mocks. +- Request/response/stream translation. +- Quota/checker parsing. +- Field cache rules. +- No accidental LiteLLM fallback when native path should apply. + +Review focus: + +- Verify provider logic follows protocol architecture instead of reintroducing bespoke protocol code everywhere. +- Verify Gemini CLI improvements do not regress existing behavior. + +## Phase 6: Routing And Fallback Groups + +Purpose: add ordered model/provider fallback while keeping credential rotation inside each candidate. + +Primary deliverables: + +- Fallback group config parser. +- Ordered candidate planner for provider/model fallback chains. +- Retryable/non-retryable fallback decisions. +- Streaming fallback rules that only fallback before visible output. +- Optional target-group structure after fallback groups work. +- Optional selectors after ordered fallback works. + +Fallback behavior: + +- If requested model is in a group, try it first. +- Continue through remaining candidates only for retryable failures or configured fallback conditions. +- Preserve classifier/private credential scopes. +- Preserve session namespace isolation. +- Each candidate delegates to current credential selection and usage tracking. + +Future target group selectors: + +- `in_order`. +- `random`. +- `usage`. +- `cost`. +- `latency`. +- `performance`. + +Key requirements: + +- Fix or replace stale `fallback_groups` expectations without breaking current model resolution. +- Do not replace `UsageManager` or `SelectionEngine`. +- Transform logging should show chosen candidate and fallback attempt history. + +Tests: + +- Ordered fallback after retryable failure. +- Non-retryable failure stops. +- Requested model promotion. +- Exhausted chain reports useful error. +- Streaming fallback before output only. +- Scope/session isolation. + +Review focus: + +- Verify fallback integrates above credential rotation, not inside it. +- Verify behavior matches user preference for fallback chains over target-group complexity. + +## Phase 7: Retry/Cooldown/Failover Cleanup + +Purpose: streamline retry behavior and make provider/model cooldown real. + +Primary deliverables: + +- Replace or activate the currently inert provider `CooldownManager` path. +- Provider/model cooldown keys. +- Consecutive failure tracking. +- Exponential backoff. +- Retry-after precedence over computed cooldown. +- Success reset behavior. +- Retry history records for logging and future debug surfaces. +- Integration with fallback groups from Phase 6. + +Cooldown layers: + +- Credential-level cooldown remains in `UsageManager`. +- Provider/model cooldown applies only to evidence of provider-wide or model-wide failure. +- Credential quota exhaustion must not automatically cool an entire provider. +- Model cooldown should not block healthy models on the same provider. +- Provider cooldown should be reserved for provider-wide failure evidence. + +Retry history fields: + +- candidate provider/model. +- credential stable ID or masked identity. +- protocol path used. +- attempt number. +- status: success, failed, skipped, cooled_down. +- error category and raw classifier result. +- retryable decision. +- cooldown decision and duration. +- timing and latency. + +Key requirements: + +- Preserve the current strong retry-after parser, especially Google/Gemini compound duration parsing. +- Preserve streaming safety: no retry after visible output unless explicitly safe. +- Make cooldown state observable by transaction logs and future endpoint surfaces. + +Tests: + +- `start_cooldown` or successor has production callers. +- Provider cooldown blocks provider-wide only. +- Model cooldown blocks one model only. +- Credential cooldown does not become provider cooldown. +- Retry-after overrides exponential backoff. +- Backoff escalates with repeated failures. +- Success clears counts. +- Retry history is recorded. +- Fallback respects cooldown skips. + +Review focus: + +- Verify no healthy credential/model is suppressed too broadly. +- Verify retry/fallback/streaming interactions are deterministic. + +## Phase 8: Streaming Library Upgrade + +Purpose: harden streaming as a reusable library capability, not just a proxy route behavior. + +Primary deliverables: + +- Transport-aware stream event pipeline using unified stream events. +- Explicit upstream iterator close/cancel on client disconnect. +- TTFB timeout before first emitted output. +- TTFT metrics. +- Throughput stall detector. +- Stream usage extraction through protocol normalizers. +- Stream transform logging for raw, unified, adapted, and formatted states. +- SSE transport improvements and WebSocket-ready transport boundaries. + +Streaming behavior: + +- Retry is allowed before client-visible output when the error is retryable. +- Retry is not allowed after visible output unless a future protocol explicitly proves safety. +- Partial streams should not record accepted session response anchors. +- Completed streams should record response anchors and normalized usage. +- Disconnects should close upstream resources promptly. + +Metrics: + +- time to first byte. +- time to first token/content. +- tokens per second where token counts are available. +- chunk count. +- stall duration. +- completion status. + +Key requirements: + +- Implement with Python-native async patterns, not runtime-specific socket internals from the external reference gateway. +- Keep current conservative retry safety. +- Preserve current usage recording behavior and then improve it via Phase 9 normalizers. + +Tests: + +- Disconnect closes upstream async iterator. +- TTFB timeout retries before output. +- Stall detection trips after grace period. +- No retry after visible content. +- Retry before first output. +- Stream transform logging writes expected pass records. +- SSE formatting remains compatible. + +Review focus: + +- Verify library-level design is not tied only to FastAPI route wrappers. +- Verify future WebSocket can reuse the same unified event stream. + +## Phase 9: Usage, Quota, And Cost Accuracy + +Purpose: make usage and cost accounting protocol-aware while preserving the current usage engine. + +Primary deliverables: + +- Protocol-aware usage normalizer interface. +- Usage normalizers for OpenAI Chat, OpenAI Responses, Anthropic Messages, Gemini, OAuth/custom providers, and LiteLLM fallback. +- Structured cost details model integrated with protocol usage. +- Provider-reported cost extraction. +- SSE cost comment parsing where providers emit cost metadata. +- Reasoning/thinking token normalization. +- Cache read/write token normalization. +- Meter/checker abstraction for proactive provider quota checks where useful. + +Usage fields: + +- input tokens. +- output tokens. +- total tokens. +- reasoning/thinking tokens. +- cache read tokens. +- cache write tokens. +- provider-reported cost. +- estimated cost. +- cost source. +- raw provider usage metadata. + +Cost precedence: + +- Provider-reported cost wins when present and trusted. +- Structured provider cost fields are preferred over text parsing. +- SSE `: cost` comments can supply provider-reported cost. +- Estimated cost remains fallback. + +Current systems to preserve: + +- `UsageManager` facade. +- windowed tracking. +- fair cycle. +- custom caps. +- quota groups. +- classifier/private scope separation. +- JSON usage persistence. + +Key requirements: + +- Do not replace the usage engine. +- Normalize before recording when native protocol path is used. +- Existing LiteLLM response usage should continue working. +- Cost details should be transaction-loggable and eventually available to APIs/TUI. + +Tests: + +- OpenAI Chat usage details. +- Responses usage details. +- Anthropic cache read/write tokens. +- Gemini usage metadata including thoughts/cache when available. +- Reasoning token extraction. +- Provider-reported cost precedence. +- SSE cost comment extraction. +- Existing usage aggregation tests still pass. + +Review focus: + +- Verify accounting is more accurate without disrupting selection and quota logic. +- Verify provider-specific raw usage remains available for debugging. + +## Phase 10: Config Polish + +Purpose: support powerful protocol/routing/provider configuration without SQLite. + +Primary deliverables: + +- Optional JSON config file support. +- Env var pointing to JSON config path. +- Env overrides JSON. +- Validation with actionable errors. +- Config sections for protocols, adapters, field-cache rules, fallback groups, providers, model overrides, quota checkers, and stream settings. +- Documentation/examples for custom provider setup using existing protocols. + +Config priorities: + +- `.env` remains enough for simple setups. +- JSON supports complex nested config that is painful in env vars. +- Env overrides allow quick local changes. +- No SQLite/Postgres requirement. + +Potential JSON sections: + +- `protocols`. +- `providers`. +- `models`. +- `adapters`. +- `field_cache`. +- `fallback_groups`. +- `quota_checkers`. +- `streaming`. +- `logging`. + +Key requirements: + +- Validation errors must name the config section and field. +- Bad config should fail early when possible. +- Config should support per-provider and per-model overrides. +- Config should not require rewriting existing `.env` usage immediately. +- Config docs should show custom provider examples using OpenAI Chat, Anthropic Messages, Gemini, Responses, and LiteLLM fallback. + +Tests: + +- Env-only config. +- JSON-only config. +- Env overrides JSON. +- Bad config failure messages. +- Provider/model adapter rule config. +- Fallback group config. +- Field-cache rule config. + +Review focus: + +- Verify custom provider setup becomes practical and documented. +- Verify no accidental database dependency is introduced. + +## Cross-Phase Review Contract + +Every phase ends with two review agents: + +- `explore` for code/file-level verification. +- `explore-heavy` for deeper architecture and reference comparison. + +Review prompts must include the relevant phase plan, external reference areas, current proxy behavior to preserve, tests run, and transaction logging expectations. If either agent fails or runs out of context, restart with narrower scope. + +## Cross-Phase Testing Contract + +Every implementation phase must add or update tests before being considered complete. New test files may need `git add -f` because the repository ignores most `tests/*` by default. + +## Cross-Phase Documentation Contract + +Implementation code must include docstrings and comments for public extension points, non-obvious transformations, lossy conversions, future WebSocket seams, provider override points, and config decisions. diff --git a/docs/experimental/phase-1-protocol-core.md b/docs/experimental/phase-1-protocol-core.md new file mode 100644 index 000000000..aacfea76d --- /dev/null +++ b/docs/experimental/phase-1-protocol-core.md @@ -0,0 +1,172 @@ +# Phase 1: Protocol Core + +## Goal + +Introduce the native protocol foundation without changing live request execution yet. This phase creates robust, documented primitives that later phases can wire into `RequestExecutor`, provider declarations, adapters, field-cache rules, Responses, and streaming transports. + +## Non-Goals + +- Do not replace LiteLLM execution yet. +- Do not add `/v1/responses` routes yet. +- Do not migrate providers to native protocols yet. +- Do not implement field-cache persistence yet. +- Do not rewrite current Anthropic compatibility routes yet. +- Do not change existing request behavior unless a test-only import path requires a harmless export. + +## Current Code Context + +- No `src/rotator_library/protocols/` package exists yet. +- Provider auto-discovery in `src/rotator_library/providers/__init__.py` is the model for protocol discovery. +- `RequestContext` currently holds execution/session/logging fields and can later receive protocol metadata, but Phase 1 should avoid mutating it unless necessary. +- `ProviderTransforms` is hardcoded and will remain active until later adapter migration. +- `TransactionLogger` currently logs initial request, transformed request, response, and stream chunks; Phase 2 will add transform-pass tracing, but Phase 1 types should be trace-friendly. +- `anthropic_compat` already has useful conversion knowledge, especially thinking/tool block handling, but Phase 1 should not delete or replace it. + +## Files To Add + +- `src/rotator_library/protocols/__init__.py` +- `src/rotator_library/protocols/types.py` +- `src/rotator_library/protocols/base.py` +- `src/rotator_library/protocols/registry.py` +- `src/rotator_library/protocols/openai_chat.py` +- `src/rotator_library/protocols/anthropic_messages.py` +- `src/rotator_library/protocols/gemini.py` +- `src/rotator_library/protocols/responses.py` +- `src/rotator_library/protocols/litellm_fallback.py` +- `tests/test_protocol_registry.py` +- `tests/test_protocol_openai_chat.py` +- `tests/test_protocol_anthropic_messages.py` +- `tests/test_protocol_gemini.py` +- `tests/test_protocol_responses.py` + +## Possible Files To Touch + +- `src/rotator_library/__init__.py` only if a public lazy export is needed. +- `src/rotator_library/core/__init__.py` only if shared exports are cleaner there. +- Avoid modifying `RequestExecutor` in Phase 1 unless tests reveal a strict import issue. + +## Data Model Plan + +- `ProtocolRole`: role names should remain strings in payloads, but internal dataclasses can use simple `str` fields to avoid over-constraining custom protocols. +- `ContentBlock`: typed block with `type`, optional text/image/source/tool fields, and `extra` dict for provider-specific data. +- `UnifiedMessage`: `role`, `content`, `name`, `tool_call_id`, `tool_calls`, `extra`. +- `ToolDefinition`: protocol-neutral tool schema with `name`, `description`, `input_schema`, `extra`. +- `ToolCall`: `id`, `name`, `arguments`, `type`, `index`, `extra`. +- `ToolResult`: `tool_call_id`, `content`, `is_error`, `extra`. +- `ReasoningBlock`: `type`, `text`, `signature`, `redacted`, `extra`. +- `Usage`: input/output/total tokens, cache read/write, reasoning tokens, raw usage, cost details. +- `CostDetails`: provider reported cost, estimated cost, currency, source, metadata. +- `UnifiedRequest`: model, messages, tools, system, stream flag, generation params, response format, previous response ID, metadata, raw payload, extra. +- `UnifiedResponse`: id, model, messages/output, stop reason, usage, metadata, raw payload, extra. +- `UnifiedStreamEvent`: event type, delta/message/tool/usage/error metadata, raw event, extra. +- `ProtocolContext`: provider, model, source protocol, target protocol, request ID, session ID, credential stable ID, transport, transaction metadata, provider options. +- `ProtocolResult` is probably unnecessary in Phase 1; keep methods direct and simple. + +## Protocol Interface Plan + +- `ProtocolAdapter` abstract/base class with override-friendly methods. +- Required class attributes: `name`, `aliases`, `supported_transports`. +- Methods: + - `parse_request(raw_request, context) -> UnifiedRequest` + - `build_request(unified_request, context) -> dict` + - `parse_response(raw_response, context) -> UnifiedResponse` + - `format_response(unified_response, context) -> dict` + - `parse_stream_event(raw_event, context) -> UnifiedStreamEvent` + - `format_stream_event(unified_event, context) -> Any` + - `extract_usage(raw_or_unified, context) -> Usage | None` + - `supports_transport(transport_name) -> bool` +- Defaults should preserve unknown data in `extra` rather than raising. +- Errors should be `ProtocolError` with protocol name, pass name, and optional payload preview. + +## Registry Plan + +- `PROTOCOL_PLUGINS: dict[str, type[ProtocolAdapter]]` +- `register_protocol(cls)` for explicit class registration. +- `get_protocol(name)` returns an instance or class consistently. +- `list_protocols()`. +- Auto-discover modules under `src/rotator_library/protocols/`, skipping private modules and infrastructure modules. +- Support aliases so `openai`, `chat_completions`, and `openai_chat` can resolve to the same adapter. +- Prevent duplicate names unless the class is identical or explicit replacement is requested later. +- Keep registry import safe and lightweight. + +## Initial Protocol Behavior + +### OpenAI Chat + +- Parse chat completions request messages, system/developer/user/assistant/tool roles, text content, multimodal content arrays, tools, tool calls, tool_choice, response_format, stream, temperature/top_p/max_tokens/stop. +- Parse responses with choices, assistant messages, tool_calls, reasoning/reasoning_content, and usage details. +- Parse stream chunks into `UnifiedStreamEvent` while preserving raw delta. +- Format back to OpenAI Chat without losing unknown fields in `extra`. + +### Anthropic Messages + +- Parse separate `system`, messages with content blocks, thinking/redacted_thinking, tool_use, tool_result, images, documents where currently supported. +- Map Anthropic usage to `Usage`. +- Preserve thinking signatures in `ReasoningBlock.extra`. +- Build/format enough for round-trip tests, not full replacement of `anthropic_compat` yet. + +### Gemini + +- Parse `contents`, roles, parts, text, inline data, file data, functionCall, functionResponse, thought/thoughtSignature where present. +- Parse `generationConfig`, `safetySettings`, `tools`, stream flag metadata. +- Map usage metadata prompt/candidates/thoughts/cache where possible. +- Preserve provider-specific safety and generation fields. + +### Responses + +- Parse `input`, `instructions`, `previous_response_id`, tools, metadata, stream. +- Parse `output` items, message content, reasoning items, function/tool calls. +- Parse common response stream events into unified stream events. +- Do not add storage or routes yet. + +### LiteLLM Fallback + +- Wrap existing OpenAI-compatible dicts into unified request/response with raw preservation. +- Exist mainly as a named explicit protocol path for later logging. + +## Transaction Logging Implications + +- Phase 1 does not integrate runtime transform logs yet. +- All dataclasses need `to_dict()`/`from_dict()` or safe serialization helpers so Phase 2 can log every pass cleanly. +- `ProtocolError` should include pass names that Phase 2 can reuse. +- Avoid mutating raw payloads in protocol parse methods unless explicitly building a new provider request. + +## Docstrings And Comments Required + +- Every public protocol class explains which external API shape it models and which parts are intentionally partial/base behavior. +- `ProtocolAdapter` docstring explains override contract and why protocols are bases. +- Registry comments explain auto-discovery and skip rules. +- Conversion helpers comment on lossy or approximate mappings. +- Future-extension comments note where WebSocket, field-cache rules, provider overrides, and target transports will attach. + +## Tests + +- Registry auto-discovers built-in protocols. +- Aliases resolve correctly. +- Duplicate registration behavior is deterministic. +- Base protocol default methods preserve raw payloads. +- OpenAI Chat request round-trip with system/developer/user/assistant/tool messages, tool definitions, tool calls, reasoning content, and usage with cache/reasoning token details. +- Anthropic Messages request round-trip with system field, text blocks, thinking/redacted_thinking blocks, tool_use/tool_result, and usage cache fields. +- Gemini request round-trip with contents and parts, functionCall/functionResponse, thoughtSignature, generationConfig/safetySettings, and usageMetadata. +- Responses request/response parse with instructions, input string and input message list, previous_response_id, output messages, reasoning items, and function/tool call items. +- Stream event parse smoke tests for each protocol where practical. +- Serialization tests ensure unified types are JSON-serializable. + +## Commit Checkpoints + +1. Add protocol dataclasses, errors, and serialization helpers. +2. Add base adapter and registry auto-discovery. +3. Add LiteLLM fallback and OpenAI Chat protocol with tests. +4. Add Anthropic Messages protocol with tests. +5. Add Gemini protocol with tests. +6. Add Responses protocol with tests. +7. Run focused protocol test set and fix issues. +8. Phase review with `explore` and `explore-heavy`, then fixes if needed. + +## Risk And Rollback + +- Keep Phase 1 isolated so rollback is just removing the new package/tests. +- Avoid touching executor behavior to prevent regressions. +- If auto-discovery causes import cycles, switch to explicit built-in imports inside registry while preserving the public registry API. +- If tests become too large for one file, split into `tests/protocols/` only if ignore rules are handled with force-add when committing. +- Since `.gitignore` ignores most `tests/*`, remember to force-add new test files when committing. diff --git a/docs/experimental/phase-10-config-polish.md b/docs/experimental/phase-10-config-polish.md new file mode 100644 index 000000000..fd0b884ab --- /dev/null +++ b/docs/experimental/phase-10-config-polish.md @@ -0,0 +1,286 @@ +# Phase 10 Plan: Config Polish + +## Goal + +Make the new protocol/routing/field-cache/streaming/pricing features configurable in a consistent, documented, testable way without replacing the current `.env` workflow. Phase 10 should add a small optional JSON configuration layer, keep environment variables as the final override, document all new knobs in `.env.example`, and expose validation helpers so invalid config fails clearly instead of silently changing routing or accounting behavior. + +## Non-Goals + +- Do not introduce SQLite or any database. +- Do not replace `.env` as the primary user workflow. +- Do not replace provider class declarations, `UsageManager`, `SessionTracker`, `SelectionEngine`, or provider quota trackers. +- Do not implement full multi-user/security config. +- Do not move secrets into JSON config. +- Do not require JSON config for existing deployments. +- Do not rewrite every direct `os.getenv()` call in the repo. +- Do not change default routing, pricing, usage, or streaming behavior when no new config is present. + +## Configuration Precedence + +- Built-in defaults and provider declarations are the base. +- Optional JSON config may provide structured routing/pricing/streaming/provider metadata. +- Environment variables override JSON config. +- Request-level explicit routing/provider fields still win for that request where already supported. +- Secrets remain environment/OAuth-file based; JSON config should not contain API keys or bearer tokens. + +## Current Code Context + +- `.env.example` documents many legacy env vars but not all Phase 6-9 additions. +- `routing/config.py` currently reads fallback groups and model routes from env only. +- Phase 9 added `ModelPricing`, `CostCalculator`, and `ProviderInterface.get_model_pricing()` but not env/JSON pricing. +- Phase 8 added stream metrics primitives but runtime TTFB/stall/heartbeat env knobs are not implemented. +- Phase 7 added provider cooldown env parsing in `retry_policy.provider_cooldown_env()`. +- Phase 3 added field-cache rule dataclasses, but no user-facing JSON config parser for cache rules. +- Provider base URLs and native provider declarations are mostly provider methods or direct env vars. +- Existing usage config env parsing is large and should not be rewritten in Phase 10. +- Reports are user-facing only and should remain uncommitted. + +## Files To Add + +- `src/rotator_library/config/experimental.py` +- `tests/test_experimental_config.py` +- `tests/test_config_pricing.py` +- `tests/test_config_routing_json.py` +- `tests/test_config_stream_settings.py` +- `tests/test_env_example_experimental_config.py` +- Maybe `docs/experimental/config-reference.md` if the implementation needs more detail than `.env.example`. + +## Files Likely To Touch + +- `src/rotator_library/routing/config.py` +- `src/rotator_library/usage/costs.py` +- `src/rotator_library/client/streaming.py` +- `src/rotator_library/client/executor.py` only if streaming env knobs need to be passed through. +- `src/rotator_library/field_cache/rules.py` or equivalent only if JSON parsing needs a small helper. +- `src/rotator_library/providers/provider_interface.py` only if pricing declarations need docstring/typing alignment. +- `.env.example` +- `docs/experimental/phase-10-config-polish.md` + +## JSON Config Model + +Add `ExperimentalConfig` dataclass with optional sections: + +- `routing` +- `pricing` +- `streaming` +- `field_cache` +- `providers` + +Add loader: + +- `load_experimental_config(path=None, env=None)` +- If `path` is `None`, read from `LLM_PROXY_CONFIG_FILE` or `PROXY_CONFIG_FILE`. +- Missing path returns an empty config. +- Invalid JSON raises `ExperimentalConfigError`. +- Unknown top-level sections should be preserved in metadata or warned about, not fatal. + +Add helpers: + +- `as_bool()` +- `as_int()` +- `as_float()` +- `env_key(provider, model, suffix)` sanitizing provider/model names consistently. + +JSON should not read or interpolate secrets. + +## Routing JSON + +Support JSON shape: + +- `routing.fallback_groups..targets = ["codex/gpt-5.1-codex@native", "openai/gpt-5.1@litellm_fallback"]` +- `routing.fallback_groups..failover_on = ["rate_limit", "server_error"]` +- `routing.fallback_groups..stop_on = ["authentication", "validation"]` +- `routing.model_routes. = "group:code_chain"` or target spec. + +`routing/config.py` should merge JSON first, then env: + +- env `FALLBACK_GROUPS`, `FALLBACK_GROUP_*`, `MODEL_ROUTE_*` override or add entries. +- env overrides should retain current behavior for existing deployments. + +Validation: + +- empty groups are invalid. +- `group:` model routes must reference known groups after merge. +- invalid target specs raise `RoutingConfigError`. + +Tests: + +- JSON-only routing works. +- env override replaces JSON group targets. +- env model route can reference JSON group. +- invalid JSON group route fails clearly. +- existing env-only routing tests still pass. + +## Pricing Config + +Add env pricing support in `CostCalculator` or a small helper: + +- `MODEL_PRICE_{PROVIDER}_{MODEL}_INPUT` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_OUTPUT` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_CACHE_READ` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_CACHE_WRITE` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_REASONING` + +Add JSON pricing support: + +- `pricing...input` +- `pricing...output` +- `pricing...cache_read` +- `pricing...cache_write` +- `pricing...reasoning` +- optional `currency` + +Precedence: + +- provider explicit `get_model_pricing()` first, because provider code may know exact native pricing. +- env pricing overrides JSON pricing. +- JSON pricing applies before LiteLLM fallback. +- LiteLLM remains last fallback. + +If env parsing is invalid, log warning and ignore that component. + +Tests: + +- JSON pricing calculates buckets. +- env pricing overrides JSON pricing. +- provider explicit pricing remains highest priority. +- missing pricing remains `unavailable`. +- `skip_cost_calculation` still wins over all pricing sources. + +## Streaming Settings + +Add `StreamRuntimeSettings` dataclass: + +- `ttfb_timeout_seconds` +- `stall_timeout_seconds` +- `heartbeat_seconds` +- `trace_metrics` + +Load from JSON `streaming` and env: + +- `STREAM_TTFB_TIMEOUT_SECONDS` +- `STREAM_STALL_TIMEOUT_SECONDS` +- `STREAM_HEARTBEAT_SECONDS` +- `STREAM_TRACE_METRICS` + +Phase 10 should not enforce timeouts by default. + +Minimal runtime integration: + +- `StreamingHandler.wrap_stream()` should read settings and conditionally emit lifecycle metrics only when trace metrics enabled, default true to preserve Phase 8 behavior. +- Do not implement heartbeat injection unless small and obviously safe. If implemented, default disabled and only emit SSE comments during controlled tests. +- Do not abort streams on TTFB/stall by default. If values are configured, trace `stream_stall_detected` when observable, but do not sever client streams unless explicitly `STREAM_STALL_ABORT=true` is introduced and tested. Prefer no abort in Phase 10. + +Tests: + +- JSON/env settings parse. +- env overrides JSON. +- `STREAM_TRACE_METRICS=false` suppresses lifecycle trace passes while SSE output remains unchanged. +- default trace behavior remains current. + +## Field-Cache Config + +If feasible, add parser for JSON field-cache rule declarations without wiring every provider: + +- `field_cache..[]` +- fields: `name`, `source`, `path`, `scope`, `mode`, `target_path`, `max_entries`. + +Keep it as helper-only if integration is too risky: + +- parse into existing `FieldCacheRule` dataclasses. +- providers/Phase 10+ can call it from `get_field_cache_rules()`. + +Tests: + +- valid JSON rule parses. +- invalid paths/rule modes fail clearly. +- provider/model wildcard merge order documented and tested if implemented. + +If the existing field-cache dataclasses do not support all desired fields cleanly, document and defer live provider integration. + +## Provider Metadata Config + +- Support safe non-secret provider metadata only. +- API base URLs can remain existing `*_API_BASE` env vars for now. +- JSON may define provider protocol/adapter names for future use, but Phase 10 should not make untrusted JSON instantiate arbitrary classes. +- Allow only names that already exist in protocol/adapter registries if validation is implemented. +- Do not place API keys, OAuth tokens, or authorization headers in JSON. +- Tests can validate config rejects/ignores secret-looking keys like `api_key`, `authorization`, `access_token`. + +## Config Validation And Diagnostics + +Add a validation result: + +- warnings for unknown sections/keys. +- errors for invalid routing target specs and invalid numeric pricing. +- no credential values in error messages. + +Add a CLI-free helper test; do not add a new CLI unless tiny. + +Add transform trace or startup log only if a config is loaded: + +- log path, loaded sections, warning count. +- no config contents with secrets. + +## `.env.example` Updates + +Add section for Phase 6 routing: + +- `FALLBACK_GROUPS` +- `FALLBACK_GROUP_` +- `MODEL_ROUTE_` +- execution suffixes `@auto`, `@native`, `@custom`, `@litellm_fallback`. + +Add section for Phase 7 provider cooldown: + +- `PROVIDER_COOLDOWN_MIN_SECONDS` +- `PROVIDER_COOLDOWN_DEFAULT_SECONDS` +- `PROVIDER_COOLDOWN_ON_QUOTA` +- transient retry delay/jitter already in defaults but should be documented. + +Add section for Phase 8 streaming: + +- `STREAM_RETRY_ON_REASONING_ONLY` +- `STREAM_TRACE_METRICS` +- `STREAM_TTFB_TIMEOUT_SECONDS` +- `STREAM_STALL_TIMEOUT_SECONDS` +- `STREAM_HEARTBEAT_SECONDS` + +Add section for Phase 9 pricing: + +- `MODEL_PRICE___INPUT` +- `MODEL_PRICE___OUTPUT` +- cache/reasoning variants. + +Add optional JSON config: + +- `LLM_PROXY_CONFIG_FILE=./config/llm-proxy.json` +- note that env overrides JSON and secrets should stay in env/OAuth files. + +Keep docs concise; do not document every old legacy env in Phase 10 unless already present. + +## Test Modernization + +- Do not take on the entire stale broad test suite unless small. +- Phase 10 can add focused tests for config parsing and maintained regression set. +- Existing stale ignored tests should not block completion unless they are tracked and part of maintained runs. + +## Implementation Checkpoints + +1. Add `ExperimentalConfig` loader and validation helpers with tests. +2. Integrate JSON+env routing config merge with existing routing tests. +3. Add JSON/env model pricing support to `CostCalculator` with tests. +4. Add stream runtime settings parsing and `STREAM_TRACE_METRICS` runtime integration with tests. +5. Add field-cache JSON rule parser helper if feasible with tests. +6. Update `.env.example` and optional config reference doc. +7. Run Phase 10 focused tests and Phase 1-9 maintained regressions. +8. Review with `explore` and `explore-heavy`; fix findings; write uncommitted Phase 10 report. + +## Risks And Mitigations + +- Config precedence confusion. Mitigation: tests for defaults, JSON, env override, provider explicit override, and request-level route behavior. +- Accidentally allowing secrets in JSON. Mitigation: validation rejects secret-looking keys in safe config sections. +- Routing merge could alter existing env behavior. Mitigation: env-only tests remain unchanged and env overrides JSON. +- Pricing could become authoritative incorrectly. Mitigation: keep costs advisory and return zero/unavailable when ambiguous. +- Streaming trace toggles could hide useful debug data unexpectedly. Mitigation: default trace metrics remains enabled. +- Over-scoping into a full config framework. Mitigation: only structured config for new experimental features; leave legacy env parsing intact. diff --git a/docs/experimental/phase-10b-config-surface-wiring.md b/docs/experimental/phase-10b-config-surface-wiring.md new file mode 100644 index 000000000..01aef40b4 --- /dev/null +++ b/docs/experimental/phase-10b-config-surface-wiring.md @@ -0,0 +1,55 @@ +# Phase 10b: Config Surface Wiring And Runtime Settings Completion + +## Goal + +Correct the Phase 10 validation finding that the config layer exists but does not wire enough runtime surfaces introduced in corrective phases. The proxy remains `.env`-first, optional JSON config is structured and secret-free, and environment variables override JSON values. + +## Non-Goals + +- Do not create a full application settings framework. +- Do not move secrets into JSON config. +- Do not replace provider credential discovery. +- Do not add security or multi-user config. +- Do not introduce SQLite or durable DB config. +- Do not commit user-facing reports. + +## Implementation Plan + +1. Add retry/cooldown runtime settings. + - Parse provider cooldown, provider backoff, and failure-history settings from JSON and env. + - Preserve existing env var names and defaults. + - Env overrides JSON. + +2. Wire retry/cooldown settings into retry policy. + - Keep imports lazy to avoid startup cycles. + - Preserve existing monkeypatch/env test behavior. + +3. Add Responses store settings config. + - Parse `RESPONSES_STORE_TTL_SECONDS`, `RESPONSES_STORE_MAX_ITEMS`, `RESPONSES_STORE_FAILED`, and `RESPONSES_STORE_IN_PROGRESS`. + - Support JSON under `responses.store`. + - Preserve default behavior. + +4. Wire Responses store settings into proxy startup. + - Construct `ResponsesService(store_settings=get_responses_store_settings())` in the FastAPI app path. + - Direct tests can still inject explicit settings. + +5. Harden field-cache/provider config parser coverage. + - Cover TTL, metadata mode hints, inject insert behavior, invalid rule errors, and secret rejection in new sections. + +6. Update `.env.example`. + - Document active streaming timeout/heartbeat settings. + - Document provider/model cooldown and backoff knobs. + - Document Responses store policy knobs. + - Document provider-reported cost precedence and SSE cost comments. + - Show safe structured-config section names. + +## Acceptance Criteria + +- Optional JSON config can configure retry/cooldown/backoff and Responses storage policies. +- Env vars override JSON for all new settings. +- Existing env-only behavior remains compatible. +- ResponsesService at proxy startup uses configured store settings. +- Streaming/pricing/routing/field-cache config regressions remain passing. +- Secret-like JSON keys are still rejected in all new sections. +- `.env.example` documents Phase 7b-10b runtime knobs accurately. +- Focused tests and dual-agent review pass. diff --git a/docs/experimental/phase-1b-protocol-breadth-operation-model.md b/docs/experimental/phase-1b-protocol-breadth-operation-model.md new file mode 100644 index 000000000..19ea02c03 --- /dev/null +++ b/docs/experimental/phase-1b-protocol-breadth-operation-model.md @@ -0,0 +1,157 @@ +# Phase 1b: Protocol Breadth And Operation Model + +## Goal + +Correct the main Phase 1 audit failure. Phase 1 created a good chat/message protocol foundation, but the starting requirements asked for protocols as the highest priority and for broad protocol coverage. The current protocol layer is still too chat-oriented: it has OpenAI Chat, Anthropic Messages, Gemini, Responses, and LiteLLM fallback, but no explicit operation model for embeddings, image generation, audio transcription, speech/TTS, Ollama-native chat/generate, MCP-style tool gateway calls, or count-token operations. + +## Non-Goals + +- Do not wire every new operation into live FastAPI routes in this corrective slice. This is protocol foundation work. +- Do not replace current embeddings or Anthropic count_tokens routes yet. +- Do not replace existing provider execution, `UsageManager`, `SessionTracker`, routing, or LiteLLM fallback. +- Do not add SQLite or persistent admin/config databases. +- Do not implement full MCP transport/proxy behavior yet; define the protocol shape and operation seam so it is not blocked later. +- Do not fake WebSocket runtime support; keep explicit future transport/capability metadata. +- Do not touch uncommitted phase reports or `docs/issues/`. + +## Current Code Context + +- `src/rotator_library/protocols/types.py` has unified chat/message/request/response/stream dataclasses, but `UnifiedRequest` has no `operation` field and no dedicated multimodal operation carrier fields. +- `ProtocolAdapter` has `supported_transports` but no `supported_operations` or `supports_operation()`. +- Existing concrete protocols are chat/message/generate-content focused. +- `ContentBlock` already has generic `type`, `source`, `raw`, and `extra`, which can support multimodal payloads with careful extension. +- Registry auto-discovery exists and should be reused for new protocol modules. +- Tests exist only for the initial protocol modules. + +## Files To Add + +- `src/rotator_library/protocols/operation.py` +- `src/rotator_library/protocols/openai_embeddings.py` +- `src/rotator_library/protocols/openai_images.py` +- `src/rotator_library/protocols/openai_audio.py` +- `src/rotator_library/protocols/ollama.py` +- `src/rotator_library/protocols/mcp.py` +- `tests/test_protocol_operation_model.py` +- `tests/test_protocol_openai_embeddings.py` +- `tests/test_protocol_openai_images_audio.py` +- `tests/test_protocol_ollama_mcp.py` + +## Files To Touch + +- `src/rotator_library/protocols/types.py` +- `src/rotator_library/protocols/base.py` +- `src/rotator_library/protocols/__init__.py` +- Existing protocol modules where they should declare `supported_operations`. +- Existing protocol tests if serialization expectations need operation fields added. + +## Data Model Additions + +- Add operation constants in a small `operation.py` module, not a rigid enum, so custom/local protocols can introduce operation strings without fighting the core. +- Initial standard operation names: `chat`, `messages`, `responses`, `count_tokens`, `embeddings`, `image_generation`, `image_edit`, `image_variation`, `audio_transcription`, `audio_translation`, `speech`, `ollama_chat`, `ollama_generate`, `mcp`, and `unknown`. +- Add `operation` to `UnifiedRequest`. +- Add flexible operation carrier fields to `UnifiedRequest`: `input`, `modalities`, and `files`. +- Add `operation` to `UnifiedResponse`. +- Add flexible response fields to `UnifiedResponse`: `data` and `content_type`. +- Keep existing fields and defaults chat-compatible so existing adapters do not break. +- Update explicit serialization fields so transform logging can see the new values. + +## Protocol Adapter Additions + +- Add class attribute `supported_operations`. +- Add `supports_operation(operation_name)`. +- Existing protocols should declare their primary operations. +- The base adapter defaults to `operation="unknown"` when no operation is identified. +- LiteLLM fallback stays broad and explicit, not a native implementation claim. + +## New Protocol Modules + +### OpenAI Embeddings + +- Protocol name: `openai_embeddings`. +- Aliases: `embeddings`. +- Operation: `embeddings`. +- Parse/build `/v1/embeddings` style fields: `model`, `input`, `encoding_format`, `dimensions`, and `user`. +- Parse `data[].embedding` and `usage` from responses. +- Preserve unknown fields. + +### OpenAI Images + +- Protocol name: `openai_images`. +- Aliases: `images`, `image_generation`. +- Operations: `image_generation`, `image_edit`, and `image_variation`. +- Parse/build standard OpenAI image fields: `prompt`, `model`, `n`, `size`, `quality`, `style`, `response_format`, `image`, `mask`, and `user`. +- Parse response `data` entries with `url`, `b64_json`, and `revised_prompt`. +- Preserve file/source metadata without reading file contents. + +### OpenAI Audio + +- Protocol name: `openai_audio`. +- Aliases: `audio`, `audio_transcription`, and `speech`. +- Operations: `audio_transcription`, `audio_translation`, and `speech`. +- Parse/build transcription/translation fields: `file`, `model`, `language`, `prompt`, `response_format`, `temperature`, and `timestamp_granularities`. +- Parse/build speech fields: `input`, `model`, `voice`, `response_format`, and `speed`. +- Preserve binary or text provider responses in `raw`; map JSON responses into `data` or `output` when possible. + +### Ollama + +- Protocol name: `ollama`. +- Operations: `ollama_chat`, `ollama_generate`, and compatible `embeddings`. +- Parse/build `/api/chat`, `/api/generate`, and `/api/embeddings` shapes. +- Map `messages`, `prompt`, `options`, `system`, `template`, `context`, `keep_alive`, `format`, and `stream`. +- Parse final and stream chunks while preserving raw context/performance fields. + +### MCP + +- Protocol name: `mcp`. +- Operation: `mcp`. +- Define a JSON-RPC carrier for `initialize`, `tools/list`, `tools/call`, `resources/list`, `resources/read`, `prompts/list`, `prompts/get`, and generic method calls. +- Do not claim a full MCP proxy implementation in this slice. +- Preserve `jsonrpc`, `method`, `params`, `id`, `result`, and `error` fields exactly. + +## Count Tokens + +- Add count-token operation support where it naturally fits, especially Anthropic Messages. +- Avoid duplicating current count-token routes in this slice. +- Ensure `ProtocolAdapter.supports_operation("count_tokens")` can represent a protocol's support. + +## Tests + +- Operation model serialization includes `operation`, `input`, `modalities`, `files`, response `data`, and `content_type`. +- `ProtocolAdapter.supports_operation()` works and remains override-friendly. +- Registry discovers all new protocols. +- Embeddings request/response parse/build round-trip and usage extraction. +- Images request/response parse/build, including edit file/mask references. +- Audio transcription and speech request/response parse/build. +- Ollama chat/generate parse/build and stream parsing. +- MCP request/response/error round-trip. +- Existing protocol regression tests remain clean. + +## Transaction Logging Implications + +- This slice only adds trace-friendly fields and protocol adapters. +- Do not wire live transform logs here. +- Preserve `raw` and `extra` so Phase 2b can log every state. + +## Docstrings And Comments Required + +- Public classes must explain what API shape they model and what remains intentionally future-facing. +- Operation constants must document that strings are intentionally extensible. +- Lossy conversions, especially binary/audio/image payload preservation, need comments explaining that file contents are not inspected and provider-specific metadata stays in `extra`. + +## Commit Checkpoints + +1. Plan doc commit: `docs(experimental): plan protocol breadth correction`. +2. Operation model/types/base adapter commit with tests. +3. Embeddings/images/audio protocol commit with tests. +4. Ollama/MCP/count-token capability commit with tests. +5. Focused protocol regression run. +6. `explore` and `explore-heavy` review against this plan and the source requirements. +7. Fix findings and re-review if substantial. +8. Write uncommitted Phase 1b report. + +## Risks + +- Adding fields to dataclasses can break exact dict assertions. Mitigation: update tests only where serialization now intentionally includes new fields. +- A rigid operation enum could block custom providers. Mitigation: use string constants/helpers, not a closed enum. +- New non-chat adapters could overclaim live support. Mitigation: protocol package supports parse/build only; route/provider wiring comes later. +- MCP can sprawl. Mitigation: implement JSON-RPC carrier shape only, not a full proxy. diff --git a/docs/experimental/phase-1c-protocol-output-operation-guardrails.md b/docs/experimental/phase-1c-protocol-output-operation-guardrails.md new file mode 100644 index 000000000..a9327d3fa --- /dev/null +++ b/docs/experimental/phase-1c-protocol-output-operation-guardrails.md @@ -0,0 +1,59 @@ +# Phase 1c: Protocol Output Correctness And Native Operation Guardrails + +## Goal + +Close the Phase 1/1b third-pass findings that still affect protocol correctness and later native provider behavior. + +## Scope + +- Fix OpenAI Chat formatted usage so it emits public OpenAI-compatible usage fields rather than unified internal usage fields. +- Fix Responses formatted usage so it emits Responses-compatible usage fields rather than unified internal usage fields. +- Promote OpenAI legacy `function_call` into unified `ToolCall` while preserving legacy round-trip shape. +- Add Ollama response formatting that respects mutated unified responses instead of stale raw payloads. +- Enforce protocol operation compatibility in native execution before transport calls. + +## Non-Goals + +- Do not solve every Phase 5 provider issue here beyond the protocol guardrail needed by native execution. +- Do not enable native streaming for priority providers. +- Do not rewrite protocol registry discovery unless a focused test exposes a direct failure. +- Do not touch unrelated dirty files or user-facing reports. + +## Implementation Plan + +1. OpenAI Chat usage formatting. + - Add a helper that converts `Usage` to `prompt_tokens`, `completion_tokens`, `total_tokens`, `prompt_tokens_details`, `completion_tokens_details`, and `cost_details`. + - Ensure `input_tokens`, `output_tokens`, `raw`, and `extra` do not leak into formatted OpenAI usage. + +2. Responses usage formatting. + - Add a helper that converts `Usage` to `input_tokens`, `output_tokens`, `total_tokens`, `input_tokens_details`, `output_tokens_details`, and `cost_details`. + - Preserve cache-write information only as a safe extension inside `input_tokens_details` when present. + +3. Legacy OpenAI `function_call`. + - Parse legacy `function_call` into a unified `ToolCall` when modern `tool_calls` are not present. + - Preserve the raw legacy field in `extra` so formatting emits `function_call` instead of incorrectly upgrading to `tool_calls`. + +4. Ollama response formatting. + - Implement `OllamaProtocol.format_response()` for chat, generate, and embeddings shapes. + - Merge non-core `extra` and usage/timing fields while honoring mutated unified messages/output/data. + +5. Native operation enforcement. + - In `NativeProviderExecutor.execute()` and `stream()`, reject unsupported protocol operations before parse/build/transport. + - Keep errors sanitized and explicit so bad provider/protocol pairings fail early. + +## Tests + +- `tests/test_protocol_openai_chat.py` +- `tests/test_protocol_responses.py` +- `tests/test_protocol_ollama_mcp.py` +- `tests/test_native_provider_executor.py` or `tests/test_request_executor_native_routing.py` +- Full protocol suite plus native routing smoke tests before review. + +## Acceptance Criteria + +- OpenAI Chat formatted responses expose OpenAI-compatible usage fields. +- Responses formatted responses expose Responses-compatible usage fields. +- Legacy OpenAI `function_call` is represented in unified tool calls and round-trips as legacy `function_call`. +- Ollama formatted responses reflect mutated unified state. +- Native execution rejects unsupported protocol operations before network transport. +- Focused tests pass and both `explore` and `explore-heavy` reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/phase-2-transform-logging.md b/docs/experimental/phase-2-transform-logging.md new file mode 100644 index 000000000..77cfbaeec --- /dev/null +++ b/docs/experimental/phase-2-transform-logging.md @@ -0,0 +1,203 @@ +# Phase 2 Plan: Transform-Pass Logging + +## Goal + +Make transaction logging capable of recording every meaningful request, response, and stream state after each transformation pass, without yet replacing runtime execution with the native protocol adapters. This phase creates the durable trace format and wires it into existing request/executor/provider logging points so later protocol, adapter, and field-cache phases can add more trace entries without redesigning observability. + +## Non-Goals + +- Do not route execution through native protocols yet. +- Do not implement field-cache extraction or injection yet. +- Do not add Responses routes yet. +- Do not add full multi-user/security work. +- Do not rewrite provider-specific logging systems. +- Do not change request/response behavior. + +## Current Code Context + +- `TransactionLogger` creates one directory per transaction and writes legacy files such as `request.json`, `request_transformed.json`, `response.json`, `streaming_chunks.jsonl`, and `metadata.json`. +- `RequestContextBuilder` creates `TransactionLogger` and calls `log_request(kwargs)`. +- `RequestExecutor` calls `log_transformed_request(kwargs, context.kwargs)` after request preparation. +- `RequestExecutor` logs non-streaming responses through `log_response(response_data)`. +- `RequestExecutor._transaction_logging_stream_wrapper()` logs parsed chunks and assembles/logs the final streamed response. +- `ProviderLogger` writes provider-level request payloads, raw response chunks, final responses, and provider errors. +- Phase 1 unified types already provide JSON-safe serialization through `serialize_value()`. + +## Files To Add + +- `src/rotator_library/transform_trace.py` +- `tests/test_transform_trace.py` +- `tests/test_transaction_logger_transform_trace.py` + +## Files Likely To Touch + +- `src/rotator_library/transaction_logger.py` +- `src/rotator_library/client/request_builder.py` +- `src/rotator_library/client/executor.py` + +## Trace Model + +`TransformTraceEntry` fields: + +- `sequence`: monotonically increasing integer per writer instance. +- `timestamp_utc`. +- `component`: `client` or `provider`. +- `pass_name`: stable machine-readable name. +- `direction`: `request`, `response`, `stream`, `error`, or `metadata`. +- `stage`: `client`, `protocol`, `adapter`, `provider`, or `final`. +- `protocol`: optional protocol name. +- `provider`: optional provider name. +- `model`: optional model name. +- `credential_id`: optional stable or masked credential identifier only. +- `transport`: `http`, `sse`, `websocket`, or a future string. +- `changed_from_previous`: optional bool. +- `metadata`: JSON-safe dict. +- `data`: JSON-safe sanitized payload. + +`TransformTraceWriter` responsibilities: + +- Maintain a local sequence counter. +- Write compact ordered entries to `transform_trace.jsonl`. +- Optionally write individual request/response snapshots under `transforms/`. +- Avoid per-chunk snapshot files for stream chunks. +- Never raise logging failures into request execution. + +## Log Files + +Existing files stay for compatibility: + +- `request.json` +- `request_transformed.json` +- `response.json` +- `streaming_chunks.jsonl` +- `metadata.json` + +New files: + +- `transform_trace.jsonl` +- `transforms/0001_raw_client_request.json` +- `transforms/0002_prepared_provider_request.json` +- `transforms/NNNN_.json` for non-stream snapshots + +Stream chunks use JSONL entries rather than one file per chunk. + +## Required Pass Names + +Request passes: + +- `raw_client_request` +- `prepared_provider_request` +- `provider_request_payload` + +Response passes: + +- `raw_client_response` +- `final_client_response` +- `provider_final_response` + +Stream passes: + +- `raw_stream_chunk` +- `parsed_stream_chunk` +- `assembled_stream_response` +- `provider_raw_stream_chunk` + +Error passes: + +- `provider_error` +- `transform_log_error` + +## Sanitization + +This is pragmatic trace hygiene, not full security work. + +Redact recursively by key name: + +- `api_key` +- `credential_identifier` +- `authorization` +- `x-api-key` +- `x-goog-api-key` +- `access_token` +- `refresh_token` +- `client_secret` +- `password` +- `secret` +- `token` + +Rules: + +- Redacted value is `"[REDACTED]"`. +- Redact by key only, not by text value. +- Do not hide ordinary model text containing words such as token. +- Always serialize through Phase 1 JSON-safe helpers. + +## Integration Behavior + +- Trace logging is additive and backward compatible. +- Logging failures never fail requests. +- `log_request()` records `raw_client_request`. +- `log_transformed_request()` records `prepared_provider_request` even when legacy transformed file is skipped because payloads compare equal. +- `log_response()` records `final_client_response`. +- A dedicated method can record `raw_client_response` before final normalization when available. +- `log_stream_chunk()` records `parsed_stream_chunk`. +- `_transaction_logging_stream_wrapper()` records `raw_stream_chunk` before JSON parsing and `assembled_stream_response` before final `log_response()`. +- `ProviderLogger.log_request()` records `provider_request_payload`. +- `ProviderLogger.log_response_chunk()` records `provider_raw_stream_chunk`. +- `ProviderLogger.log_final_response()` records `provider_final_response`. +- `ProviderLogger.log_error()` records `provider_error`. + +## Context Design + +- `TransactionLogger` owns a client-side `TransformTraceWriter`. +- `TransactionContext` carries enough trace metadata for providers to create provider-side trace entries without sharing mutable logger objects. +- `ProviderLogger` creates its own writer with `component="provider"`. +- Client and provider sequence numbers are local to their writer instances. Entries include component and timestamps, so Phase 2 does not promise one global order across independent writers. + +## Comments And Docstrings + +- Public trace classes must explain they are observability-only and must not affect request behavior. +- Redaction helpers must explain why redaction is key-based rather than value-based. +- TransactionLogger comments should distinguish legacy files from the new trace ledger. +- Stream comments should explain why stream chunks use JSONL rather than per-chunk snapshots. + +## Tests + +`tests/test_transform_trace.py`: + +- JSON-safe serialization handles dataclasses and non-primitives. +- Redaction recurses through nested payloads. +- Redaction does not redact normal text values. +- Sequence increments per writer. +- Snapshot filenames are stable and sanitized. + +`tests/test_transaction_logger_transform_trace.py`: + +- `log_request()` writes legacy `request.json` and trace entry `raw_client_request`. +- `log_transformed_request()` writes trace entry even when legacy transformed file is skipped because payloads are equal. +- `log_response()` writes `final_client_response`. +- `log_stream_chunk()` writes `parsed_stream_chunk`. +- ProviderLogger writes provider request, chunk, final, and error trace entries. +- Logging disabled writes nothing. + +Regression tests: + +- Phase 1 protocol tests. +- `tests/test_session_tracking.py`. +- `tests/test_selection_engine.py`. + +## Commit Checkpoints + +1. Add transform trace dataclasses, writer, redaction, and tests. +2. Integrate client-side `TransactionLogger` trace entries and tests. +3. Integrate provider-side `ProviderLogger` trace entries and tests. +4. Run focused and regression tests. +5. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 2 report. + +## Risks And Mitigations + +- Stream logs can grow quickly. Use JSONL only for stream chunks. +- Redaction can hide useful fields if too broad. Redact by sensitive key names only. +- Provider/client entry ordering is not globally sequenced in Phase 2. Include component and timestamps. +- Provider SDK objects may not be JSON-native. Always use `serialize_value()`. +- Tests may import an installed package if local `src` is not first on `sys.path`. Keep `tests/conftest.py`. diff --git a/docs/experimental/phase-2b-transform-trace-coverage.md b/docs/experimental/phase-2b-transform-trace-coverage.md new file mode 100644 index 000000000..034874939 --- /dev/null +++ b/docs/experimental/phase-2b-transform-trace-coverage.md @@ -0,0 +1,184 @@ +# Phase 2b: Complete Live Transform-Pass Trace Coverage + +## Goal + +Correct the Phase 2 audit gap. Phase 2 created the trace writer and added important request, response, stream, provider, and error trace entries, but the validation pass found that it still does not capture every meaningful real transformation boundary. Phase 2b makes transform tracing systematic across the live LiteLLM path, provider transforms, native protocol execution, Responses service/streaming, adapter chains, field-cache passes, and stream wrappers without changing request behavior. + +## Non-Goals + +- Do not replace LiteLLM execution in this phase. +- Do not make non-chat protocols live routes here. +- Do not implement field-cache persistence or new field-cache semantics here; only improve logging of existing cache operations. +- Do not implement WebSocket runtime support here. +- Do not add security/multi-user features beyond existing redaction behavior. +- Do not commit user-facing reports. +- Do not touch unrelated dirty `ARCHITECTURE.md`, `STRUCTURE.md`, `.opencode/`, `docs/issues/`, or old phase reports unless explicitly needed. + +## Current Code State + +- `TransformTraceWriter` exists with JSONL and snapshot files. +- `TransactionLogger.log_transform_pass()` and `log_transform_error()` exist and are additive. +- Existing legacy trace entries include `raw_client_request`, `prepared_provider_request`, `final_client_response`, stream chunk entries, provider request/chunk/final/error entries, routing metadata, retry/cooldown metadata, and usage summaries. +- `ProviderTransforms.apply()` mutates request kwargs through built-in transforms, provider hook transforms, model options, and LiteLLM conversion, but it does not emit per-step trace entries. +- `RequestExecutor._prepare_request_kwargs()` calls `ProviderTransforms.apply()`, then `log_transformed_request()` logs only the final prepared provider kwargs. +- Non-streaming LiteLLM/provider responses are logged only after success as `final_client_response`; there is no explicit `raw_provider_response` or `post_usage_normalization_response` distinction. +- Streaming logs `raw_stream_chunk`, `parsed_stream_chunk`, and `assembled_stream_response`, but the live stream handler has additional normalization/usage/error decision work that is not fully trace-covered. +- `NativeProviderExecutor` traces some native passes, but it skips key states: raw native input, parsed unified request, built provider request before adapters, after response adapters, field-cache extraction details, and final formatting boundaries in a complete request/response sequence. +- `run_adapter_chain()` traces `before_adapter_chain` and `after_adapter`, but not a final `after_adapter_chain` summary and not enough metadata to distinguish built-in vs provider-declared adapter chains in live reviews. +- `FieldCacheEngine` logs per-rule operations, but injection/extraction needs start/end summary trace entries so debugging can see when a cache pass happened but no rule matched. +- Responses service currently logs a service create pass, but stream output and bridge conversions are not fully represented as protocol/bridge/final boundaries. + +## Trace Taxonomy + +### Client Request + +- `raw_client_request` +- `pre_provider_transform_request` +- `after_builtin_provider_transform` +- `after_provider_hook_transform` +- `after_provider_model_options` +- `before_litellm_conversion` +- `after_litellm_conversion` +- `prepared_provider_request` + +### LiteLLM / Provider Execution + +- `provider_execution_request` +- `raw_provider_response` +- `post_usage_normalization_response` +- `final_client_response` + +### Native Protocol Request + +- `raw_native_client_request` +- `native_protocol_selected` +- `parsed_native_unified_request` +- `built_native_provider_request` +- `before_adapter_chain` +- `after_adapter` +- `after_adapter_chain` +- `field_cache_injection_start` +- `field_cache_rule_*` +- `field_cache_injection_complete` +- `native_provider_request` + +### Native Protocol Response + +- `raw_native_provider_response` +- `parsed_native_unified_response` +- `formatted_native_response` +- `after_response_adapter_chain` +- `field_cache_extraction_start` +- `field_cache_rule_*` +- `field_cache_extraction_complete` +- `usage_accounting_summary` +- `final_client_response` + +### Native Stream + +- `native_provider_stream_request` +- `raw_native_provider_stream_chunk` +- `parsed_native_stream_event` +- `after_stream_event_adapter_chain` +- `after_field_cache_stream_extraction` +- `formatted_client_stream_event` + +### Responses API + +- `responses_raw_request` +- `responses_parsed_request` +- `responses_bridge_chat_request` +- `responses_bridge_chat_response` +- `responses_stored_response` +- `responses_final_response` +- stream equivalents for created, delta, completed, bridge chunk, and final states where available. + +### Errors + +- Continue using `transform_log_error`. +- Include failed pass name, component, provider, protocol, transport, and sanitized payload. + +## Implementation Plan + +1. Add small local trace helpers where they avoid repeated fragile `if transaction_logger` blocks. + - Avoid global state. + - Avoid dependency cycles. + - Prefer local helpers when a shared helper would create import loops. + +2. Expand `ProviderTransforms.apply()` tracing. + - Add optional keyword-only `transaction_logger`, `credential_id`, `transport`, and `trace_metadata` parameters. + - Trace `pre_provider_transform_request` before any mutation. + - Trace `after_builtin_provider_transform` after each built-in transform that returns a modification string. + - Trace `after_provider_hook_transform` when provider hooks modify or report modifications. + - Trace provider hook errors as `transform_log_error` without changing failure behavior. + - Trace `after_provider_model_options` when options are applied. + - Trace `before_litellm_conversion` and `after_litellm_conversion` around `convert_for_litellm()`. + +3. Pass trace context from `RequestExecutor._prepare_request_kwargs()`. + - Provide transaction logger, provider/model/session/scope/classifier metadata, transport, and stable credential ID where available. + - Add an optional `credential_id` argument rather than changing existing call semantics. + +4. Add raw provider and normalization response trace in non-streaming executor. + - Trace `provider_execution_request` immediately before the provider call after pre-request callback. + - Trace `after_pre_request_callback` if callback changed kwargs. + - Trace `raw_provider_response` immediately after provider call returns. + - Trace `post_usage_normalization_response` after `_normalize_response_usage()` and before returning. + +5. Expand streaming trace coverage. + - Ensure streaming request path gets the same provider-transform traces as non-streaming. + - Keep `raw_stream_chunk`, `parsed_stream_chunk`, and `assembled_stream_response`. + - Add `stream_error_event` for error SSE payloads produced by executor fallback/error handling. + - Add `stream_done_event` for `[DONE]` boundaries with snapshots disabled. + - Do not change emitted SSE bytes. + +6. Expand native provider trace coverage. + - Add `raw_native_client_request` before protocol parse. + - Trace `parsed_native_unified_request` after `protocol.parse_request()`. + - Trace `built_native_provider_request` after `protocol.build_request()` and before adapters. + - Trace `after_request_adapter_chain` after request adapters. + - Trace `parsed_native_unified_response`, `formatted_native_response`, and `after_response_adapter_chain` on response path. + - Add equivalent stream-event boundaries where applicable. + +7. Expand adapter chain summary traces. + - Keep `before_adapter_chain` and per-adapter `after_adapter`. + - Add `after_adapter_chain` with adapter list, stage, and count even when there are no adapters. + - Include `changed_from_previous` where feasible. + - Disable snapshots for stream-event summaries. + +8. Expand field-cache summary traces. + - Log `field_cache_injection_start` / `field_cache_extraction_start` before the pass. + - Log `field_cache_injection_complete` / `field_cache_extraction_complete` after completion. + - Include rule count, matched count, changed count, source, and operation type. + - Keep per-rule trace entries intact. + +9. Expand Responses API and bridge tracing. + - Trace raw request, parsed unified request, bridge chat request, bridge response, stored response, and final response where those boundaries exist. + - Trace emitted streaming response events with stable pass names and snapshots disabled for per-chunk events. + - Preserve route bodies and SSE event ordering. + +10. Tests. + - Extend `tests/test_transaction_logger_transform_trace.py` for new pass names and redaction stability. + - Add or extend tests for `ProviderTransforms.apply()` tracing each stage without changing payload behavior. + - Add native provider executor trace order tests for non-streaming and streaming. + - Add adapter chain summary trace test. + - Add field-cache start/complete trace tests, including no-rules/no-match cases. + - Add Responses service/stream trace tests if existing fixtures can cover it without large integration cost. + - Run protocol regressions because trace serialization touches protocol objects. + +## Commit Checkpoints + +1. `docs(experimental): plan transform trace coverage correction`. +2. Provider transform trace coverage and tests. +3. Executor request/response/stream trace coverage and tests. +4. Native provider/adapter/field-cache trace summaries and tests. +5. Responses trace coverage and tests. +6. Review fixes after `explore` and `explore-heavy`. +7. User-facing Phase 2b report, uncommitted. + +## Risks And Mitigations + +- Trace logging accidentally mutates payloads. Mitigation: pass deep-copied or read-only snapshots and rely on `sanitize_for_trace()`. +- New trace arguments break callers. Mitigation: make all new args keyword-only optional and keep old call sites valid. +- Stream trace output grows too much. Mitigation: keep per-stream-event snapshots disabled and log compact metadata. +- Import cycles. Mitigation: use private local helpers when needed. +- Brittle tests. Mitigation: assert pass presence and causal relative order only where required. diff --git a/docs/experimental/phase-2c-transform-trace-completion.md b/docs/experimental/phase-2c-transform-trace-completion.md new file mode 100644 index 000000000..0de390162 --- /dev/null +++ b/docs/experimental/phase-2c-transform-trace-completion.md @@ -0,0 +1,73 @@ +# Phase 2c: Transform Trace Completion And Redaction Hardening + +## Goal + +Close the Phase 2/2b third-pass tracing findings: Anthropic compatibility coverage, native stream-event adapter tracing, Responses SSE formatting traces, transform-failure traces, final-response ordering, and camelCase secret redaction. + +## Scope + +- Add transform-pass traces around live Anthropic Messages compatibility request/response conversion. +- Add Anthropic streaming conversion traces and explicit upstream close on disconnect or abnormal exit. +- Ensure final client response traces represent the final post-normalization returned payload. +- Run and trace native stream-event adapter chains. +- Trace Responses final SSE formatting boundaries. +- Harden redaction for camelCase secret-bearing keys. +- Emit transform-failure traces for built-in provider transforms and Responses conversion errors. + +## Non-Goals + +- Do not redesign transaction log storage. +- Do not remove legacy request/response JSON logs. +- Do not enable native streaming for priority providers. +- Do not change public API response shapes. + +## Implementation Plan + +1. Anthropic non-stream traces. + - Trace `anthropic_raw_request`, `anthropic_to_openai_request`, `anthropic_openai_response`, `openai_to_anthropic_response`, and `anthropic_final_response`. + +2. Anthropic streaming traces and close safety. + - Trace source OpenAI chunks and emitted Anthropic SSE frames. + - Close upstream stream via `aclose()` / `close()` on disconnect or abnormal exit. + - Trace stream transform errors before emitting the Anthropic error frames. + +3. Provider transform failure traces. + - Wrap built-in provider transforms and emit `transform_log_error` before re-raising. + - Preserve provider-hook behavior but keep sanitized metadata. + +4. Native stream-event adapter chain. + - Run adapter chain for parsed native stream events. + - Trace `after_stream_event_adapter_chain` before formatting. + +5. Responses SSE formatting trace. + - Trace each formatted `ResponsesStreamEvent` frame and terminal frame emitted by `stream_response()`. + +6. Final response trace ordering. + - Verify or adjust `final_client_response` ordering so it occurs after usage normalization/cost accounting. + +7. Redaction hardening. + - Normalize camelCase keys such as `apiKey`, `accessToken`, `refreshToken`, `clientSecret`, and `idToken`. + +8. Responses transform-failure traces. + - Emit standardized transform error traces for Responses parse/bridge/storage conversion errors without changing raised errors. + +## Tests + +- Anthropic non-stream and stream trace tests. +- Anthropic disconnect upstream close test. +- Built-in provider transform failure trace test. +- Native stream-event adapter mutation/trace test. +- Responses SSE formatting trace test. +- CamelCase redaction tests. +- Final response trace-ordering tests. + +## Acceptance Criteria + +- Anthropic compatibility has transform-pass coverage for request conversion, response conversion, and stream conversion. +- Anthropic stream disconnect closes upstream when possible. +- Built-in provider transform exceptions emit transform-failure traces. +- Native stream-event adapter chains are executed and traced. +- Responses stream SSE formatting boundaries are traced. +- Final client response traces reflect post-normalization final payloads. +- CamelCase secret-bearing keys are redacted. +- Focused tests pass and both `explore` and `explore-heavy` reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/phase-3-adapters-field-cache.md b/docs/experimental/phase-3-adapters-field-cache.md new file mode 100644 index 000000000..d7fbf53cf --- /dev/null +++ b/docs/experimental/phase-3-adapters-field-cache.md @@ -0,0 +1,284 @@ +# Phase 3 Plan: Adapter And Field-Cache System + +## Goal + +Add the configurable adapter and field-cache foundation that lets protocol/provider implementations preserve and re-inject provider-specific state without hardcoding every provider. This phase should make it possible to define rules for values like reasoning content, thought signatures, prompt cache keys, provider session IDs, and response IDs, with strict provider/model/credential/session/classifier scoping and transform trace visibility. + +## Non-Goals + +- Do not migrate all existing providers to native protocols yet. +- Do not add `/v1/responses` routes yet. +- Do not replace `ProviderTransforms` yet. +- Do not implement JSON config loading for these rules yet; Phase 10 owns config polish. +- Do not instantiate SQLite or any DB. +- Do not implement provider-specific Claude/Codex/Copilot/Antigravity behavior yet. +- Do not make field cache a replacement for `SessionTracker`. + +## Current Code Context + +- `ProviderTransforms` is still the active runtime transform path and should remain untouched unless a small hook is needed. +- Phase 1 provides unified request/response/stream dataclasses and protocol adapters. +- Phase 2 provides `log_transform_pass()` and `log_transform_error()` for request/response/stream trace states. +- `ProviderCache` already exists and supports async `store_async()` / `retrieve_async()` with memory+disk TTL behavior, but it stores strings and starts async background tasks on construction. +- `RequestContext` has session, scope, classifier, provider, model, and credential metadata needed for cache isolation. +- Existing runtime behavior should stay stable while the new adapter/cache foundation is built and tested in isolation. + +## Files To Add + +- `src/rotator_library/adapters/__init__.py` +- `src/rotator_library/adapters/base.py` +- `src/rotator_library/adapters/registry.py` +- `src/rotator_library/adapters/builtin.py` +- `src/rotator_library/field_cache/__init__.py` +- `src/rotator_library/field_cache/types.py` +- `src/rotator_library/field_cache/paths.py` +- `src/rotator_library/field_cache/store.py` +- `src/rotator_library/field_cache/engine.py` +- `tests/test_adapter_registry.py` +- `tests/test_field_cache_paths.py` +- `tests/test_field_cache_engine.py` +- `tests/test_field_cache_trace.py` + +## Files Possibly Touched + +- `src/rotator_library/providers/provider_interface.py` to add optional provider declarations for adapter names and field-cache rules. +- `src/rotator_library/client/transforms.py` only if adding a no-op future hook is cleaner. +- `src/rotator_library/__init__.py` only if public lazy exports are useful. +- Avoid `RequestExecutor` runtime wiring in the first adapter/cache commits unless tests prove the hook is harmless and necessary. + +## Adapter System + +- Create a protocol-neutral `Adapter` base class. +- Adapters operate on raw dicts, unified dataclasses, or stream events depending on `supported_stages`. +- Adapter methods are override-friendly: + - `transform_request(payload, context) -> payload` + - `transform_response(payload, context) -> payload` + - `transform_stream_event(payload, context) -> payload` +- Adapter context should include provider, model, protocol, credential_id, session_id, scope_key, classifier, transport, metadata, and optional `transaction_logger`. +- Registry should support: + - explicit registration + - alias resolution + - auto-discovery from `src/rotator_library/adapters/` + - duplicate/collision checks similar to protocol registry + - order-preserving adapter chain resolution + +Built-in adapters: + +- `reasoning_rewrite`: copy/rename reasoning fields between configured paths. +- `reasoning_content`: normalize common reasoning content field names. +- `suppress_developer_role`: convert or remove developer-role messages for providers that reject it. +- `model_override`: replace model in outbound payload. +- `field_rename`: generic configured source-path to target-path copy/move. + +These adapters are bases, not gospel. Providers can subclass/copy/mutate adapters or override provider methods. + +## Field-Cache Rules + +`FieldCacheRule` fields: + +- `name` +- `source`: `request`, `response`, `stream_event`, `unified_request`, `unified_response`, `unified_stream_event` +- `path`: JSON-path-like selector +- `mode`: `last`, `all`, `last_user_turn`, `last_assistant_turn`, `per_tool_call` +- `scope`: list/tuple of scope dimensions +- `inject`: optional `FieldCacheInjection` +- `enabled` +- `ttl_seconds` optional, for future store implementations +- `metadata` for provider/model notes + +`FieldCacheInjection` fields: + +- `target`: `request`, `unified_request`, `metadata`, or provider-specific future target +- `path` +- `when_missing_only` +- `insert` +- `as_list` + +Scope dimensions: + +- `provider` +- `model` +- `credential` +- `session` +- `conversation` +- `classifier` + +Default scope for conversation-affecting rules: + +- provider + model + classifier + session. + +Cache keys must include rule name and selected scope values. Missing optional values should use stable placeholders only where safe; for session-scoped rules with no session, default to no-op unless a rule explicitly allows fallback. + +## Path Engine + +Support dotted paths: + +- `choices.0.message.reasoning_content` +- `choices.*.message.reasoning_content` +- `messages[-1].reasoning_content` +- `candidates.*.content.parts.*.thoughtSignature` + +Behavior: + +- Supports dict keys, list indices, `*`, and `[-1]`. +- Extraction returns all matches in stable traversal order. +- Injection creates missing dict containers when unambiguous. +- Injection does not create containers through wildcard paths. +- Malformed paths raise `FieldCachePathError` with useful messages. +- Missing paths are no-op, not errors. + +## Store Plan + +- `FieldCacheStore` protocol/interface with async `get`, `set`, `append`, and `clear`. +- `InMemoryFieldCacheStore` for tests and simple runtime. +- `ProviderCacheFieldStore` wrapper around an injected `ProviderCache`, JSON serializing values. +- Do not instantiate `ProviderCache` in module import paths because it starts async tasks. +- No SQLite. + +## Engine Plan + +`FieldCacheEngine` responsibilities: + +- hold rules and store +- validate rule names and paths +- `extract(source, payload, context, transaction_logger=None)` +- `inject(target, payload, context, transaction_logger=None)` + +Extraction: + +- select rules matching source +- extract values by path +- apply mode +- store under scoped key +- trace `after_field_cache_extraction` with rule name, match count, mode, scope key, and sanitized values summary + +Injection: + +- select rules with matching injection target +- retrieve scoped values +- apply mode result +- inject into a payload copy by path +- trace `after_field_cache_injection` with rule name, hit/miss, target path, and changed flag + +Engine rules: + +- Do not mutate caller payload unless explicitly requested; default returns a deep copy. +- Failures should call `log_transform_error()` with failed rule/pass and then raise validation errors in tests. Runtime integration later can choose fail-open or fail-closed per provider. + +## Modes + +- `last`: store latest extracted value. +- `all`: append extracted values preserving order. +- `last_user_turn`: use metadata/turn marker when available; for raw messages fallback to latest user-associated match if path includes messages. +- `last_assistant_turn`: same for assistant. +- `per_tool_call`: requires a tool-call ID path or extracted object with obvious `id` / `tool_call_id`; otherwise validation error. + +Phase 3 should implement `last` and `all` fully, and provide validated structure for turn/tool modes with tests for validation and simple cases. If a mode needs runtime conversation indexing not yet available, document it as present but limited. + +## Transform Trace Requirements + +Every adapter invocation should be traceable: + +- `before_adapter_chain` +- `after_adapter` +- `after_adapter_chain` + +Every field-cache injection/extraction should be traceable: + +- `before_field_cache_injection` +- `after_field_cache_injection` +- `before_field_cache_extraction` +- `after_field_cache_extraction` + +Errors use Phase 2 `transform_log_error`. + +Trace metadata should include: + +- adapter name +- rule name +- source/target +- path +- mode +- scope dimensions +- cache key hash or readable scoped key with no secrets +- match count +- changed flag + +Do not log huge cached payloads by default; log summaries and sanitized sample values. + +## Provider Declaration Plan + +Add optional provider methods/class attributes only if needed: + +- `protocol_name` +- `adapter_names` +- `field_cache_rules` +- `get_adapter_config(model)` +- `get_field_cache_rules(model)` + +Keep defaults empty/no-op to avoid provider changes. These declarations are for later provider migration. + +## Tests + +Adapter registry: + +- auto-discovers built-ins +- aliases resolve +- duplicate names/collisions are deterministic +- chain order is preserved +- no-op adapter preserves payload + +Built-in adapters: + +- `model_override` +- `suppress_developer_role` +- `field_rename` +- reasoning field copy/rename + +Path engine: + +- dict path extraction +- list index extraction +- wildcard extraction +- `[-1]` extraction +- injection into simple dict path +- missing path no-op +- wildcard injection rejected +- malformed path error + +Field-cache engine: + +- extract response value and inject into next request +- `last` overwrites +- `all` appends +- scope isolation by provider/model/session/classifier/credential +- missing path no-op +- malformed rule validation +- stream_event extraction +- trace entries emitted for injection/extraction +- transform errors emitted for rule failures + +Regression: + +- Phase 1 protocol tests. +- Phase 2 logging tests. +- core session/selection tests. + +## Commit Checkpoints + +1. Add adapter base/registry and built-in no-op/simple adapters with tests. +2. Add field-cache rule dataclasses and path engine with tests. +3. Add field-cache stores and scoped key builder with tests. +4. Add field-cache engine extraction/injection with trace tests. +5. Add optional provider declaration methods if needed. +6. Run focused and regression tests. +7. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 3 report. + +## Risks And Mitigations + +- JSON path scope could grow too large. Keep Phase 3 path syntax intentionally small and explicit. +- Cache leaks across providers/models/sessions would be severe. Scope-key tests must be extensive. +- Storing huge model outputs can bloat caches. Engine should store only matched values and trace summaries. +- ProviderCache lifecycle can be awkward because it starts async tasks. Wrap injected instances; do not instantiate globally. +- Turn/tool modes can be under-specified. Validate and implement safe subsets rather than guessing. +- Runtime integration can destabilize requests. Keep Phase 3 mostly isolated until review confirms the foundation. diff --git a/docs/experimental/phase-3b-field-cache-runtime-semantics.md b/docs/experimental/phase-3b-field-cache-runtime-semantics.md new file mode 100644 index 000000000..d65494933 --- /dev/null +++ b/docs/experimental/phase-3b-field-cache-runtime-semantics.md @@ -0,0 +1,147 @@ +# Phase 3b: Field-Cache Runtime Semantics And Persistence + +## Goal + +Correct the Phase 3 audit gap. Phase 3 built the adapter and field-cache foundation, but the audit found that the system is still partly declarative: `last_user_turn`, `last_assistant_turn`, and `per_tool_call` are validated/declared but not semantically implemented enough; the native provider executor creates a fresh `InMemoryFieldCacheStore` by default so cached values do not persist across native requests; JSON-configured field-cache rules are parsed but not merged into runtime provider declarations. + +## Non-Goals + +- Do not introduce SQLite or any database. +- Do not replace `SessionTracker`, `UsageManager`, or provider retry logic. +- Do not wire non-native LiteLLM paths through field-cache rules in this phase. +- Do not make field cache a substitute for session affinity. It preserves provider protocol state only. +- Do not implement a full external admin UI/config editor. +- Do not commit user-facing reports. +- Do not touch unrelated dirty files (`ARCHITECTURE.md`, `STRUCTURE.md`, `.opencode/`, `docs/issues/`, old phase reports). + +## Current Code State + +- `FieldCacheRule.mode` allows `last`, `all`, `last_user_turn`, `last_assistant_turn`, and `per_tool_call`. +- `FieldCacheEngine._store_values()` currently treats `last_user_turn` and `last_assistant_turn` exactly like `last`. +- `per_tool_call` currently extracts tool IDs from each extracted value, which only works when the selected value itself contains the tool ID. It cannot correlate values with a sibling or parent tool-call ID from the original payload. +- `FieldCacheInjection` has `as_list`, `when_missing_only`, and `insert`, but engine injection does not use `insert` and does not support selecting a per-tool-call value at injection time. +- `NativeProviderExecutor.__init__(field_cache_store=None)` passes `None` into each `FieldCacheEngine`, causing each request to get a new `InMemoryFieldCacheStore`. That means default native field-cache persistence is per-engine/per-request, not process-level. +- `ProviderCacheFieldStore` exists and no SQLite is present. +- `ExperimentalConfig.parse_field_cache_rules()` can parse JSON rules, but `RequestExecutor._native_context()` currently only uses `plugin.get_field_cache_rules(model)`. +- `NativeProviderContext.field_cache_context()` includes provider/model/credential/session/conversation/classifier plus metadata; metadata is the right place to carry current request payload/turn/tool-call hints without expanding core scope dimensions. +- Phase 2b hardened traces so field-cache logs expose shapes and counts, not raw cached values. + +## Implementation Plan + +### 1. Process-Local Default Native Store + +- Change `NativeProviderExecutor` so a missing `field_cache_store` creates one shared `InMemoryFieldCacheStore` in the executor instance, not a new store inside every `FieldCacheEngine`. +- Keep injected stores honored exactly as before. +- Add tests proving two calls through the same executor can extract on response and inject on a later request using the default store. +- Do not persist to disk by default; process-local is enough for runtime continuity and respects the no-SQLite rule. + +### 2. Value Envelopes For Contextual Modes + +- Add a small internal stored representation for contextual modes, likely dicts with fields such as `value`, `role`, `turn_index`, `tool_call_id`, and `source_path`. +- Keep stored values JSON-serializable for `ProviderCacheFieldStore`. +- Do not expose envelopes to injected payloads except where needed; injection should unwrap `value`. +- Preserve current external behavior for `last` and `all` as much as possible. + +### 3. Turn-Aware Extraction + +- Implement real `last_user_turn` and `last_assistant_turn` semantics. +- Minimum supported shapes: + - OpenAI/Responses-like `messages[*]` objects with `role`. + - Anthropic-like content lists nested under message objects with `role`. + - Generic payloads where a rule provides metadata hints. +- Add rule metadata hints: + - `turn_container_path`: path to turn/message list, default inferred from common roots such as `messages`. + - `turn_role_path`: role field inside each turn, default `role`. + - `turn_value_path`: value path relative to each turn when inference is better than global path. +- If no turn context can be inferred, skip with `reason="turn_context_not_found"` rather than silently behaving like `last`. +- Store only values from the latest matching user/assistant turn. +- Add tests for user turn, assistant turn, skip/no-op, and metadata-configured relative extraction. + +### 4. Per-Tool-Call Correlation And Injection + +- Keep requiring `metadata.tool_call_id_path` for `per_tool_call`. +- Support existing value-relative extraction when each selected value contains the tool ID. +- Add payload-relative correlation with metadata: + - `tool_container_path` + - `tool_call_id_path` relative to each tool container + - `tool_value_path` relative to each tool container +- Store a mapping of `tool_call_id -> value`. +- Add injection selection: + - If `metadata.inject_tool_call_id_path` exists, extract current tool IDs from target payload and inject matching values. + - If `FieldCacheContext.metadata["tool_call_id"]` exists, inject that value. + - If `inject.as_list=True`, inject all matching values as a list. + - Otherwise skip ambiguous maps with `reason="tool_call_id_not_found"` or `reason="ambiguous_tool_call_values"`. +- Add tests for sibling tool ID/value extraction and target-specific injection. + +### 5. Honor `FieldCacheInjection.insert` + +- Wire `rule.inject.insert` into `inject_path()`. +- Add tests for list insertion where path targets a list index or append-like position if supported by the path engine. +- If the path engine does not support safe insertion yet, add minimal support or explicitly validate/reject unsafe insert paths with clear errors. + +### 6. TTL Handling Without New Persistence + +- `FieldCacheRule.ttl_seconds` exists but stores ignore it. +- Add TTL support to `InMemoryFieldCacheStore`. +- Keep `ProviderCacheFieldStore` compatible: + - If the injected provider cache supports TTL arguments, pass them. + - If not, wrap values with expiry metadata and enforce expiry on `get()`. +- Update `FieldCacheStore` protocol with optional TTL on `set`/`append` only if safe; otherwise add internal engine helper methods to avoid breaking third-party stores. +- Add tests for memory-store expiry using a fake clock if possible. +- No SQLite. + +### 7. Merge JSON-Configured Rules Into Native Runtime Context + +- In `RequestExecutor._native_context()`, load optional `ExperimentalConfig` and merge `parse_field_cache_rules(config, provider, model)` with provider-declared rules. +- Provider-declared rules should come first; JSON rules append and can add/override by rule name. +- Override policy: + - If a JSON rule name matches a provider rule name, JSON replaces that rule. + - Otherwise append. +- Environment remains primary for config path and JSON secrets remain rejected by the existing loader. +- Add tests for native context rule merge without real network calls or credentials. +- Avoid top-level imports that create cycles; use local imports where needed. + +### 8. Tests + +- Engine semantics tests for contextual modes. +- Native executor test proving default process-local persistence across separate `execute()` calls. +- Config merge test for JSON field-cache rules. +- TTL store tests. +- Existing trace and protocol regression tests where touched. + +Focused suites: + +- `tests/test_field_cache_paths.py` +- `tests/test_field_cache_engine.py` +- `tests/test_field_cache_trace.py` +- `tests/test_native_provider_executor.py` +- `tests/test_experimental_config.py` +- relevant provider declaration tests +- Phase 1b/2b regression subsets if field-cache/protocol trace serialization is touched + +### 9. Documentation And Comments + +- Update docstrings in `FieldCacheEngine`, `FieldCacheRule`, and store classes to explain mode semantics. +- Document when turn/tool modes skip rather than fallback. +- Document process-local default store vs injected production store. +- Keep comments focused on algorithmic rationale: turn context inference, tool-call correlation, TTL handling, and JSON rule merge precedence. + +## Commit Checkpoints + +1. `docs(experimental): plan field cache runtime correction`. +2. Default persistent native store. +3. Turn/tool-call mode semantics. +4. TTL/store behavior. +5. JSON rule merge/runtime wiring. +6. Review fixes after `explore` and `explore-heavy`. +7. User-facing Phase 3b report, uncommitted. + +## Acceptance Criteria + +- Native field cache persists across requests by default within one executor instance. +- `last_user_turn` and `last_assistant_turn` have real turn-aware semantics and skip safely when context is unavailable. +- `per_tool_call` can correlate tool IDs with sibling values and inject the right cached value by current tool ID. +- `insert`, `ttl_seconds`, and JSON-configured rules are implemented or explicitly validated with clear errors if a requested shape is unsafe. +- No raw cached values leak in traces after Phase 2b hardening. +- No SQLite or new database is introduced. +- Focused tests and dual-agent review pass. diff --git a/docs/experimental/phase-3c-field-cache-runtime-completion.md b/docs/experimental/phase-3c-field-cache-runtime-completion.md new file mode 100644 index 000000000..4c6bb01db --- /dev/null +++ b/docs/experimental/phase-3c-field-cache-runtime-completion.md @@ -0,0 +1,66 @@ +# Phase 3c: Field-Cache Runtime Completion And Trace Safety + +## Goal + +Close the Phase 3/3b third-pass findings around field-cache runtime coverage, adapter-chain trace safety, and credential-scoped isolation. + +## Scope + +- Execute all declared field-cache sources used by native runtime: `request`, `response`, `stream_event`, `unified_request`, `unified_response`, and `unified_stream_event`. +- Execute declared injection targets used by native runtime: `request`, `unified_request`, and `metadata`. +- Suppress generic adapter-chain payload traces in native execution where field-cache-aware redaction has not yet run, and rely on native executor traces that apply rule-aware redaction. +- Fail closed when a rule includes `credential` scope but the runtime lacks a credential identifier. +- Add focused tests for each corrected behavior. + +## Non-Goals + +- Do not replace the field-cache store or add a database. +- Do not make native streaming inject into stream events; streaming extraction remains expected behavior. +- Do not make Copilot declare field-cache rules when none are required. +- Do not implement custom provider-specific turn inference beyond existing metadata hints. + +## Implementation Plan + +1. Credential scope isolation. + - Update cache-key construction so missing `credential` scope returns `None`. + - Keep existing `allow_missing_session` behavior only for `session` scope. + - Add tests proving credential-scoped extraction/injection skips instead of using a shared missing bucket. + +2. Unified request injection and extraction. + - In native non-streaming and streaming paths, create the field-cache engine before protocol build. + - Extract `unified_request` after parsing client payload. + - Inject `metadata` before building protocol context when metadata rules exist. + - Inject `unified_request` before `protocol.build_request()`. + - Hydrate injected serialized unified-request dictionaries back into the current `UnifiedRequest` for supported common fields. + - Trace `after_unified_request_field_cache_injection` and `after_metadata_field_cache_injection`. + +3. Unified response extraction. + - Extract `unified_response` after protocol response parsing and before formatting. + - Keep existing provider-response extraction after response adapter chain. + - Trace `after_unified_response_field_cache_extraction`. + +4. Unified stream-event extraction. + - Extract `unified_stream_event` after native stream-event adapter chain and before formatting. + - Keep existing provider stream-event extraction from formatted event payload. + - Trace `after_unified_stream_event_field_cache_extraction`. + +5. Native adapter trace safety. + - Disable generic adapter-chain traces inside native request/response paths, matching the stream-event path. + - Keep native executor traces after adapter chains, where configured field-cache paths are redacted. + - Add tests with arbitrary configured provider-state paths proving generic `before_adapter_chain`/`after_adapter` traces are not emitted in native execution and native traces redact configured paths. + +6. Tests. + - Field-cache engine credential-scope fail-closed tests. + - Native executor unified request source/target tests. + - Native executor metadata injection tests. + - Native executor unified response extraction tests. + - Native stream unified event extraction tests. + - Native adapter trace redaction/suppression tests. + +## Acceptance Criteria + +- Native runtime executes all declared field-cache sources/targets that it accepts, except stream-event injection which remains intentionally unsupported. +- Missing credential scope never shares a `_none` cache key. +- Native adapter chain traces cannot leak arbitrary configured provider-state fields before rule-aware redaction. +- Existing field-cache and native provider tests continue to pass. +- Both `explore` and `explore-heavy` reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/phase-4-responses-api.md b/docs/experimental/phase-4-responses-api.md new file mode 100644 index 000000000..901510066 --- /dev/null +++ b/docs/experimental/phase-4-responses-api.md @@ -0,0 +1,285 @@ +# Phase 4 Plan: Responses API + +## Goal + +Add a first-class OpenAI-compatible Responses API surface with durable response storage, `previous_response_id` continuation support, HTTP SSE streaming, and an explicit WebSocket transport seam for future implementation. Phase 4 should use the Phase 1 Responses protocol, Phase 2 transform trace logging, and Phase 3 field-cache/adapter foundations without forcing all providers onto native protocols yet. + +## Non-Goals + +- Do not migrate every provider to native protocol execution in this phase. +- Do not remove or replace `/v1/chat/completions`. +- Do not replace `UsageManager`, `SessionTracker`, retry-after parsing, or current credential rotation. +- Do not implement full multi-user/security controls. +- Do not add SQLite or any database. +- Do not implement full provider-specific Responses APIs for Claude Code/Codex/Copilot yet. +- Do not implement WebSocket runtime handling yet; provide transport interfaces/seams and tests for HTTP SSE behavior. + +## Current Code Context + +- FastAPI routes live in `src/proxy_app/main.py`. +- `RotatingClient.acompletion()` is the stable execution entry point today. +- `ResponsesProtocol` already parses/formats Responses request/response/stream shapes and declares future `websocket` support. +- `TransactionLogger` can record new transform passes and errors. +- `FieldCacheEngine` and `ProviderCacheFieldStore` can preserve response IDs/output items but need targeted key deletion for Responses delete endpoints. +- Existing stream wrappers aggregate chat-completions SSE chunks, but Responses SSE needs its own event formatting instead of chat completion chunk output. +- Runtime execution should stay conservative: Phase 4 can bridge Responses requests to the existing completion path until native provider protocols are wired in later phases. + +## Files To Add + +- `src/rotator_library/responses/__init__.py` +- `src/rotator_library/responses/types.py` +- `src/rotator_library/responses/store.py` +- `src/rotator_library/responses/bridge.py` +- `src/rotator_library/responses/service.py` +- `src/rotator_library/responses/streaming.py` +- `tests/test_responses_store.py` +- `tests/test_responses_bridge.py` +- `tests/test_responses_service.py` +- `tests/test_responses_routes.py` +- `tests/test_responses_streaming.py` + +## Files Likely To Touch + +- `src/proxy_app/main.py` +- `src/rotator_library/client/rotating_client.py` +- `src/rotator_library/field_cache/store.py` if targeted delete is added to the store protocol. +- `src/rotator_library/__init__.py` only if lazy public exports are useful. +- Existing protocol tests only if the Responses adapter needs an additive event/output fix. + +## Route/API Scope + +`POST /v1/responses`: + +- Accept raw Responses JSON. +- Support non-streaming response. +- Support `stream: true` with HTTP SSE. +- Validate required `model`. +- Accept `input`, `instructions`, `tools`, `tool_choice`, `reasoning`, `metadata`, `include`, `store`, `previous_response_id`, and common generation params. + +`GET /v1/responses/{response_id}`: + +- Return stored response object when available. +- Return 404-compatible error if missing or deleted. + +`DELETE /v1/responses/{response_id}`: + +- Delete stored response metadata/output. +- Return `{ "id": ..., "object": "response.deleted", "deleted": true }` or equivalent compatible shape. + +`GET /v1/responses/{response_id}/input_items`: + +- Return stored input items for clients that need continuation inspection. + +Optional if simple: + +- `POST /v1/responses/{response_id}/cancel` for active stream cancellation can be planned but not fully implemented unless active stream tracking lands cleanly. + +## Response Storage + +`StoredResponse` fields: + +- `id` +- `created_at` +- `model` +- `status` +- `request` +- `response` +- `input_items` +- `output_items` +- `usage` +- `metadata` +- `session_id` +- `scope_key` +- `classifier` +- `expires_at` optional + +`ResponsesStore` protocol: + +- `save(stored_response)` +- `get(response_id)` +- `delete(response_id)` +- `list_input_items(response_id)` + +Store implementations: + +- `InMemoryResponsesStore` for tests and runtime default if no persistent cache is configured. +- `ProviderCacheResponsesStore` as an optional wrapper around injected `ProviderCache`. +- JSON serialization through `serialize_value()`. +- No global `ProviderCache` construction at import time. +- No SQLite. + +Targeted deletion: + +- Add `delete(key)` to `FieldCacheStore` only if reused directly. +- Prefer a dedicated Responses store so field cache remains small and focused. + +## Response IDs + +- Generate IDs like `resp_` when upstream response lacks an ID. +- Preserve upstream IDs when present. +- Store every response when request `store` is true or omitted if compatibility expects default storage. +- If `store: false`, return the response but do not persist it; `previous_response_id` cannot refer to it later. +- Include `previous_response_id` in stored metadata for lineage/debugging. + +## `previous_response_id` + +- On request with `previous_response_id`, load the stored parent response. +- If not found, return a 404/400-compatible invalid request error rather than silently ignoring. +- Build continuation context conservatively: + - Keep original current request input. + - Add parent response output items into protocol/service metadata for traceability. + - For bridge execution through chat completions, convert parent response message items into prior assistant messages where safe. + - Do not inject unknown output item types into chat messages unless `ResponsesProtocol` can represent them safely. +- Record transform trace pass: + - `responses_previous_response_loaded` + - Include parent ID, output count, input item count, and whether bridge context was expanded. +- This should line up with Phase 3 field-cache storage but does not need to fully rely on field-cache rules yet. + +## Bridge Execution + +`ResponsesBridge` responsibilities: + +- Convert `UnifiedRequest` from `ResponsesProtocol.parse_request()` into chat-completions kwargs for current `RotatingClient.acompletion()`. +- Map `model` to `model`. +- Map `instructions`/system blocks to a system message. +- Map input messages to chat messages. +- Map tool definitions to OpenAI tool shape when possible. +- Map generation params to existing compatible kwargs. +- Map `stream` to `stream`. +- Preserve unsupported Responses fields in trace metadata and request metadata, not silently discard them. +- Rebuild Responses output from chat completion responses using `ResponsesProtocol.format_response()`. + +This bridge is temporary compatibility, not the final native provider path. Docstrings/comments should explain that native provider execution will replace the bridge for covered providers in later phases. + +## Service + +`ResponsesService` responsibilities: + +- Own `ResponsesProtocol`, `ResponsesBridge`, `ResponsesStore`, and optional transform logging. +- `create_response(raw_request, client, request=None)` +- `stream_response(raw_request, client, request=None)` +- `get_response(response_id)` +- `delete_response(response_id)` +- `list_input_items(response_id)` + +Service transform passes: + +- `raw_responses_request` +- `parsed_unified_request` +- `responses_previous_response_loaded` +- `responses_bridge_chat_request` +- `raw_responses_provider_response` or `raw_chat_bridge_response` +- `parsed_unified_response` +- `stored_responses_response` +- `final_responses_response` + +Errors use `log_transform_error()`. + +## Streaming + +- HTTP SSE first. +- Convert chat-completions stream chunks into Responses SSE events. +- Required event names: + - `response.created` + - `response.output_item.added` + - `response.output_text.delta` + - `response.output_item.done` + - `response.completed` + - `response.failed` on errors + - final `data: [DONE]` only if compatibility requires it; prefer Responses-style events and document the chosen behavior. +- Preserve raw stream chunk trace: + - `raw_chat_bridge_stream_chunk` + - `parsed_unified_stream_event` + - `formatted_responses_stream_event` + - `stored_responses_stream_response` +- Accumulate final output items so streamed responses are retrievable by `GET /v1/responses/{id}` after completion. +- Handle client disconnect without crashing the server. +- Do not add the broader stream stall/backpressure overhaul; that belongs to the later streaming phase. + +## WebSocket Seam + +- Add transport-neutral interfaces: + - `ResponsesTransport` or small formatter abstraction. + - `ResponsesSSEFormatter`. + - `ResponsesWebSocketFormatter` placeholder or protocol with `NotImplementedError`. +- Keep protocol/service logic independent from `StreamingResponse`. +- Add tests that formatter/service API can select `sse` now and recognizes `websocket` as a future transport without implementing a WebSocket route. +- Do not expose a WebSocket endpoint yet unless it is a no-op documented placeholder; prefer no route over a misleading route. + +## FastAPI Integration + +- Add routes near existing OpenAI-compatible routes in `main.py`. +- Use existing `verify_api_key`. +- Use existing raw request logging where applicable. +- Use `JSONResponse` for normal responses. +- Use `StreamingResponse(..., media_type="text/event-stream")` for streams. +- Use OpenAI-compatible error response shape for validation/missing response IDs. +- Store service on app state during lifespan if it needs shared storage, or construct module-level in-memory store carefully if no async lifecycle is needed. Prefer `app.state.responses_service`. + +## Tests + +Store tests: + +- save/get/delete response +- input items list +- missing response returns None +- JSON serialization with provider-cache wrapper if implemented +- no SQLite/import-time cache construction + +Bridge tests: + +- Responses input string/list converts to chat messages +- instructions become system message +- tool definitions preserve function call shape +- unsupported fields are preserved in metadata/trace +- chat completion response formats back to Responses output + +Service tests: + +- non-stream create stores response +- `store: false` does not persist +- `previous_response_id` loads parent +- missing previous response raises useful error +- transform trace passes emitted + +Route tests: + +- `POST /v1/responses` non-stream success with fake client +- `POST /v1/responses` missing model returns 400 +- `GET /v1/responses/{id}` success and 404 +- `DELETE /v1/responses/{id}` success +- `GET /v1/responses/{id}/input_items` + +Streaming tests: + +- stream emits Responses SSE event names in order +- deltas aggregate into stored final response +- stream errors emit `response.failed` +- client disconnect path is safe if testable + +Regression tests: + +- Phase 1 protocol tests +- Phase 2 transform logging tests +- Phase 3 adapter/cache tests +- `test_session_tracking.py` +- `test_selection_engine.py` + +## Commit Checkpoints + +1. Add Responses store and tests. +2. Add Responses bridge and tests. +3. Add Responses service with non-stream create/get/delete/input-items and tests. +4. Add HTTP routes and route tests. +5. Add HTTP SSE formatter/streaming conversion and tests. +6. Add WebSocket transport seam tests/docstrings without runtime route. +7. Run focused and regression tests. +8. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 4 report. + +## Risks And Mitigations + +- Responses API is broader than chat completions. Keep the bridge explicit about unsupported fields and preserve them in metadata/trace rather than pretending full native support. +- `previous_response_id` can leak context across users/scopes if storage keys are global. Include request/session/scope/classifier metadata and leave full multi-user isolation for later, but avoid credential secrets in storage keys. +- Streaming conversion can lose item fidelity. Accumulate output text and tool/reasoning items explicitly; preserve raw chunks in trace. +- In-memory store will not survive restarts. This is acceptable for Phase 4 default; provider-cache-backed persistence can be enabled where lifecycle is available. +- WebSocket support can be over-promised. Expose the abstraction seam only, not a functioning route, until the later transport phase. diff --git a/docs/experimental/phase-4b-responses-session-storage-transport.md b/docs/experimental/phase-4b-responses-session-storage-transport.md new file mode 100644 index 000000000..81e352540 --- /dev/null +++ b/docs/experimental/phase-4b-responses-session-storage-transport.md @@ -0,0 +1,94 @@ +# Phase 4b Plan: Responses API Session, Storage, And Transport Corrections + +## Goal + +Correct the Phase 4 audit findings without over-claiming full native Responses provider execution. Phase 4 made `/v1/responses` usable through the chat bridge, but the validation pass found that important Responses semantics are still incomplete: `previous_response_id` is only used to replay parent output, not as session-routing evidence; default storage has only incidental expiry support with no runtime TTL policy or max-size pruning; stream generation is SSE-first inside the service instead of cleanly transport-neutral; and route/service tracing/storage should more clearly represent stored/skipped/current-state behavior. + +## Non-Goals + +- Do not implement a WebSocket FastAPI route in this phase unless the transport abstraction makes it trivial and tests stay small. +- Do not replace the bridge with fully native provider Responses execution here. +- Do not introduce SQLite or a database. +- Do not build multi-user security/admin features. +- Do not change current public SSE event order except where tests show it is wrong. +- Do not commit user-facing reports. +- Do not touch unrelated dirty `ARCHITECTURE.md`, `STRUCTURE.md`, `.opencode/`, `docs/issues/`, or old phase reports. + +## Current Code State + +- `ResponsesService.create_response()` parses `previous_response_id`, loads the parent from the store, and passes parent output to `ResponsesBridge.to_chat_kwargs()`. +- `ResponsesBridge.to_chat_kwargs()` stores `previous_response_id` only in private `_responses_bridge` metadata used for trace context, then `ResponsesService` pops that metadata before calling `client.acompletion()`. +- The actual `RotatingClient` request context cannot see `previous_response_id`, so `SessionTracker` cannot use it as a strong anchor for credential affinity. +- `StoredResponse` has `expires_at`, and stores check expiry on `get()`, but `ResponsesService` never sets a TTL policy and `InMemoryResponsesStore` has no max-size pruning. +- `ProviderCacheResponsesStore` exists but app startup always creates `ResponsesService()` with default in-memory store. +- `ResponsesService.stream_response()` directly instantiates `ResponsesSSEFormatter()` and yields formatted strings. +- `ResponsesWebSocketFormatter` exists as a placeholder with `NotImplementedError`, which is honest but does not prove that service logic is formatter-neutral. +- Responses stream storage stores only the completed payload after stream success and can skip storage when `store=false`. It does not store in-progress current state. +- `response.failed` events are yielded on exceptions, but failed responses are not stored as failed state even when `store=true`. + +## Implementation Plan + +1. Add an internal Responses session-hints carrier. + - Introduce private `_session_tracking_hints` request kwargs consumed by `RequestContextBuilder` and removed before provider execution. + - Merge service-level hints with provider hints before `SessionTracker.infer_session()`. + - Preserve provider hints and explain that these hints are proxy-internal evidence, never provider payload. + +2. Make `previous_response_id` a strong Responses session anchor. + - Attach `_session_tracking_hints` for continuation requests with a strong anchor like `responses_previous_response_id:{id}` and deterministic affinity key. + - Do not fake a strong first-turn anchor from a generated response ID before the request executes. + - Add tests proving continuation requests expose hidden hints to context construction and the hidden field is not sent to providers. + +3. Record generated Responses IDs as response-derived metadata. + - Store response IDs and parent IDs clearly in `StoredResponse.metadata`. + - Add a helper for Responses session hints so future native Responses execution can share it. + +4. Add runtime storage policy. + - Add `ResponsesStoreSettings` with `ttl_seconds`, `max_items`, `store_failed`, and `store_in_progress`. + - Preserve current defaults: no expiry, no max pruning, completed responses stored, in-progress updates disabled. + +5. Implement TTL assignment and max-size pruning. + - Set `StoredResponse.expires_at` in `_stored_response()` from settings. + - Add max-item pruning to `InMemoryResponsesStore.save()`. + - Keep provider-cache expiry via `StoredResponse.expires_at`; document max-size limitations if listing is unavailable. + +6. Store failed stream responses when `store=true`. + - Persist `response.failed` payloads when storage policy says failed responses are stored. + - Keep `store=false` skip behavior. + +7. Add a current-state storage seam. + - Add `store_in_progress` default false. + - When enabled, save created/intermediate/completed stream state using the same store API. + - Keep default behavior unchanged. + +8. Refactor stream generation to transport-neutral events. + - Add `ResponsesStreamEvent` and a formatter interface in `responses/streaming.py`. + - Add `ResponsesService.stream_events()` yielding event objects. + - Make `stream_response()` a thin SSE wrapper over `stream_events()`. + - Keep `ResponsesWebSocketFormatter` honest: either pure JSON event formatting or explicit route-level future support, but service logic must not be SSE-specific. + +9. Tighten trace boundaries around the event pipeline. + - Preserve existing `responses_stream_event_*` pass names. + - Avoid trace-only conversions when no transaction logger is present. + - Add trace pass for current-state storage updates if enabled. + +10. Update FastAPI route wiring minimally. + - Keep current `/v1/responses` route behavior. + - Continue using `service.stream_response()` for SSE. + - Defer provider-cache/durable config wiring to config work unless a safe existing setting exists. + +## Tests + +- Responses service tests for hidden session hints, TTL, max pruning, and metadata. +- Responses streaming tests for `stream_events()` order, SSE wrapper compatibility, failed storage, `store=false`, and WebSocket formatter seam. +- Responses store tests for memory max pruning and provider-cache expiry behavior. +- Request-builder/session tracking tests for `_session_tracking_hints` consumption and provider payload cleanup. +- Route regression tests for current HTTP behavior. + +## Acceptance Criteria + +- `previous_response_id` becomes strong session-routing evidence for continuation requests without leaking internal hints to provider payloads. +- Responses storage has explicit TTL/max-size policy and failed-stream storage behavior. +- Streaming service logic can emit transport-neutral event objects, with SSE as a formatter wrapper and WebSocket support no longer requiring service/protocol rewrite. +- Existing `/v1/responses`, retrieve, delete, input_items, and SSE behavior remain compatible. +- Trace-disabled paths avoid unnecessary trace-only conversions. +- Focused tests and dual-agent review pass. diff --git a/docs/experimental/phase-4c-responses-storage-lineage-errors.md b/docs/experimental/phase-4c-responses-storage-lineage-errors.md new file mode 100644 index 000000000..b35b1a329 --- /dev/null +++ b/docs/experimental/phase-4c-responses-storage-lineage-errors.md @@ -0,0 +1,60 @@ +# Phase 4c: Responses Storage, Continuation Lineage, And Error Shape + +## Goal + +Close the Phase 4/4b third-pass Responses findings while keeping the existing bridge architecture and no-SQLite constraint. + +## Scope + +- Wire configurable Responses storage at app startup. +- Keep process-local memory storage as the default. +- Add provider-cache-backed JSON storage as an opt-in durable backend. +- Preserve existing `ResponsesStoreSettings` policy for TTL, maximum items, failed storage, and in-progress storage. +- Extend `previous_response_id` continuation context to replay parent input and output lineage, oldest-to-newest. +- Return top-level OpenAI-compatible `error` bodies from Responses routes. + +## Non-Goals + +- Do not introduce SQLite or a new database. +- Do not replace the chat-completions bridge with native Responses execution. +- Do not add a WebSocket route. +- Do not implement a cancel endpoint. +- Do not perfect every rich Responses item type in the bridge; cover known text/message/tool-call cases and preserve unsupported fields. + +## Implementation Plan + +1. Configurable Responses store backend. + - Add runtime config helpers for `memory` and `provider_cache` backends. + - Use existing provider-cache JSON storage for durable mode. + - Keep memory backend as the default. + +2. Startup wiring. + - Construct `ResponsesService(store=..., store_settings=...)` at proxy startup and fallback service creation. + - Keep direct `ResponsesService()` defaults unchanged for tests and embedding users. + +3. Continuation lineage. + - Load parent chains through stored `request.previous_response_id` with depth and cycle guards. + - Pass oldest-to-newest lineage into the bridge. + +4. Bridge replay. + - Replay parent input items as user messages when convertible. + - Replay parent output message items as assistant messages. + - Keep existing parent-output behavior for external callers using the old argument. + +5. Top-level route errors. + - Return `JSONResponse(status_code=..., content={"error": ...})` for Responses service errors. + - Apply to create validation, retrieve, delete, and input-items routes. + +6. Tests. + - Config/env tests for storage backend selection. + - Store tests for provider-cache-backed persistence. + - Bridge/service tests for parent input + output lineage replay. + - Route tests for top-level error bodies. + +## Acceptance Criteria + +- App startup can use durable provider-cache-backed Responses storage via config/env. +- Existing memory storage remains default. +- Continuations replay parent input and output lineage, not just parent output. +- Responses route errors use top-level OpenAI-compatible `error` bodies. +- Focused tests pass and both reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/phase-5-provider-protocol-overhaul.md b/docs/experimental/phase-5-provider-protocol-overhaul.md new file mode 100644 index 000000000..c05d94ad7 --- /dev/null +++ b/docs/experimental/phase-5-provider-protocol-overhaul.md @@ -0,0 +1,263 @@ +# Phase 5 Plan: Native Provider Protocol Overhaul + +## Goal + +Make provider implementations opt into the native protocol, adapter, field-cache, and Responses foundations added in Phases 1-4, then add or restore the priority providers in the requested order: Claude Code, Codex, Copilot, Antigravity, and Gemini CLI parity review. This phase should create a reusable provider-native execution seam first, then implement providers incrementally behind explicit declarations so LiteLLM remains fallback-only for uncovered cases. + +## Non-Goals + +- Do not replace `UsageManager`, `SessionTracker`, `SelectionEngine`, or retry-after parsing. +- Do not implement routing/fallback groups yet; Phase 6 owns ordered fallback chains. +- Do not implement full multi-user/security isolation. +- Do not add SQLite or any database. +- Do not rewrite the entire executor in one pass. +- Do not migrate every existing provider at once. +- Do not copy obsolete retired-provider behavior blindly. +- Do not remove LiteLLM fallback; make fallback explicit and observable. + +## Current Code Context + +- Phase 1 provides reusable protocol adapters: OpenAI Chat, Anthropic Messages, Gemini, Responses, and LiteLLM fallback. +- Phase 2 provides transaction transform tracing and provider trace entries. +- Phase 3 provides provider declarations, adapter chains, field-cache rules, and scoped cache engine. +- Phase 4 provides `/v1/responses`, response storage, HTTP SSE, and WebSocket seam. +- `ProviderInterface` already has optional `get_protocol_name()`, `get_adapter_names()`, `get_adapter_config()`, and `get_field_cache_rules()`. +- Existing custom providers with `has_custom_logic()` include Gemini CLI, Deepseek, and retired Antigravity/IFlow/Qwen-style providers. +- Gemini CLI is large and custom; Phase 5 should review and target improvements rather than destabilize it. +- Antigravity exists only under `src/rotator_library/providers/_retired/` and must be compared against current behavior before restoration. +- No Claude Code, Codex, or Copilot provider files exist in active providers today. + +## Files To Add + +- `src/rotator_library/native_provider/__init__.py` +- `src/rotator_library/native_provider/context.py` +- `src/rotator_library/native_provider/executor.py` +- `src/rotator_library/native_provider/http.py` +- `src/rotator_library/native_provider/streaming.py` +- `src/rotator_library/providers/claude_code_provider.py` +- `src/rotator_library/providers/codex_provider.py` +- `src/rotator_library/providers/copilot_provider.py` +- Possibly `src/rotator_library/providers/antigravity_provider.py` +- `tests/test_native_provider_executor.py` +- `tests/test_native_provider_streaming.py` +- `tests/test_claude_code_provider.py` +- `tests/test_codex_provider.py` +- `tests/test_copilot_provider.py` +- `tests/test_antigravity_provider_restore.py` if restored in this phase +- `tests/test_gemini_cli_protocol_declarations.py` + +## Files Likely To Touch + +- `src/rotator_library/providers/provider_interface.py` only for small optional native execution methods if needed. +- `src/rotator_library/client/executor.py` only if a minimal seam is needed to route declared native providers without breaking current providers. +- `src/rotator_library/client/models.py` only if model/provider resolution needs explicit provider declarations. +- `src/rotator_library/client/request_builder.py` only if provider declarations need context fields before execution. +- `src/rotator_library/responses/service.py` only if native Responses-capable providers can bypass the chat bridge cleanly. +- Existing provider files only for declaration additions or targeted Gemini CLI fixes. + +## Native Provider Execution Seam + +- Add `NativeProviderContext` with provider, model, credential identifier, headers, request context, protocol name, adapter names, field-cache rules, transport, transaction logger, session fields, scope key, classifier, and provider metadata. +- Add `NativeProviderExecutor` that can: + - Resolve protocol via `provider.get_protocol_name(model)`. + - Build `AdapterContext` and run `before_adapter_chain` / `after_adapter` / `after_adapter_chain`. + - Run `FieldCacheEngine.inject()` before provider request. + - Build native HTTP request payload from the selected protocol. + - Send request through a small HTTP transport wrapper. + - Parse provider response through selected protocol. + - Run `FieldCacheEngine.extract()` after response or stream events. + - Format back to the requested client protocol. + - Emit transform trace passes for every step. +- The executor must be provider-opt-in. If `get_protocol_name()` returns `None` or `"litellm_fallback"`, current behavior stays unchanged. +- It must be independently testable with mocked HTTP clients and fake providers. + +## Provider Interface Additions + +Prefer using existing declarations first. Add optional methods only if needed: + +- `get_native_endpoint(model, operation)` +- `get_native_headers(credential_identifier, model, operation)` +- `build_native_request_options(model, operation)` +- `supports_native_operation(operation, model)` +- `should_use_native_protocol(operation, model)` + +Defaults must preserve current behavior. Docstrings must explain provider overrides are encouraged when a base protocol is close but not exact. + +## HTTP Transport + +- Use injected `httpx.AsyncClient`. +- Support JSON POST first. +- Support SSE response iteration for streaming. +- Preserve raw request/response bodies for transform trace. +- Do not add WebSocket runtime transport yet; keep the seam compatible with Phase 4 WebSocket formatter. +- Convert provider HTTP errors into existing error classification paths where possible. + +## Streaming + +- Native stream parser should call protocol stream parsing where available. +- Emit transform passes: + - `native_provider_request` + - `raw_native_provider_stream_chunk` + - `parsed_native_stream_event` + - `after_field_cache_stream_extraction` + - `formatted_client_stream_event` +- Do not fallback after visible output. Phase 6 owns fallback policy; Phase 5 only makes stream behavior observable and safe. + +## Responses Integration + +- For providers with native Responses support, allow `ResponsesService` to bypass `ResponsesBridge`. +- Add a service-level native executor hook only if it can be done without coupling service to provider internals. +- If not clean in Phase 5, leave `/v1/responses` bridge as default and expose provider-native Responses in provider tests/foundation for Phase 6/8 wiring. +- Preserve `/v1/responses` route behavior from Phase 4. + +## Priority Provider Plan + +### Claude Code + +- Add provider first. +- Determine whether it is Anthropic Messages-compatible, OpenAI Chat-compatible, or a dedicated endpoint. +- Start with declarations and mocked native execution tests. +- Preserve Claude reasoning/thinking fields via field-cache rules if present. +- Suppress/transform unsupported roles through adapters instead of bespoke monolithic code. +- Add tests for auth headers, model list behavior, request transform, response transform, streaming text, and field-cache rule declarations. +- If live API details are uncertain, implement an integration path with explicit env names and mocked endpoint behavior rather than guessing secrets or undocumented flows. + +### Codex + +- Add provider second. +- Treat as likely OpenAI/Responses-compatible until provider-specific evidence says otherwise. +- Prefer native Responses protocol if supported; otherwise OpenAI Chat protocol. +- Add tests for protocol selection, Responses bypass or bridge compatibility, auth headers, model naming, and explicit no-LiteLLM path when native mode is declared. + +### Copilot + +- Add provider third. +- Use native protocol declarations and OAuth/header helpers. +- Keep credential acquisition/refresh minimal and mocked unless existing credential machinery already supports it. +- Add field-cache rules for conversation/session IDs if required by provider behavior. +- Add tests for model list, auth header, protocol selection, request conversion, and streaming. + +### Antigravity + +- Compare `src/rotator_library/providers/_retired/antigravity_provider.py`, auth base, quota tracker, and device profile utilities before restoring. +- Restore only valid/current behavior. +- Avoid fragile or obsolete device-profile behavior unless current service behavior requires it. +- Extract reusable pieces: + - model mapping + - auth headers + - schema cleanup + - thinking/reasoning preservation + - quota parsing + - SSE handling +- Do not resurrect the whole monolith unchanged. +- Add a restored provider only after tests describe the stable subset. +- If the current service cannot be validated safely, write an explicit integration-path provider skeleton and defer live-specific behavior. + +### Gemini CLI Parity Review + +- Review active `gemini_cli_provider.py` against the new protocol/adapter/cache foundation. +- Add provider declarations where safe: + - likely protocol `gemini` + - adapters for model override / field rename / role suppression only if needed + - field-cache rules for thought signatures and provider session fields if they match current behavior. +- Do not rewrite Gemini CLI in this phase. +- Add targeted tests proving declarations do not change current behavior. +- Fix only clear parity gaps found during review. + +## LiteLLM Fallback Policy + +- Providers with native declarations should not silently fall back to LiteLLM on the primary path. +- If fallback is allowed, it must be explicit: + - protocol name `"litellm_fallback"` + - trace pass `native_provider_litellm_fallback` + - metadata reason. +- Tests must assert no accidental LiteLLM fallback for providers declared native. + +## Transform Trace Requirements + +Provider-native calls should log: + +- `native_protocol_selected` +- `before_adapter_chain` +- `after_adapter` +- `after_field_cache_injection` +- `native_provider_request` +- `raw_native_provider_response` +- `parsed_native_provider_response` +- `after_field_cache_extraction` +- `final_client_response` + +Provider-native stream calls should log: + +- `native_protocol_selected` +- `native_provider_stream_request` +- `raw_native_provider_stream_chunk` +- `parsed_native_stream_event` +- `after_field_cache_stream_extraction` +- `formatted_client_stream_event` + +Errors should use `log_transform_error()` with provider/protocol/pass context. + +## Field-Cache Requirements + +- Use Phase 3 rules for: + - reasoning content + - thought signatures + - provider session IDs + - prompt cache keys + - previous/response IDs where provider-specific +- Scope must include at least provider + model + classifier + session for conversation-affecting values. +- Credential scope must be added for values tied to an account/token. +- Tests must prove no cross-provider/model/session/credential leakage for provider rules. + +## Testing Plan + +Native executor tests: + +- protocol selection +- adapter chain order +- field-cache injection/extraction +- non-stream request/response trace passes +- stream trace passes +- explicit fallback behavior +- provider opt-in leaves current fallback path untouched for undeclared providers + +Provider tests: + +- registration/discovery for Claude Code, Codex, Copilot, and restored Antigravity if added. +- env var naming and auth header construction via mocks. +- model list parsing via mocked HTTP. +- request payload build and response parse via mocked HTTP. +- SSE stream conversion via mocked chunks. +- field-cache declarations. +- no live credentials required. + +Regression tests: + +- Phase 1 protocol tests. +- Phase 2 logging tests. +- Phase 3 adapter/field-cache tests. +- Phase 4 Responses tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add native provider context/executor/HTTP transport with tests. +2. Add native streaming support with tests. +3. Add Claude Code provider or integration skeleton with tests. +4. Add Codex provider or integration skeleton with tests. +5. Add Copilot provider or integration skeleton with tests. +6. Compare and restore the safe Antigravity subset, or write a documented deferral with tests. +7. Add Gemini CLI declaration/parity fixes with tests. +8. Run focused and regression tests. +9. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 5 report. + +## Risks And Mitigations + +- Provider APIs may be undocumented or volatile. Use mocked behavior and explicit integration seams rather than guessing hidden flows. +- A native executor could destabilize current traffic. Keep native execution opt-in and leave undeclared providers unchanged. +- Restoring Antigravity could reintroduce brittle device/profile logic. Restore only tested stable subsets. +- Streaming fidelity for tool/reasoning deltas may vary by provider. Preserve raw chunks in trace and keep provider-specific parsers override-friendly. +- Field-cache leakage would be serious. Add scope tests per provider rule. +- LiteLLM fallback could hide native failures. Make fallback explicit, traced, and tested. diff --git a/docs/experimental/phase-5b-provider-native-integrations.md b/docs/experimental/phase-5b-provider-native-integrations.md new file mode 100644 index 000000000..e861a9df8 --- /dev/null +++ b/docs/experimental/phase-5b-provider-native-integrations.md @@ -0,0 +1,101 @@ +# Phase 5b Plan: Priority Providers From Skeletons To Mock-Live Native Integrations + +## Goal + +Correct the Phase 5 audit gap. Phase 5 added native-provider seams and priority provider declarations, but the validation pass found the priority providers are still mostly skeletons. Phase 5b will make Claude Code, Codex, Copilot, and Antigravity usable through the native execution path under mock HTTP tests, while preserving Gemini CLI's existing custom execution path and adding parity declarations where they are safe. + +## Non-Goals + +- Do not remove LiteLLM fallback. +- Do not replace Gemini CLI's existing custom provider implementation with the generic native executor in this phase. +- Do not invent device fingerprinting or brittle environment/device-profile behavior for Antigravity. +- Do not add real credential acquisition flows for Copilot/Codex/Claude Code; this phase consumes credentials supplied by the existing credential system. +- Do not use external files outside the project root. +- Do not introduce SQLite or new persistence. +- Do not touch unrelated dirty `ARCHITECTURE.md`, `STRUCTURE.md`, `.opencode/`, `docs/issues/`, or old phase reports. +- Do not commit user-facing phase reports. + +## Current Code State + +- `NativeProviderExecutor` can parse/build/adapter/cache/HTTP/format native requests and streams. +- `RequestExecutor` can select native execution when routing target execution is `native` or auto-detected by provider protocol declaration. +- `_build_native_provider_context()` currently asks every provider for endpoint and headers with operation `"chat"`, which is too generic. +- Priority provider files exist for Claude Code, Codex, Copilot, and Antigravity, but tests mostly cover declarations and helper methods. +- Provider-prefixed model names can leak into upstream native payloads unless each provider normalizes them. +- Native streaming exists, but provider support flags are conservative and not all priority providers have stream coverage. +- Gemini CLI has substantial existing custom logic and must not be silently bypassed by auto-native routing. + +## Implementation Plan + +1. Add provider-native operation resolution. + - Add default methods to `ProviderInterface`: `get_native_operation()`, `normalize_native_model()`, and optional `prepare_native_request()`. + - Preserve current behavior by default. + - Update `RequestExecutor._build_native_provider_context()` to ask providers for operation, endpoint, headers, normalized model, and prepared request metadata. + +2. Make model normalization explicit and tested. + - Claude Code strips `claude_code/`. + - Codex strips `codex/`. + - Copilot strips `copilot/`. + - Antigravity strips `antigravity/` and maps public aliases to internal upstream names. + - Gemini CLI remains custom-path first. + +3. Make provider endpoints operation-aware. + - Claude Code uses `messages` and `/v1/messages`. + - Codex uses `responses` and `/v1/responses`. + - Copilot uses `chat` and `/chat/completions`. + - Antigravity uses Gemini generate/stream-generate endpoints. + +4. Add minimal provider request preparation. + - Normalize model before protocol parsing. + - Allow provider `prepare_native_request()` to deep-copy and adjust request payloads before protocol parsing. + - Trace this pass without adding credentials. + +5. Strengthen provider auth/header behavior. + - Keep current supplied-credential model. + - Add tested header sets for each provider. + - Do not add secrets to JSON config. + +6. Native streaming support declarations. + - Enable only where tested safe. + - Prove native streaming selection through `RequestExecutor` tests. + - Keep unsupported providers on existing fallback/custom paths. + +7. Mock-live RequestExecutor integration tests. + - Prove Claude Code, Codex, Copilot, and Antigravity use the native executor with the correct protocol, operation, endpoint, headers, model normalization, and response formatting under fake HTTP. + - Cover streaming for providers that opt in. + +8. Provider model discovery hardening. + - Test fallback and successful discovery for each priority provider. + - Avoid duplicate prefixes and invalid aliases. + +9. Gemini CLI parity review. + - Keep custom execution path. + - Verify declarations align with Gemini protocol and field-cache paths. + - Explicitly prevent accidental auto-native routing if required. + +10. Documentation and comments. + - Update provider docstrings away from “skeleton” once behavior is mock-live. + - Explain model normalization, Antigravity safety boundaries, and Gemini CLI custom-path deferral. + +## Tests + +- `tests/test_claude_code_provider.py` +- `tests/test_codex_provider.py` +- `tests/test_copilot_provider.py` +- `tests/test_antigravity_provider_restore.py` +- `tests/test_gemini_cli_protocol_declarations.py` +- `tests/test_provider_protocol_declarations.py` +- `tests/test_native_provider_executor.py` +- `tests/test_native_provider_streaming.py` +- `tests/test_request_executor_native_routing.py` +- Relevant protocol, field-cache, and routing regressions. + +## Acceptance Criteria + +- Priority providers are no longer declaration-only skeletons; each has mock-live native `RequestExecutor` coverage or an explicit tested reason for custom-path deferral. +- Native operation and endpoint selection are provider-aware, not hardcoded to `"chat"`. +- Provider-prefixed model names are normalized before native upstream calls. +- Native streaming is enabled only where tested safe. +- Gemini CLI remains on its existing custom path unless explicitly routed otherwise. +- LiteLLM fallback remains available for uncovered providers/protocols. +- Focused tests and dual-agent review pass. diff --git a/docs/experimental/phase-5c-provider-native-correction.md b/docs/experimental/phase-5c-provider-native-correction.md new file mode 100644 index 000000000..f07939758 --- /dev/null +++ b/docs/experimental/phase-5c-provider-native-correction.md @@ -0,0 +1,72 @@ +# Phase 5c: Provider-Native Output Protocol And Contract Correction + +## Goal + +Close the Phase 5/5b third-pass provider-native findings while preserving the current safety stance: native streaming remains disabled for priority providers until each live stream path is proven, and LiteLLM remains fallback only for unsupported cases. + +## Scope + +- Return native provider responses in the originating client protocol instead of the provider protocol. +- Add explicit native endpoint/header/operation-selection hooks to `ProviderInterface`. +- Ensure Claude Code native requests always include required Anthropic `max_tokens`. +- Harden Claude Code API-key header behavior without breaking bearer/OAuth-style credentials. +- Fix Antigravity alias normalization so duplicated aliases do not collapse incorrectly and low/high thinking aliases can still affect request metadata. +- Centralize explicit-native streaming fail-closed behavior. +- Keep priority-provider native streaming disabled by default. +- Add focused tests for all blocker/high/medium findings. + +## Non-Goals + +- Do not enable live native streaming for Claude Code, Codex, Copilot, or Antigravity. +- Do not restore retired Antigravity device-profile/fingerprint behavior. +- Do not replace credential rotation, usage tracking, or routing. +- Do not implement all rich Codex Responses item conversions in this phase. +- Do not commit user-facing reports. + +## Implementation Plan + +1. Client target protocol in native context. + - Add `client_protocol_name` to `NativeProviderContext`. + - Default direct/library native use to provider protocol for backwards compatibility. + - Set `client_protocol_name="openai_chat"` from the chat-completions executor path. + - Format non-streaming native responses using the client protocol after provider protocol parsing. + - Format native stream events using the client protocol when a client protocol is set. + +2. Provider interface native contract. + - Add default `get_native_endpoint()` and `get_native_headers()` methods that raise clear `NotImplementedError`. + - Add `supports_native_operation()` defaulting to operation support through provider declarations. + - Add `should_use_native_protocol()` defaulting to true only when a provider declares a native protocol and operation is supported. + - Update executor checks to call hooks rather than relying on `hasattr()`. + +3. Claude Code hardening. + - `prepare_native_request()` ensures Anthropic `max_tokens` is present, using existing request value or `CLAUDE_CODE_MAX_TOKENS` default. + - Header selection supports bearer credentials and API-key credentials: + - `CLAUDE_CODE_AUTH_HEADER=bearer|x-api-key|auto` + - auto uses `x-api-key` for `sk-ant-*` style keys and bearer otherwise. + - Preserve `anthropic-version` and content type. + +4. Antigravity alias/thinking metadata. + - Fix alias-to-upstream mapping to use public alias map directly, not a lossy reverse map. + - Preserve thinking-level hints for `gemini-3-pro-low` / `gemini-3-pro-high` in request metadata before the upstream model is normalized. + - Keep model discovery output stable and prefixed. + +5. Native streaming fail-closed helper. + - Centralize explicit-native streaming unsupported handling so auto and explicit modes use the same fail-closed decision. + - Keep priority providers non-streaming native only. + +6. Tests. + - Native executor: provider protocol response -> OpenAI chat client response. + - Request executor: routed native Claude/Codex/Antigravity chat completions return OpenAI Chat shape. + - Claude Code: missing `max_tokens` gets defaulted and auth header modes behave correctly. + - Antigravity: low/high aliases normalize safely and preserve request metadata. + - ProviderInterface: native hooks exist and unsupported operation checks are explicit. + - Streaming: explicit native streaming fail-closed path uses centralized helper. + +## Acceptance Criteria + +- `/v1/chat/completions` native routes return OpenAI Chat response shape regardless of provider-native protocol. +- Claude Code native Messages requests include `max_tokens`. +- Antigravity model normalization does not lose low/high alias intent. +- Provider native endpoint/header/operation hooks are explicit on `ProviderInterface`. +- Explicit native streaming fails closed through one helper path. +- Focused tests pass and both reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/phase-6-routing-fallback-groups.md b/docs/experimental/phase-6-routing-fallback-groups.md new file mode 100644 index 000000000..d546e8efc --- /dev/null +++ b/docs/experimental/phase-6-routing-fallback-groups.md @@ -0,0 +1,294 @@ +# Phase 6 Plan: Routing And Fallback Groups + +## Goal + +Add explicit ordered routing/fallback groups and a provider-native routing seam that can choose between the current LiteLLM execution path, provider custom logic, and the Phase 5 `NativeProviderExecutor`. Fallback groups are the priority deliverable: a model can resolve to an ordered chain of concrete targets, and the executor can try the next target on eligible failures without replacing the existing credential rotation, usage tracking, session tracking, or retry-after parsing. + +## Non-Goals + +- Do not replace `UsageManager`, `SelectionEngine`, `SessionTracker`, or retry-after parsing. +- Do not add full target-group selector language before fallback groups work. +- Do not implement multi-user/security isolation. +- Do not implement a complete config-file system; Phase 10 owns config polish. +- Do not route every provider natively by default. +- Do not fallback after visible streamed output has been emitted. +- Do not silently fallback to LiteLLM for native providers unless the fallback chain explicitly includes a LiteLLM target. +- Do not make Claude Code/Codex/Copilot/Antigravity live-production claims beyond mocked/tested routing behavior. + +## Current Code Context + +- `RequestExecutor` is the live retry/rotation path. +- `RequestExecutor` already gets provider plugin instances, usage managers, credentials, request context, transaction logger, and LiteLLM provider params. +- Phase 5 `NativeProviderExecutor` is isolated and opt-in. +- Providers can declare `protocol_name`, adapter names, field-cache rules, native endpoints, and native headers. +- Model resolution currently maps a single requested model to one provider/model pair. +- Existing retry loops rotate credentials inside one provider; Phase 6 must add model/provider target fallback around that without breaking per-provider credential rotation. +- Existing transaction trace can log routing decisions. +- Existing streaming handler can rotate credentials before output, but stream fallback after output must remain disallowed. + +## Files To Add + +- `src/rotator_library/routing/__init__.py` +- `src/rotator_library/routing/types.py` +- `src/rotator_library/routing/config.py` +- `src/rotator_library/routing/resolver.py` +- `src/rotator_library/routing/policy.py` +- `src/rotator_library/routing/executor.py` or `attempts.py` +- `tests/test_routing_config.py` +- `tests/test_fallback_resolver.py` +- `tests/test_fallback_policy.py` +- `tests/test_request_executor_fallback_groups.py` +- `tests/test_request_executor_native_routing.py` +- `tests/test_streaming_fallback_policy.py` + +## Files Likely To Touch + +- `src/rotator_library/core/types.py` to add optional routing/fallback fields to `RequestContext`. +- `src/rotator_library/client/request_builder.py` to resolve fallback groups before execution. +- `src/rotator_library/client/models.py` if model resolution needs a group-aware helper. +- `src/rotator_library/client/executor.py` to wrap existing per-provider retry loops in target attempts and call native executor for eligible targets. +- `src/rotator_library/client/rotating_client.py` only if passing route configuration into components requires it. +- `src/rotator_library/providers/provider_interface.py` only if a small native-support method is needed. +- `src/rotator_library/transaction_logger.py` only if adding route correlation helpers is cleaner. + +## Routing Types + +`RouteTarget`: + +- `name`: stable target identifier. +- `provider`: provider key. +- `model`: concrete model name, with provider prefix optional but normalized. +- `protocol`: optional override; defaults to provider declaration. +- `execution`: `auto`, `native`, `custom`, or `litellm_fallback`. +- `priority`: optional numeric order. +- `weight`: future target-group selector support, ignored for ordered fallback. +- `conditions`: optional metadata for later selectors. +- `metadata`. + +`FallbackGroup`: + +- `name`. +- `targets`: ordered `RouteTarget` list. +- `failover_on`: error categories that permit next-target fallback. +- `stop_on`: error categories that must stop. +- `streaming_policy`: e.g. `pre_output_only`. +- `max_targets`: optional guard. +- `metadata`. + +`RoutingDecision`: + +- `requested_model`. +- `group_name` optional. +- `targets`. +- `selected_target_index`. +- `reason`. + +`RouteAttemptResult`: + +- target, success/failure, error classification, emitted_output flag, usage summary. + +## Configuration Plan + +- Phase 6 supports env-first fallback definitions with a small parser. +- Optional JSON config can be parsed if simple, but deeper config merging belongs to Phase 10. +- Env examples: + - `FALLBACK_GROUPS=sonnet_chain,codex_chain` + - `FALLBACK_GROUP_SONNET_CHAIN=claude_code/claude-sonnet-4-5,copilot/claude-sonnet-4-5,anthropic/claude-3-5-sonnet-latest` + - `FALLBACK_GROUP_CODEX_CHAIN=codex/gpt-5.1-codex,openai/gpt-5.1` + - `MODEL_ROUTE_CLAUDE_SONNET=group:sonnet_chain` + - `MODEL_ROUTE_CODEX=group:codex_chain` +- Also allow programmatic construction in tests. +- Config parser must validate: + - group names are unique. + - targets have provider/model. + - empty chains are invalid. + - cycles through group aliases are rejected. + - provider names are not silently guessed when ambiguous. +- No secrets in route config. + +## Target Resolution + +- If requested model maps to a fallback group, return all group targets in order. +- If requested model is already `provider/model`, return a single target. +- If requested model maps through existing model definitions, preserve existing behavior. +- Target model should be normalized with provider prefix before entering `RequestExecutor`. +- Record transform trace pass `routing_decision` with requested model, group, target count, selected target names, execution modes, and no credentials. + +## Fallback Policy + +Try next target only on eligible failures: + +- rate limit/quota/capacity. +- provider unavailable/server error. +- transient connection errors. +- native provider unsupported operation when explicit fallback target exists. + +Do not fallback on: + +- auth errors for all credentials unless the group target explicitly marks auth fallback safe. +- validation/permanent request errors. +- pre-request callback failures. +- client cancellation. +- streaming errors after visible output. + +Respect existing per-provider credential rotation before moving to the next target: + +- one target attempt should let current `RequestExecutor` exhaust eligible credentials according to existing retry logic. +- after a target is exhausted, the fallback layer can advance to the next target. + +Keep error accumulator context across targets so final failure includes all target errors. + +## Native, Custom, And LiteLLM Execution Choice + +`execution=auto`: + +- if plugin has `has_custom_logic()`, use current plugin `acompletion()`. +- else if provider has native protocol declaration and native endpoint/header methods, use `NativeProviderExecutor`. +- else use current LiteLLM call path. + +`execution=native`: require native support or fail this target with a classified unsupported-operation error. + +`execution=custom`: require `has_custom_logic()`. + +`execution=litellm_fallback`: force current LiteLLM path and emit `native_provider_litellm_fallback` / `routing_litellm_fallback` trace metadata. + +Tests must prove native-declared providers do not silently use LiteLLM unless the target says `litellm_fallback`. + +## RequestExecutor Integration + +- Keep the existing `_execute_non_streaming()` and `_execute_streaming()` credential retry loops as the per-target implementation. +- Add a wrapper method that accepts a `RequestContext` with `routing_targets`. +- For each target: + - clone/update context provider/model/usage manager key/credentials for that target. + - record `routing_target_attempt_started`. + - run the existing target execution path. + - on success record `routing_target_attempt_succeeded` and return. + - on failure classify and ask fallback policy if next target is allowed. + - record `routing_target_attempt_failed`. +- Avoid mutating the original request context in-place. +- If no group exists, execute exactly as today. + +## Native Provider Execution Integration + +- Add a small execution branch inside the target attempt where the selected target resolves to native execution. +- Construct `NativeProviderContext` from: + - provider plugin declarations. + - selected model. + - credential identifier/token. + - request context session/scope/classifier. + - transaction logger. + - endpoint/headers from provider methods. +- Use Phase 5 `NativeProviderExecutor` for mocked/tested providers. +- Preserve usage recording through the existing credential context. If native responses do not provide full usage yet, record normalized available usage only and leave cost normalization for Phase 9. +- For custom providers with `has_custom_logic()`, keep existing plugin `acompletion()` branch. +- For undeclared providers, keep LiteLLM path. + +## Streaming Integration + +- Fallback to next target is allowed only before any visible chunk is yielded. +- Track `emitted_output` in the stream wrapper. +- If a stream errors before output, fallback policy may advance to next target. +- If a stream errors after output, propagate the error; do not switch models mid-stream. +- Record: + - `routing_stream_target_attempt_started` + - `routing_stream_target_attempt_failed` + - `routing_stream_target_attempt_succeeded` + - `routing_stream_fallback_blocked_after_output`. +- Keep current stream retry policy for same-target/same-provider errors before group fallback. + +## Transaction Trace Requirements + +- `routing_decision` +- `routing_target_attempt_started` +- `routing_target_attempt_failed` +- `routing_target_attempt_succeeded` +- `routing_fallback_selected` +- `routing_fallback_exhausted` +- `routing_litellm_fallback` +- `routing_native_execution_selected` +- stream equivalents where applicable. +- Trace metadata includes group, target index, provider, model, execution mode, classification, and reason. It must not include raw credentials. + +## Target Groups As Optional Richer Layer + +- Phase 6 focuses on ordered fallback groups. +- Add type seams for future target groups: + - `TargetSelector` + - `TargetGroup` + - `TargetSelectionPolicy` +- Do not implement weighted/latency/cost selector behavior unless it is trivial and isolated. +- Document that Phase 6 fallback groups are deterministic ordered chains. + +## Tests + +Config tests: + +- parse env fallback group. +- invalid empty group. +- duplicate group names. +- provider/model parsing. +- explicit `litellm_fallback` execution target. + +Resolver tests: + +- requested alias maps to group. +- provider/model remains single target. +- target order is preserved. +- existing model prefix behavior stays intact. + +Policy tests: + +- fallback on rate limit/server/transient. +- stop on auth/permanent/pre-request/cancel. +- streaming fallback allowed before output. +- streaming fallback blocked after output. + +Request executor tests: + +- first target success uses no fallback. +- first target rate limits all credentials then second target succeeds. +- final error includes both target failures. +- no fallback group preserves old behavior. +- transaction trace includes routing passes. + +Native routing tests: + +- native-declared provider target uses `NativeProviderExecutor`. +- native-declared provider does not use LiteLLM unless target execution is `litellm_fallback`. +- custom provider still uses plugin `acompletion()`. +- undeclared provider still uses LiteLLM. + +Streaming tests: + +- pre-output target failure falls through to next target. +- post-output failure does not fall through. +- trace records blocked fallback after output. + +Regression: + +- Phase 1 protocol tests. +- Phase 2 logging tests. +- Phase 3 adapter/cache tests. +- Phase 4 Responses tests. +- Phase 5 provider tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add routing types, config parser, resolver, and policy with tests. +2. Add request context routing fields and target cloning helpers with tests. +3. Integrate non-streaming fallback group wrapper around existing executor path with tests. +4. Integrate native/custom/LiteLLM execution mode selection with tests. +5. Integrate streaming pre-output fallback policy with tests. +6. Add trace pass coverage and run regressions. +7. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 6 report. + +## Risks And Mitigations + +- Fallback could double-spend or corrupt usage. Mitigation: each target uses existing credential context and usage manager; group fallback starts only after target failure. +- Fallback could mask permanent request bugs. Mitigation: policy stops on validation/permanent errors. +- Streaming fallback could corrupt client output. Mitigation: disallow fallback after visible output. +- Native routing could accidentally bypass LiteLLM behavior for existing providers. Mitigation: default no group/no native route is unchanged; native path is explicit/declared. +- Config could become too broad. Mitigation: small env parser now; richer JSON config later in Phase 10. +- Session affinity could leak across provider pools. Mitigation: keep `SessionTracker` namespace behavior and clone target context with provider/model-specific namespace where needed. diff --git a/docs/experimental/phase-6b-routing-fallback-correctness.md b/docs/experimental/phase-6b-routing-fallback-correctness.md new file mode 100644 index 000000000..3d53895bd --- /dev/null +++ b/docs/experimental/phase-6b-routing-fallback-correctness.md @@ -0,0 +1,91 @@ +# Phase 6b: Routing/Fallback Correctness And Structured Error Safety + +## Goal + +Correct the Phase 6 audit findings. Phase 6 built ordered fallback groups and live executor wrappers, but fallback still needs stronger safety around hard-stop errors, structured error responses, streaming policy enforcement, and sanitized target summaries. + +## Non-Goals + +- Do not replace `UsageManager`, `SessionTracker`, retry-after parsing, or credential rotation behavior. +- Do not implement rich target-group selector syntax beyond ordered fallback chains. +- Do not add security or multi-user features. +- Do not make native streaming work; unsupported native streaming remains fail-closed. +- Do not include raw provider messages or credentials in cross-target summaries. +- Do not commit user-facing reports. + +## Current State + +- `FallbackPolicy.should_fallback()` uses group `failover_on` and `stop_on` directly, so a group can currently override auth/permanent errors into fallback eligibility. +- `_route_error_type_from_response()` only recognizes retry-like summaries and a few proxy error types; it does not robustly inspect structured status/code/details fields. +- `FallbackGroup.streaming_policy` exists but live streaming fallback does not use it. +- Streaming fallback correctly tracks visible output, but needs stronger tests for error/control frames and `never` streaming policy. +- Non-streaming fallback summaries are structural, but call sites still pass raw exception strings that should be avoided entirely. + +## Implementation Plan + +1. Add hard-stop route error categories. + - Add a non-overridable hard-stop set for auth, forbidden, invalid request, context-window, credential reauth, pre-request callback, cancellation, and configuration errors. + - Document that these are safety boundaries and group policies cannot opt into cross-target fallback for them. + +2. Normalize routing policy vocabulary. + - Add `normalize_route_error_type()` for aliases such as `auth`, `permission_denied`, `bad_request`, `validation`, `transient`, `network`, and `configuration`. + - Use it in policy, routing runner, retry helper, and executor route classification. + +3. Make hard stops win in `FallbackPolicy`. + - Streaming visible output still blocks fallback first. + - Hard-stop categories return false before group policy evaluation. + - Group stop/failover sets are normalized before matching. + +4. Validate routing group policy. + - Reject configured `failover_on` entries that normalize to hard-stop categories. + - Parse and validate `streaming_policy` from JSON routing config. + - Preserve environment override precedence. + +5. Enforce `FallbackGroup.streaming_policy` in live executor paths. + - `never` prevents streaming fallback even before output. + - `pre_output_only` remains the default. + - Trace metadata should include the policy used. + +6. Harden structured response classification. + - Inspect `error.type`, `error.code`, `error.status`, `error.details.status_code`, and detail classification fields before summaries. + - Hard-stop signals win over retryable text. + - Keep retryable classification for 429/quota/rate-limit, 5xx/server, timeout, and connection signals. + +7. Sanitize target-failure summaries. + - Keep raw messages out of fallback details and traces. + - Remove unnecessary raw string arguments from summary call sites. + - Add tests proving secrets/provider text do not appear. + +8. Improve stream fallback frame handling. + - Verify error/control frames are non-visible output. + - Test `event: error`, `type: error`, `response.failed`, `[DONE]`, comments, and heartbeat-like frames. + - Ensure visible text/tool deltas still block fallback. + +9. Add deterministic exhaustion metadata. + - Include sanitized target summaries in routing exhaustion traces for stream and non-stream paths. + - Avoid raw exception text in trace metadata. + +10. Add execution-mode safety tests. + - Explicit native configuration errors must be hard stops. + - Unsupported operation behavior must be clear and tested. + - LiteLLM fallback remains available for retryable errors. + +## Tests + +- `tests/test_fallback_policy.py` +- `tests/test_retry_policy.py` +- `tests/test_routing_config.py` +- `tests/test_fallback_resolver.py` +- `tests/test_routing_executor.py` +- `tests/test_streaming_fallback_policy.py` +- `tests/test_request_executor_native_routing.py` +- Phase 5b provider/native regression subset if executor routing helpers change. + +## Acceptance Criteria + +- Auth, forbidden, invalid request, context-window, credential reauth, pre-request callback, cancellation, and configuration errors never fallback across targets. +- Structured error responses classify deterministically with hard-stop signals taking precedence. +- Streaming fallback respects `FallbackGroup.streaming_policy`. +- Error/control stream frames before visible output do not accidentally lock routing. +- Cross-target summaries and traces remain sanitized. +- Focused tests and dual-agent review pass. diff --git a/docs/experimental/phase-6c-routing-fallback-correction.md b/docs/experimental/phase-6c-routing-fallback-correction.md new file mode 100644 index 000000000..4df2eb01f --- /dev/null +++ b/docs/experimental/phase-6c-routing-fallback-correction.md @@ -0,0 +1,62 @@ +# Phase 6c: Routing/Fallback Correctness And Stale Test Repair + +## Goal + +Close all Phase 6/6b third-pass routing findings while preserving the fallback group architecture and the hard-stop safety from Phase 6b. + +## Scope + +- Replace stale `tests/test_fallback_groups.py` with tests for the current routing package. +- Implement requested-model-in-group promotion. +- Align streaming execution-mode behavior with non-streaming behavior for `@custom`, `@native`, and `@litellm_fallback`. +- Ensure explicit native configuration errors are hard-stop configuration errors, not retryable unsupported-operation errors. +- Expand structured route-error alias normalization. +- Adjust fallback target session namespace per target. +- Add focused tests for every blocker/high/medium. + +## Non-Goals + +- Do not resurrect the old `FallbackGroupManager`. +- Do not reintroduce `RotatingClient(model_fallback_groups=...)`. +- Do not implement routing for embeddings unless it is trivial and needed by tests. +- Do not weaken hard-stop policy. +- Do not commit user-facing reports. + +## Implementation Plan + +1. Replace stale fallback tests. + - Rewrite `tests/test_fallback_groups.py` around `routing.config`, `FallbackResolver`, `FallbackAttemptRunner`, and executor helpers. + - Cover env/JSON group parsing, target promotion, execution suffix parsing, and fallback exhaustion behavior. + +2. Requested-model promotion. + - Promote a requested provider/model target to the first attempt when it appears in a fallback group. + - Preserve the relative order of all other targets. + - Apply this both for explicit `MODEL_ROUTE_* = group:name` routes and provider-prefixed requests that appear in any group. + +3. Streaming execution-mode parity. + - Share execution-mode selection logic between streaming and non-streaming paths. + - `@litellm_fallback` must force LiteLLM. + - `@custom` must require custom provider logic. + - `@native` must fail closed if native streaming is unsupported. + - `auto` remains custom first, then native streaming if explicitly supported, then LiteLLM. + +4. Explicit native configuration errors. + - Missing native declaration, unsupported operation, and missing endpoint/header helpers should raise `RoutingExecutionError(error_type="configuration_error")` for explicit native execution failures. + - Hard-stop policy then blocks fallback. + +5. Structured route-error aliases. + - Add common status-less aliases including `invalid_api_key`, `unauthorized`, `invalid_argument`, `rate_limited`, `too_many_requests`, `resource_exhausted`, `unavailable`, `deadline_exceeded`, and context-window variants. + +6. Session namespace adjustment. + - `clone_context_for_target()` should rewrite standard session-tracking namespaces to the target provider/model instead of preserving the first target namespace. + - Unknown/custom namespace shapes remain unchanged. + +## Acceptance Criteria + +- Broad test collection no longer fails on stale fallback imports. +- Requested provider-prefixed models inside a group are tried first. +- Streaming and non-streaming execution-mode selection are consistent. +- Explicit native config errors are hard-stop configuration errors. +- Common structured aliases normalize correctly. +- Fallback target contexts do not reuse the first target's provider/model session namespace. +- Focused tests pass and both reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/phase-7-retry-cooldown-failover.md b/docs/experimental/phase-7-retry-cooldown-failover.md new file mode 100644 index 000000000..ffcfba144 --- /dev/null +++ b/docs/experimental/phase-7-retry-cooldown-failover.md @@ -0,0 +1,253 @@ +# Phase 7 Plan: Retry/Cooldown/Failover Cleanup + +## Goal + +Harden the retry and failover layer now that ordered routing exists. Phase 7 makes provider-level cooldown functional, reduces fallback policy duplication, preserves the current retry-after parser, and improves final error visibility across target chains without replacing `UsageManager`, `SessionTracker`, credential rotation, or the existing retry classifier. + +## Non-Goals + +- Do not replace `UsageManager`, `SelectionEngine`, `SessionTracker`, or the retry-after parser. +- Do not rewrite the entire executor. +- Do not implement the Phase 8 streaming transport/library overhaul. +- Do not implement native streaming dispatch unless a small safe seam is necessary. +- Do not add SQLite or any database. +- Do not implement full JSON config merging; Phase 10 owns config polish. +- Do not make fallback happen after visible streamed output. +- Do not silently fallback to LiteLLM unless the route target explicitly selects `litellm_fallback`. + +## Current Code Context + +- `CooldownManager` exists and `_wait_for_cooldown()` is called before target attempts, but `start_cooldown()` is not called by the executor, so provider cooldown is effectively dormant. +- Existing retry-after parsing lives in `error_handler.py` and is strong; keep it. +- Existing `RequestExecutor._handle_error_with_context()` handles non-streaming per-credential retries and rotations. +- Streaming has duplicated retry/rotation handling in several exception branches. +- Phase 6 added `FallbackPolicy`, `FallbackAttemptRunner`, `clone_context_for_target()`, and executor inline fallback loops. +- Phase 6 final review found no blockers but noted two useful cleanup targets: executor inline loops do not pass `FallbackGroup` overrides to `FallbackPolicy`, and `FallbackAttemptRunner` is tested but not used by live executor. +- Phase 6 response-based fallback maps exhausted structured errors to retryable categories, but final target-chain failures do not yet summarize all target failures. +- Provider-level cooldown should apply to provider-wide or IP/global throttling, not every small per-key retry-after. + +## Files To Add + +- `src/rotator_library/retry_policy.py` +- `tests/test_retry_policy.py` +- `tests/test_cooldown_activation.py` +- `tests/test_request_executor_fallback_error_summary.py` +- Possibly `tests/test_request_executor_group_policy.py` + +## Files Likely To Touch + +- `src/rotator_library/client/executor.py` +- `src/rotator_library/cooldown_manager.py` +- `src/rotator_library/routing/types.py` +- `src/rotator_library/routing/executor.py` +- `src/rotator_library/core/types.py` +- `src/rotator_library/error_handler.py` only if docstring/type comments need alignment; do not weaken parser behavior. +- Existing Phase 6 tests if group policy wiring changes expected trace order. + +## Retry Policy Foundation + +Add a small `retry_policy.py` module to centralize decisions that are currently duplicated: + +- `classify_route_error(error, provider)` +- `should_provider_cooldown(classified_error, *, small_cooldown_threshold, provider_cooldown_threshold)` +- `provider_cooldown_duration(classified_error, default_duration)` +- `should_retry_same_credential(classified_error, small_cooldown_threshold)` +- `should_rotate_credential(classified_error)` +- `is_target_failover_eligible(error_type, group=None, stream=False, emitted_output=False)` + +This module should call existing `classify_error()`, `should_retry_same_key()`, and `should_rotate_on_error()` rather than reimplementing them. + +It should document why small retry-after values stay on the same credential while large/provider-level values can activate provider cooldown. + +It should not own sleeping or mutation; it only returns decisions. + +## Cooldown Activation + +Add provider cooldown activation in non-streaming error handling: + +- If `classified.retry_after` exists and is above `SMALL_COOLDOWN_RETRY_THRESHOLD`, start provider cooldown for that duration when the error is provider-wide enough. +- Candidate categories: `rate_limit`, `server_error`, and possibly provider-specific `quota_exceeded` only when retry-after suggests global reset rather than per-key quota. +- Be conservative: avoid cooling down an entire provider for every credential quota if the error looks per-credential. + +Add env knobs: + +- `PROVIDER_COOLDOWN_MIN_SECONDS` default perhaps same as small-cooldown threshold or a modest value. +- `PROVIDER_COOLDOWN_DEFAULT_SECONDS` for retryable provider-wide errors without retry-after. +- `PROVIDER_COOLDOWN_ON_QUOTA` default false unless evidence indicates provider-wide quota. + +Existing `_wait_for_cooldown()` stays the wait path. + +`CooldownManager.start_cooldown()` should extend only when the new expiry is later than the current expiry, not shorten an active cooldown. + +Add trace pass: + +- `provider_cooldown_started` +- metadata: provider, duration, error_type, retry_after_present, reason. + +Logging failure to start cooldown should never fail the request. + +## Fallback Runner And Live Loop Cleanup + +Either wire `FallbackAttemptRunner` into live non-streaming and streaming wrappers or keep inline wrappers but pass a resolved group object. + +Minimal preferred approach for Phase 7: + +- Add `routing_group` optional field to `RequestContext` or attach group policy data to context. +- `RequestContextBuilder` should carry the `FallbackGroup` resolved from config when routing is active. +- Executor fallback wrappers pass `group=context.routing_group` to `FallbackPolicy.should_fallback()`. +- Keep inline wrappers if that avoids contorting trace behavior. + +If refactoring to `FallbackAttemptRunner` stays small, do it; otherwise leave runner as tested support and make inline loops group-aware. + +Add tests proving group overrides are honored in live executor wrappers for non-streaming and streaming. + +## Cross-Target Error Summary + +Add a target failure accumulator for fallback groups: + +- target name +- provider +- model +- execution mode +- error type +- message summary +- emitted_output for streams + +If all fallback targets fail, final client error should include `fallback_targets` details under `error.details`, with no credentials. + +Preserve existing per-target credential error summaries from `RequestErrorAccumulator`. + +Do not expose credential secrets or raw request data. + +Add trace pass: + +- `routing_fallback_exhausted` +- include accumulated target failures. + +Tests should cover: + +- all targets fail and final response contains both target failures. +- exception failure on first target and structured proxy error on second target are both summarized. +- no credentials in summary. + +## Provider Cooldown And Fallback Interaction + +- If a target enters provider cooldown and still fails/exhausts, fallback to next target may proceed if policy allows. +- If a provider is already cooling down, `_wait_for_cooldown()` should wait only if there is request budget; otherwise target should fail fast in a way fallback can interpret. +- Consider returning/raising a classified `provider_cooldown_budget_exceeded` or mapping to `rate_limit` for target fallback. +- Keep this minimal: do not replace capacity waiting or usage limits. + +## Streaming Retry/Fallback Cleanup + +- Keep Phase 6 invariant: fallback only before visible output. +- Deduplicate obvious streaming error branches only if the change is small and tests cover it. +- Add cooldown activation in streaming error handling where retry-after is available and no visible output has been emitted. +- Do not implement native streaming execution-mode dispatch in Phase 7 unless reviewers insist; Phase 8 owns streaming transport design. +- Add provider cooldown trace for streaming provider-wide throttles. + +## Error-Classification Alignment + +Add tests tying actual `classify_error()` output to fallback policy decisions for: + +- `rate_limit` +- `quota_exceeded` +- `server_error` +- `api_connection` +- `authentication` +- `forbidden` +- `invalid_request` +- `context_window_exceeded` +- `credential_reauth_needed` +- `pre_request_callback_error` +- `cancelled` +- `unsupported_operation` + +Preserve conservative `unknown` behavior unless a test shows a strong reason to fallback. + +## Stale Fallback Test Cleanup + +There is a stale ignored test file `tests/test_fallback_groups.py` that imports a non-existent older module. + +- Do not delete it unless it is tracked and relevant to current runs. +- If it is tracked or causing collection failures in broader test runs, replace/archive it with current routing tests in a focused commit. +- If ignored/untracked, leave it alone unless it blocks CI. + +## Transform Trace Requirements + +Keep existing Phase 6 traces: + +- `routing_decision` +- `routing_target_attempt_started` +- `routing_target_attempt_failed` +- `routing_target_attempt_succeeded` +- `routing_fallback_selected` +- `routing_fallback_exhausted` +- `routing_litellm_fallback` +- `routing_native_execution_selected` +- stream equivalents. + +Add: + +- `provider_cooldown_started` +- `provider_cooldown_skipped` +- `routing_group_policy_applied` if group overrides influence a decision. + +Trace metadata must not include credentials or secrets. + +## Testing Plan + +Retry policy tests: + +- existing classifier output maps to expected retry/rotate/fallback/cooldown decisions. +- small retry-after chooses same-credential retry, not provider cooldown. +- large retry-after can start provider cooldown. +- `unknown` remains conservative for target fallback. + +Cooldown tests: + +- `CooldownManager.start_cooldown()` extends but does not shorten cooldown. +- non-streaming large retry-after starts provider cooldown. +- small retry-after does not start provider cooldown. +- cooldown trace emitted. +- `_wait_for_cooldown()` respects deadline budget. + +Fallback group policy tests: + +- live non-streaming fallback honors group-specific `failover_on`. +- live non-streaming fallback honors group-specific `stop_on`. +- live streaming fallback honors group policy before output. + +Error summary tests: + +- all target failures are summarized in final error details. +- target summary excludes credentials. + +Regression tests: + +- Phase 1 protocol tests. +- Phase 2 transform logging tests. +- Phase 3 adapter/field-cache tests. +- Phase 4 Responses tests. +- Phase 5 provider/native tests. +- Phase 6 routing tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add retry policy helper module and classifier-alignment tests. +2. Harden `CooldownManager` extension semantics and tests. +3. Wire provider cooldown activation into non-streaming executor with trace tests. +4. Add group-policy awareness to live fallback wrappers with tests. +5. Add cross-target error summaries with tests. +6. Add streaming cooldown/group-policy cleanup tests if not covered earlier. +7. Run focused and regression tests. +8. Review with `explore` and `explore-heavy`; fix findings; write uncommitted Phase 7 report. + +## Risks And Mitigations + +- Provider cooldown could over-throttle healthy credentials. Mitigation: only activate on large retry-after/provider-wide categories, keep quota cooldown disabled by default unless configured. +- Refactoring fallback loops could break trace order. Mitigation: prefer small group-aware changes unless runner integration stays simple. +- Error summaries could leak credentials. Mitigation: summarize target/provider/model/error only; never include credentials. +- Streaming fallback could corrupt output if changed carelessly. Mitigation: preserve Phase 6 visible-output gate and add tests. +- Retry policy could drift from existing parser. Mitigation: call existing parser/classifier helpers instead of rewriting them. diff --git a/docs/experimental/phase-7b-retry-cooldown-backoff.md b/docs/experimental/phase-7b-retry-cooldown-backoff.md new file mode 100644 index 000000000..5ceb14ff7 --- /dev/null +++ b/docs/experimental/phase-7b-retry-cooldown-backoff.md @@ -0,0 +1,81 @@ +# Phase 7b: Retry/Cooldown Backoff And Failure History + +## Goal + +Correct the Phase 7 audit findings. Phase 7 made provider cooldown activation real and centralized retry/failover decisions, but it still lacks provider/model cooldown scopes, scoped backoff, and structured in-memory failure history. + +## Non-Goals + +- Do not replace `UsageManager`, credential cooldowns, usage windows, `classify_error()`, or retry-after parsing. +- Do not introduce SQLite or a new persistence database. +- Do not make native streaming safe; unsupported native streaming remains fail-closed. +- Do not change Phase 6b fallback hard-stop behavior. +- Do not commit user-facing reports. + +## Current State + +- `CooldownManager` tracks provider cooldowns only. +- `RequestExecutor._wait_for_cooldown()` waits on provider cooldown only. +- `retry_policy.decide_provider_cooldown()` returns provider-level decisions without model scope or failure-history context. +- `MODEL_CAPACITY_EXHAUSTED` is noticed in logs but still behaves like a generic provider server error. +- There is no structured provider/model failure-history ring for future observability and bounded backoff decisions. + +## Implementation Plan + +1. Extend `CooldownManager` with scoped cooldown methods. + - Preserve `start_cooldown()`, `is_cooling_down()`, and `get_remaining_cooldown()`. + - Add provider/model scoped methods and extend-only semantics per scope key. + - Provider cooldown blocks all models; model cooldown blocks only that provider/model. + +2. Update executor cooldown waiting. + - Pass model into cooldown wait paths. + - Wait for the max of provider and model cooldown remaining. + - Trace waits without credentials or raw provider text. + +3. Add model-capacity detection. + - Detect `MODEL_CAPACITY_EXHAUSTED`, model capacity, and capacity-exhausted signals from exceptions and dict payloads. + - Keep `error_type="server_error"` for compatibility, but choose model cooldown scope. + +4. Extend cooldown decisions. + - Add `scope`, `model`, and optional backoff metadata to `ProviderCooldownDecision`. + - Large retry-after rate limits stay provider-scoped. + - Model-capacity failures become model-scoped. + - Quota cooldown remains disabled by default. + +5. Add in-memory failure history. + - Bounded ring with timestamp, provider, model, error type, scope, duration, and reason. + - No disk persistence. + - Executor records successful cooldown starts for future observability and backoff. + +6. Add bounded repeated-transient backoff. + - Track repeated transient `server_error`/`api_connection` failures within a configurable window. + - Escalate cooldown duration conservatively up to a max. + - Do not backoff hard-stop categories. + +7. Wire scoped cooldown start in `RequestExecutor`. + - Use scoped cooldown methods when available and fall back to provider-only fakes in tests. + - Keep existing trace pass names with added scope/model metadata. + +8. Update streaming error decisions. + - Keep `decide_streaming_error_action()` side-effect-free. + - Include cooldown scope/model in the decision. + - Visible output still suppresses provider/model cooldown. + +## Tests + +- `tests/test_cooldown_activation.py` +- `tests/test_retry_policy.py` +- `tests/test_streaming_error_handler.py` +- Executor-focused cooldown/trace tests. +- Phase 6b routing regression subset. + +## Acceptance Criteria + +- Provider cooldown behavior remains backwards-compatible and extend-only. +- Model cooldowns do not block unrelated models on the same provider. +- Provider cooldowns still block all models for that provider. +- Model-capacity errors produce model-scoped cooldown/backoff. +- Repeated transient failures can produce bounded scoped backoff from in-memory history. +- Cooldown starts/waits are traceable and sanitized. +- Streaming cooldown decisions expose scope and respect visible-output blocking. +- Focused tests and dual-agent review pass. diff --git a/docs/experimental/phase-7c-retry-cooldown-correction.md b/docs/experimental/phase-7c-retry-cooldown-correction.md new file mode 100644 index 000000000..8fba3b192 --- /dev/null +++ b/docs/experimental/phase-7c-retry-cooldown-correction.md @@ -0,0 +1,57 @@ +# Phase 7c: Cooldown Budget, Transient Backoff, And Attempt History + +## Goal + +Close all Phase 7/7b third-pass retry/cooldown findings while preserving existing retry-after parsing and credential-rotation semantics. + +## Scope + +- Fail fast when an active provider/model cooldown exceeds the remaining request deadline budget. +- Prevent a single generic `server_error` / `api_connection` without `retry_after` from starting provider-wide cooldown. +- Record repeated transient failures into `FailureHistory` even when the cooldown threshold has not yet been reached. +- Clear or reduce failure history on successful provider/model calls. +- Populate structured `routing_attempt_history` for live non-streaming and streaming fallback attempts. +- Add tests for cooldown-over-budget, single transient no-cooldown, repeated transient cooldown, success reset, model-scoped cooldown isolation, and stream parity. + +## Non-Goals + +- Do not replace the retry-after parser. +- Do not replace `UsageManager` or `SessionTracker`. +- Do not introduce durable failure-history storage. +- Do not change small retry-after same-credential behavior. +- Do not weaken fallback hard-stop policy from Phase 6c. + +## Implementation Plan + +1. Cooldown-over-budget fail-fast. + - Raise a retryable routing error instead of returning when cooldown remaining exceeds the request deadline budget. + - Trace `cooldown_wait_exceeds_budget` without exposing credentials or provider payloads. + +2. Generic transient cooldown threshold. + - No-`retry_after` `server_error` and `api_connection` should start cooldown only after `FailureHistory.backoff_for()` crosses the configured threshold. + - Large explicit `retry_after` behavior remains unchanged. + +3. Record skipped transient failures. + - Record sanitized transient entries when cooldown is skipped because the backoff threshold is not yet met. + - Keep history bounded and in-memory only. + +4. Success reset. + - Add a clear/reset helper on `FailureHistory`. + - Clear matching provider/model transient entries after successful non-streaming and completed streaming calls. + +5. Structured attempt history. + - Append sanitized attempt entries for fallback failures and successes into `RequestContext.routing_attempt_history`. + - Include error type, target identity, execution mode, output visibility, status code when available, fallback decision, and timing where cheap. + +6. Model cooldown isolation. + - Add tests proving model-scoped cooldowns block only the matching model while provider-scoped cooldowns block all provider models. + +## Acceptance Criteria + +- Cooldown that exceeds request budget does not silently allow execution. +- A single generic transient without retry-after does not start provider-wide cooldown. +- Repeated transient failures can still start bounded cooldown/backoff. +- Successful calls clear matching failure-history entries. +- Live routing attempt history is populated and sanitized. +- Model-scoped cooldown blocks only its model plus provider-wide cooldown still blocks all provider models. +- Focused tests pass and both reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/phase-8-streaming-library-upgrade.md b/docs/experimental/phase-8-streaming-library-upgrade.md new file mode 100644 index 000000000..86c1c7fbc --- /dev/null +++ b/docs/experimental/phase-8-streaming-library-upgrade.md @@ -0,0 +1,329 @@ +# Phase 8 Plan: Streaming Library Upgrade + +## Goal + +Turn streaming into a reusable, transport-aware library layer instead of scattered executor/provider-specific logic. Phase 8 should preserve the current HTTP SSE behavior, add measured stream lifecycle events, consolidate retry/error handling where safe, expose a WebSocket-ready transport seam, and prepare native provider streaming to share the same policies. It must keep the Phase 6/7 invariants: no fallback after visible output, provider cooldown only before visible output or error-only chunks, and no replacement of `UsageManager`, `SessionTracker`, or retry-after parsing. + +## Non-Goals + +- Do not replace `UsageManager`, `SessionTracker`, `SelectionEngine`, or retry-after parsing. +- Do not implement a live WebSocket route yet unless it is strictly a disabled/placeholder seam. +- Do not migrate every provider to native streaming in one pass. +- Do not alter non-streaming request behavior. +- Do not add SQLite or any database. +- Do not implement full cost/quota overhaul; Phase 9 owns usage/quota/cost. +- Do not change client-visible SSE formats except to fix bugs or add compatible error/end events. +- Do not fallback to another target after visible model output. + +## Current Code Context + +- `RequestExecutor._execute_streaming()` owns the active streaming credential retry loop and still has four duplicated exception branches. +- `RequestExecutor._execute_streaming_with_fallback()` wraps target fallback and already tracks visible output. +- Phase 7 added `_stream_chunk_is_visible_output()` and `_can_start_stream_provider_cooldown()` safety helpers. +- `StreamingHandler.wrap_stream()` converts LiteLLM chunks to chat-completions SSE strings, records usage after completion, records response anchors, and handles client disconnect. +- `stream_retry_policy.py` owns reasoning-only retry safety for post-output stream failures. +- `ResponsesService.stream_response()` converts chat stream chunks into Responses SSE events, stores final streamed responses, and has a `ResponsesWebSocketFormatter` placeholder. +- `NativeProviderExecutor.stream()` can stream mocked native provider JSON-line chunks and trace raw/parsed/formatted events, but live executor streaming does not route through native execution modes yet. +- Transform trace already has stream pass names from Phase 2 and later phases. +- There is no central stream metrics object for TTFB/TTFT/stalls/cancellation. + +## Files To Add + +- `src/rotator_library/streaming/__init__.py` +- `src/rotator_library/streaming/events.py` +- `src/rotator_library/streaming/metrics.py` +- `src/rotator_library/streaming/transport.py` +- `src/rotator_library/streaming/policy.py` +- `src/rotator_library/streaming/errors.py` +- `tests/test_stream_events.py` +- `tests/test_stream_metrics.py` +- `tests/test_stream_transport.py` +- `tests/test_stream_policy.py` +- `tests/test_streaming_error_handler.py` +- `tests/test_request_executor_stream_metrics.py` +- `tests/test_native_streaming_transport_seam.py` + +## Files Likely To Touch + +- `src/rotator_library/client/executor.py` +- `src/rotator_library/client/streaming.py` +- `src/rotator_library/client/stream_retry_policy.py` +- `src/rotator_library/native_provider/executor.py` +- `src/rotator_library/native_provider/streaming.py` +- `src/rotator_library/responses/streaming.py` +- `src/rotator_library/responses/service.py` +- `src/proxy_app/main.py` only if route headers/cancellation handling need a small SSE-compatible fix. + +## Streaming Event Model + +Add `StreamEvent` dataclass: + +- `event_type`: `started`, `raw_chunk`, `parsed_chunk`, `delta`, `reasoning_delta`, `tool_delta`, `usage`, `error`, `completed`, `cancelled`, `heartbeat`, `metadata`. +- `protocol`: `openai_chat`, `responses`, `anthropic_messages`, `gemini`, `native`, `litellm_fallback`, or provider-specific. +- `transport`: `sse`, `websocket`, `jsonl`, or future string. +- `data`: JSON-safe event payload. +- `raw`: optional raw chunk for trace only, sanitized/serialized. +- `metadata`: JSON-safe metadata. +- `visible_output`: bool. +- `timestamp_utc`. + +Add helpers: + +- `stream_event_from_sse_chunk()` +- `stream_event_to_sse()` +- `stream_event_to_websocket_message()` placeholder/seam. + +Keep formatters small and override-friendly. + +## Transport Abstraction + +Add `StreamTransportFormatter` base/protocol: + +- `format_event(event)` +- `format_error(error_event)` +- `format_done()` +- `is_terminal_event(event)` + +Add `SSEStreamFormatter` for existing HTTP SSE output. + +Add `WebSocketStreamFormatter` placeholder that formats JSON messages but is not wired to a route yet. + +Add `JSONLineStreamFormatter` for native provider/internal stream tests if useful. + +Responses API can keep `ResponsesSSEFormatter` but should share the transport interface or adapt to it. + +## Metrics And Lifecycle + +Add `StreamMetrics`: + +- `started_at` +- `first_byte_at` +- `first_visible_output_at` +- `last_chunk_at` +- `completed_at` +- `chunk_count` +- `visible_chunk_count` +- `error_count` +- `cancelled` +- properties: `ttfb_seconds`, `ttft_seconds`, `duration_seconds`, `idle_seconds`. + +Add `StreamMonitor`: + +- records raw chunk, formatted chunk, visible output, errors, completion, cancellation. +- can detect stall if `time_since_last_chunk > stall_timeout`. + +Env knobs: + +- `STREAM_TTFB_TIMEOUT_SECONDS` optional/disabled by default. +- `STREAM_STALL_TIMEOUT_SECONDS` optional/disabled by default. +- `STREAM_HEARTBEAT_SECONDS` optional/disabled by default. + +Phase 8 should add metrics and trace events first. Enforced timeouts can be opt-in to avoid behavior surprises. + +## Stream Policy + +- Move/re-export `can_retry_stream_after_error()` into the new streaming policy package, preserving current behavior. +- Keep `stream_retry_policy.py` as a compatibility wrapper if imports exist. +- Add visible-output detection policy: + - Chat-completions content/tool/function deltas are visible. + - Reasoning-only deltas are not visible for fallback only if the existing env allows reasoning-only retry. + - Error chunks and `[DONE]` are not visible. + - Responses `response.output_text.delta` is visible. + - Responses `response.failed` is not visible by itself. +- Add tests for malformed chunks failing closed. + +## Streaming Error Handling + +Add `StreamingErrorDecision` dataclass: + +- `classified` +- `action`: `retry_same`, `rotate`, `fail`, `fallback_allowed`, `fallback_blocked_after_output` +- `start_provider_cooldown`: bool +- `provider_cooldown_duration` +- `reason` + +Add helper that consumes: + +- exception +- provider +- last_streamed_chunk +- attempt +- max_retries +- deadline +- allow_reasoning_only_retry +- retry/cooldown env settings + +It should use existing `classify_error()`, `should_retry_same_key()`, `should_rotate_on_error()`, and Phase 7 retry policy. + +It should not sleep or mutate credential state; executor still owns those side effects. + +Refactor `_execute_streaming()` gradually: + +- First introduce helper and tests. +- Then replace duplicated branch decision logic where safe. +- Preserve exact output semantics and trace ordering. + +If full deduplication is too risky, use helper for cooldown/visibility/fallback decisions but keep branch structure. + +## Fallback And Cooldown Invariants + +- Fallback to next target is allowed only before visible output. +- Provider cooldown can start only before visible output or after error-only chunks. +- Same-credential retry after reasoning-only chunks remains controlled by `STREAM_RETRY_ON_REASONING_ONLY` behavior from current code. +- If a stream emits visible output then errors: + - no cross-target fallback. + - return/raise the current upstream stream failure behavior. + - emit `routing_stream_fallback_blocked_after_output` when in a fallback group. +- Preserve Phase 7 tests. + +## Native Provider Streaming + +- Add a streaming mode seam to `NativeProviderExecutor` that can return `StreamEvent`s or formatted SSE through the common formatter. +- Live `RequestExecutor` does not need to route native streaming fully unless safe, but the mode should be testable: + - `execution=native` streaming target can call native executor stream when provider declares native streaming support. + - if not supported, fail with `unsupported_operation` before output so fallback can choose next target. +- Add provider interface optional method only if needed: + - `supports_native_streaming(model, operation)` + - default false/no-op. +- This should prepare Claude Code/Codex/Copilot/Antigravity skeletons without claiming live support. + +## Responses Streaming Integration + +- Keep current `/v1/responses` SSE behavior. +- Share metrics/trace helpers where possible. +- Preserve stored final streamed response behavior. +- Add tests: + - Responses stream emits visible-output metrics. + - failed stream records error metrics. + - WebSocket formatter seam can format equivalent event payloads but no route is exposed. +- Do not change stored response shape unless necessary. + +## Transaction Trace Requirements + +Add or standardize stream trace passes: + +- `stream_started` +- `stream_first_byte` +- `stream_first_visible_output` +- `stream_stall_detected` +- `stream_heartbeat_sent` +- `stream_completed` +- `stream_cancelled` +- `stream_error_decision` +- `stream_metrics_final` + +Existing pass names stay: + +- `raw_stream_chunk` +- `parsed_stream_chunk` +- `assembled_stream_response` +- provider/native/responses stream pass names. + +Trace metrics must not include raw credentials or auth headers. + +Raw chunks should continue through existing redaction/serialization. + +## Client Disconnect And Cancellation + +- `StreamingHandler.wrap_stream()` already checks `request.is_disconnected()`. +- Add metrics event for cancellation/disconnect. +- Add trace `stream_cancelled`. +- Do not mark success on cancellation. +- Preserve current behavior for partial streams. + +## Heartbeats And Stall Detection + +Implement monitor primitives and tests first. + +Optional runtime heartbeat support: + +- if `STREAM_HEARTBEAT_SECONDS > 0`, emit SSE comment heartbeat `: keep-alive\n\n` or compatible event only when no provider chunk has arrived. +- default disabled to avoid client behavior changes. + +Optional stall detection: + +- if `STREAM_STALL_TIMEOUT_SECONDS > 0`, classify as transient stream failure before visible output or fail after visible output. +- default disabled. + +Tests can use fake clocks. + +## Testing Plan + +Stream event tests: + +- SSE chunk to event parsing. +- Responses event visibility. +- malformed chunk fails closed. +- event-to-SSE formatting. +- WebSocket formatter seam output shape. + +Metrics tests: + +- TTFB and TTFT calculations. +- chunk counts. +- cancellation. +- stall detection with fake clock. + +Policy tests: + +- current reasoning-only retry behavior preserved. +- visible output detection for text/tool deltas. +- error and done chunks not visible. + +Error handler tests: + +- large retry-after before output starts cooldown decision. +- visible output blocks fallback/cooldown. +- transient errors choose same-key retry before max retry. +- permanent errors fail. + +Executor tests: + +- streaming metrics trace passes emitted. +- cancellation trace emitted. +- pre-output fallback still works. +- post-output fallback remains blocked. +- provider cooldown still starts before output. + +Native stream tests: + +- native streaming emits common stream events/trace. +- unsupported native streaming fails before output and can fallback. + +Responses streaming tests: + +- existing route behavior unchanged. +- metrics final trace emitted. + +Regression tests: + +- Phase 1 protocol tests. +- Phase 2 transform logging tests. +- Phase 3 adapter/cache tests. +- Phase 4 Responses tests. +- Phase 5 provider/native tests. +- Phase 6 routing tests. +- Phase 7 retry/cooldown tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add streaming event/transport/metrics primitives with tests. +2. Move/re-export stream retry/visibility policy with tests. +3. Add streaming error decision helper with tests. +4. Integrate stream metrics and trace passes into `StreamingHandler` and/or executor with tests. +5. Refactor streaming error branches only as far as tests make safe. +6. Add native streaming seam/support detection tests. +7. Add Responses streaming metrics/transport seam tests. +8. Run focused and regression tests. +9. Review with `explore` and `explore-heavy`; fix findings; write uncommitted Phase 8 report. + +## Risks And Mitigations + +- Stream behavior is client-visible. Mitigation: keep SSE output format unchanged by default and add primitives before enforcement. +- Timeout/stall handling can break long reasoning streams. Mitigation: default stall/TTFB enforcement disabled; only metrics on by default. +- Refactoring executor streaming can regress retry semantics. Mitigation: keep branch structure unless shared helper is proven by tests. +- WebSocket support can be over-promised. Mitigation: formatter seam only, no route unless explicitly implemented later. +- Visible-output detection can be too permissive. Mitigation: fail closed on malformed/ambiguous chunks and preserve current tests. +- Metrics can leak data if raw chunks are logged. Mitigation: trace summaries and use existing redaction/serialization paths. diff --git a/docs/experimental/phase-8b-streaming-hardening.md b/docs/experimental/phase-8b-streaming-hardening.md new file mode 100644 index 000000000..932bec89a --- /dev/null +++ b/docs/experimental/phase-8b-streaming-hardening.md @@ -0,0 +1,84 @@ +# Phase 8b: Streaming Hardening, Cancellation, Heartbeats, And Stall Policy + +## Goal + +Correct the Phase 8 audit findings. Phase 8 added stream events, metrics, formatters, and observability, but stream hardening still needs upstream cancellation, active TTFB/stall policy, heartbeat support, and generic native HTTP streaming support. + +## Non-Goals + +- Do not enable native streaming for priority providers from Phase 5b. +- Do not rewrite the entire streaming executor or change existing chat-completions SSE chunk format by default. +- Do not implement a WebSocket FastAPI route. +- Do not weaken Phase 6b fallback visible-output safety or Phase 7b cooldown/retry latch behavior. +- Do not introduce persistence. +- Do not commit user-facing reports. + +## Current State + +- `StreamMonitor` records TTFB, TTFT, chunk counts, errors, cancellation, and stall status, but no active timeout policy uses it. +- `StreamingHandler.wrap_stream()` detects client disconnect but does not guarantee upstream iterator closure. +- No heartbeat frames are emitted during long waits between chunks. +- `NativeHTTPTransport.stream_json_lines()` requires custom injected clients to expose `stream_json_lines()` and does not support generic `httpx.AsyncClient.stream()`. +- Existing tests cover metrics and fallback, but not upstream cancellation, heartbeats, TTFB timeout, stall timeout, or generic native stream transport. + +## Implementation Plan + +1. Extend stream runtime settings. + - Add `ttfb_timeout_seconds`, `stall_timeout_seconds`, `heartbeat_interval_seconds`, and `cancel_upstream_on_disconnect`. + - Add env overrides: `STREAM_TTFB_TIMEOUT_SECONDS`, `STREAM_STALL_TIMEOUT_SECONDS`, `STREAM_HEARTBEAT_INTERVAL_SECONDS`, `STREAM_CANCEL_UPSTREAM_ON_DISCONNECT`. + - Keep defaults behavior-compatible: no timeout/heartbeat unless configured; upstream cancellation enabled on disconnect. + +2. Add upstream stream close helper. + - Close via `aclose()` when available, otherwise `close()`. + - Use it on client disconnect, cancellation, or abnormal stream exit. + - Log close failures only at debug/trace level. + +3. Add heartbeat formatting. + - Add `format_heartbeat()` to SSE/WebSocket/JSONL formatters. + - SSE heartbeat is a comment frame such as `: heartbeat\n\n`. + - Heartbeats must not count as visible output, session evidence, or usage. + +4. Add heartbeat emission in `StreamingHandler.wrap_stream()`. + - When configured, wait for upstream chunks with heartbeat interval timeout and yield heartbeat comments while waiting. + - Default remains no heartbeat. + +5. Add TTFB timeout policy. + - If configured and no first byte arrives in time, raise `StreamedAPIError` with structured `api_connection` timeout payload. + - This occurs before visible output so existing retry/fallback can apply. + +6. Add stall timeout policy. + - If configured and no chunk arrives for the configured interval after first byte, raise `StreamedAPIError` with structured `api_connection` timeout payload. + - Phase 7b visible-output latch must still suppress retry/fallback/cooldown if output was already visible. + +7. Add native `httpx` stream support. + - Keep custom `stream_json_lines()` support first. + - Otherwise use `client.stream("POST", ...)` with `aiter_lines()` or `aiter_bytes()` fallback. + - Preserve `[DONE]`, ignore empty lines, and parse `data:` JSON when possible. + +8. Add lifecycle trace events. + - `stream_heartbeat` + - `stream_ttfb_timeout` + - `stream_stall_timeout` + - `stream_upstream_cancelled` + - `stream_upstream_close_failed` + - Keep snapshots disabled and metadata sanitized. + +## Tests + +- Formatter heartbeat tests. +- Streaming handler disconnect/upstream close tests. +- Heartbeat interval tests. +- TTFB timeout tests. +- Stall timeout tests before and after visible output. +- Native HTTP transport tests for `httpx`-style streaming. +- Phase 7b retry/cooldown/routing regression subset. + +## Acceptance Criteria + +- Client disconnect closes upstream async streams when possible. +- Heartbeats are supported, disabled by default, and non-visible. +- Configured TTFB/stall timeouts produce structured stream errors. +- Visible-output latch still prevents retry/fallback/cooldown after output. +- Native HTTP transport supports generic `httpx` streaming plus custom test clients. +- Stream traces include heartbeat/timeout/cancel metadata without secrets. +- Focused tests and dual-agent review pass. diff --git a/docs/experimental/phase-8c-streaming-runtime-hardening.md b/docs/experimental/phase-8c-streaming-runtime-hardening.md new file mode 100644 index 000000000..fa731ef9a --- /dev/null +++ b/docs/experimental/phase-8c-streaming-runtime-hardening.md @@ -0,0 +1,60 @@ +# Phase 8c: Responses Stream Runtime Hardening And Anthropic Close Safety + +## Goal + +Close all Phase 8/8b third-pass streaming findings while preserving the transport-neutral Responses stream model and the existing Anthropic compatibility wrapper. + +## Scope + +- Apply Phase 8b stream runtime settings directly in `ResponsesService.stream_events()`. +- Add Responses heartbeat formatting support so SSE wrappers can emit non-visible heartbeat frames. +- Enforce Responses TTFB timeout before first upstream chunk. +- Enforce Responses stall timeout after a prior upstream chunk/visible output. +- Close upstream Responses chat streams on client disconnect, timeout, or abnormal exit. +- Keep Responses cost-comment handling for Phase 9c, but avoid making heartbeat/cost metadata visible output. +- Ensure Anthropic compatibility streaming always attempts to close upstream on disconnect and wrapper exit, including generator-style streams that only expose close on the iterator. + +## Non-Goals + +- Do not add WebSocket routes. +- Do not replace the Responses chat bridge with native Responses execution. +- Do not change public Responses event ordering except inserting SSE comment heartbeats when configured. +- Do not make timeout defaults active; all timeout/heartbeat knobs remain opt-in through existing config. +- Do not solve Phase 9 cost-comment propagation here except preserving metadata as non-visible. + +## Implementation Plan + +1. Responses heartbeat formatter. + - Add heartbeat handling to `ResponsesSSEFormatter` and `ResponsesWebSocketFormatter`. + - Represent heartbeat as a transport-neutral non-terminal `ResponsesStreamEvent` with `event_name="heartbeat"`. + - SSE heartbeat output is `": heartbeat\n\n"`. + +2. Responses runtime settings. + - Load `get_stream_runtime_settings()` inside `ResponsesService.stream_events()`. + - Use polling around upstream `__anext__()` to enforce optional TTFB and stall timeouts and emit optional heartbeats. + - Keep all settings disabled by default. + +3. Upstream close helper. + - Close upstream iterator/stream with `aclose()` or `close()` on disconnect, timeout, or abnormal exit. + - Trace close/close-failure events. + +4. Disconnect detection. + - Poll `request.is_disconnected()` while waiting for upstream chunks. + - If disconnected, close upstream when configured and stop without emitting model output. + +5. Anthropic close safety. + - Track both the async iterator and original OpenAI stream as close candidates. + - Ensure disconnect and wrapper exit close whichever object exposes `aclose()` / `close()`. + +6. Tests. + - Responses heartbeat, TTFB timeout, stall timeout after output, disconnect close, and heartbeat SSE comment tests. + - Anthropic iterator-only upstream close test. + - Broader streaming/Responses/Anthropic regression slice. + +## Acceptance Criteria + +- Responses streaming honors `STREAM_TTFB_TIMEOUT_SECONDS`, `STREAM_STALL_TIMEOUT_SECONDS`, `STREAM_HEARTBEAT_INTERVAL_SECONDS`, and `STREAM_CANCEL_UPSTREAM_ON_DISCONNECT`. +- Responses heartbeat frames are non-visible SSE comments. +- Responses upstream chat streams are closed on disconnect/timeout/abnormal exit. +- Anthropic compatibility streaming closes upstream even when close is exposed only on the async iterator. +- Focused tests pass and both reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/phase-9-usage-quota-cost.md b/docs/experimental/phase-9-usage-quota-cost.md new file mode 100644 index 000000000..66cb18273 --- /dev/null +++ b/docs/experimental/phase-9-usage-quota-cost.md @@ -0,0 +1,331 @@ +# Phase 9 Plan: Usage, Quota, Cost + +## Goal + +Make usage, quota, and cost accounting consistent across LiteLLM fallback, native protocol adapters, Responses API, streaming, routing fallback chains, and provider-specific quota metadata while preserving the existing `UsageManager`, `SelectionEngine`, fair-cycle behavior, storage format, and provider quota reset logic. Phase 9 should formalize a normalized usage/cost layer that feeds the current usage engine rather than replacing it. + +## Non-Goals + +- Do not replace `UsageManager`, `TrackingEngine`, `LimitEngine`, `SelectionEngine`, `SessionTracker`, or the retry-after parser. +- Do not introduce SQLite or any database. +- Do not implement the full Phase 10 JSON/env config system. +- Do not change credential selection strategy semantics. +- Do not rewrite provider quota trackers. +- Do not require live provider credentials or live pricing API calls. +- Do not make cost enforcement mandatory; cost tracking is additive unless a provider already has limits. +- Do not change client-visible response usage fields except where normalization already happens today. + +## Current Code Context + +- `CredentialContext.mark_success()` already accepts prompt, completion, thinking, cache read/write tokens, approximate cost, and response headers. +- `TrackingEngine.record_usage()` stores request count, token buckets, output tokens, total tokens, and approximate cost into windows/totals. +- `RequestExecutor._extract_usage_tokens()` extracts usage from LiteLLM-like response objects, subtracting reasoning tokens from completion tokens when the provider includes reasoning inside completion. +- `RequestExecutor._calculate_cost()` uses LiteLLM cost helpers and returns `0.0` for providers that set `skip_cost_calculation`. +- `StreamingHandler.wrap_stream()` records stream usage after final chunk and currently calculates cost with `litellm.get_model_info()`. +- Phase 4 Responses streaming maps chat usage into Responses usage for storage. +- Phase 5 native protocols can produce provider-native usage shapes through protocol parse/format paths. +- Phase 6 routing can try multiple targets, but only the successful target should record usage. +- Phase 7 fallback summaries are client-safe and should not be polluted by raw usage/cost internals. +- Phase 8 stream metrics are timing/count-only and do not alter usage recording. + +## Files To Add + +- `src/rotator_library/usage/accounting.py` +- `src/rotator_library/usage/costs.py` +- `src/rotator_library/usage/quota.py` +- `tests/test_usage_accounting.py` +- `tests/test_usage_costs.py` +- `tests/test_usage_quota_snapshots.py` +- `tests/test_executor_usage_accounting.py` +- `tests/test_streaming_usage_accounting.py` +- `tests/test_responses_usage_accounting.py` +- `tests/test_native_usage_accounting.py` + +## Files Likely To Touch + +- `src/rotator_library/usage/types.py` +- `src/rotator_library/usage/manager.py` +- `src/rotator_library/usage/tracking/engine.py` +- `src/rotator_library/client/executor.py` +- `src/rotator_library/client/streaming.py` +- `src/rotator_library/responses/bridge.py` +- `src/rotator_library/responses/service.py` +- `src/rotator_library/native_provider/executor.py` +- `src/rotator_library/core/utils.py` only if current usage normalization helper needs a small adapter. +- Provider files only for optional cost/quota declarations, not broad rewrites. + +## Normalized Usage Model + +Add `UsageRecord` dataclass: + +- `input_tokens` +- `output_tokens` +- `completion_tokens` +- `reasoning_tokens` +- `cache_read_tokens` +- `cache_write_tokens` +- `total_tokens` +- `raw_total_tokens` +- `request_count` +- `source` +- `provider` +- `model` +- `metadata` + +Token semantics: + +- `input_tokens` means billable non-cache-read prompt tokens when provider separates cache-read tokens. +- `cache_read_tokens` means prompt/cache tokens read from provider cache. +- `cache_write_tokens` means prompt/cache tokens written or created. +- `completion_tokens` means visible/output text/tool tokens, excluding reasoning when the provider reports reasoning separately. +- `reasoning_tokens` means hidden thinking/reasoning tokens. +- `output_tokens = completion_tokens + reasoning_tokens`. +- `total_tokens = input_tokens + cache_read_tokens + cache_write_tokens + completion_tokens + reasoning_tokens`. +- `raw_total_tokens` preserves provider-reported total before normalization for debugging. + +Rationale: current code already avoids double-counting reasoning by subtracting thinking from completion when necessary. Phase 9 makes that rule explicit and reusable. + +## Usage Extraction + +Add `extract_usage_record(response_or_usage, provider=None, model=None, source=None)`. + +Support: + +- LiteLLM/OpenAI object usage attributes. +- dict usage fields. +- OpenAI `prompt_tokens_details.cached_tokens`. +- OpenAI `completion_tokens_details.reasoning_tokens`. +- Anthropic `input_tokens`, `output_tokens`, `cache_creation_input_tokens`, `cache_read_input_tokens`. +- Gemini `usageMetadata`, `promptTokenCount`, `candidatesTokenCount`, `thoughtsTokenCount`, `cachedContentTokenCount`, `totalTokenCount`. +- Responses `input_tokens`, `output_tokens`, `output_tokens_details.reasoning_tokens`, `input_tokens_details.cached_tokens`. +- Existing stream usage dicts after `normalize_usage_for_response()`. + +Unknown usage shapes return an empty `UsageRecord` with source metadata rather than raising in runtime paths. + +Tests must cover dicts, objects, nested details, and double-count prevention. + +## Cost Model + +Add `CostBreakdown` dataclass: + +- `input_cost` +- `cache_read_cost` +- `cache_write_cost` +- `output_cost` +- `reasoning_cost` +- `total_cost` +- `currency` +- `pricing_source` +- `metadata` + +Add `ModelPricing` dataclass: + +- per-token prices for input, cache read, cache write, output, reasoning. +- `currency`. +- `source`. + +Add `CostCalculator`: + +- prefers explicit provider/model pricing declarations. +- falls back to LiteLLM model info/completion_cost where applicable. +- returns zero cost for `skip_cost_calculation`. +- does not call network. + +Minimal Phase 9 provider declarations: + +- Optional `get_model_pricing(model)` on `ProviderInterface`, default `None`. +- Providers can later define native pricing without touching accounting code. + +Env pricing can be minimal and Phase-10-ready: + +- `MODEL_PRICE_{PROVIDER}_{MODEL}_INPUT` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_OUTPUT` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_CACHE_READ` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_CACHE_WRITE` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_REASONING` + +If env parsing is too broad, implement programmatic pricing tests and leave env polish to Phase 10. + +## UsageManager Integration + +- Preserve existing `mark_success()` signature for compatibility. +- Add optional `usage_record` and `cost_breakdown` parameters only if needed, but do not force every call site to change. +- Preferred minimal integration: + - Executor and streaming code use `UsageRecord` internally, then pass existing numeric fields to `mark_success()`. + - `approx_cost` remains a float in current storage. + - Add optional metadata later only if storage impact is small. +- Do not change persisted JSON shape unless additive and backward-compatible. +- Add trace pass: + - `usage_accounting_summary` + - Include normalized token fields, raw total, cost total, pricing source, provider/model, and source. + - No credential secrets. + +## Quota Snapshot Model + +Add `QuotaSnapshot` dataclass: + +- `provider` +- `model` +- `quota_group` +- `credential_id` optional stable/masked identifier. +- `window_name` +- `limit` +- `used` +- `remaining` +- `reset_at` +- `source` +- `metadata` + +Add helper to build snapshots from `UsageManager` state: + +- per model. +- per quota group. +- per credential where available. + +Keep `WindowLimitChecker` behavior unchanged. + +Tests should prove snapshots reflect group windows and model windows without affecting limits. + +Optional `UsageAPI` helper can expose snapshots for future UI/TUI work if small. + +## Routing/Fallback Usage Behavior + +- Only successful target attempts record usage. +- Failed target attempts continue to record failures through existing paths. +- Route fallback summaries should include optional `usage_recorded: false` for failed targets only if already tracked internally; do not expose raw costs. +- Tests: + - first target failure + second target success records usage only on second provider/credential. + - native target success records normalized usage. + - explicit LiteLLM fallback target records normalized usage. + +## Streaming Usage + +- Replace ad-hoc stream usage extraction/cost calculation with `UsageRecord` and `CostCalculator`. +- Preserve existing `mark_success()` numeric values and final `[DONE]` behavior. +- Preserve `skip_cost_calculation`. +- Add tests: + - stream final usage with cache/read/write/reasoning records expected buckets. + - stream with missing usage records zero usage but still marks success as today. + - stream metrics from Phase 8 remain independent from token usage. + +## Responses Usage + +- Use `UsageRecord` when converting/storing Responses usage. +- Ensure `previous_response_id` storage is not affected. +- Add tests: + - non-streaming Responses usage maps to normalized fields. + - streaming Responses usage stores expected input/output totals. + - reasoning details are not double counted. + +## Native Provider Usage + +- In `NativeProviderExecutor.execute()`, protocol-formatted provider responses should expose usage in a shape `extract_usage_record()` can understand. +- Add trace summary but do not require the isolated native executor to mutate `UsageManager`. +- Live routed native execution via `RequestExecutor` should record normalized usage because it receives the native response and passes through the same executor accounting. +- Tests: + - native OpenAI-style response usage records normalized fields. + - native Gemini usage metadata records thought/cached tokens. + - native Responses usage records reasoning/details. + +## Quota/Cost Reporting + +- Keep existing quota viewer/API behavior unchanged unless adding optional fields is safe. +- Add cost totals to existing total/window stats already present as `approx_cost`. +- Do not implement hard cost caps unless trivial and isolated; cost caps can be a Phase 10+ config feature. +- Add clear docs/comments that `approx_cost` is advisory and depends on available pricing. + +## Transform Trace Requirements + +Add: + +- `usage_accounting_summary` +- `usage_cost_calculated` +- `quota_snapshot_built` for explicit snapshot APIs/tests, not every request. + +Metadata: + +- provider, model, source, pricing source, skip-cost flag. + +Data: + +- normalized usage record and cost breakdown only. + +Do not log credentials, raw headers, or raw provider errors. + +## Testing Plan + +Accounting tests: + +- OpenAI/LiteLLM object usage. +- OpenAI dict usage with prompt/completion details. +- Anthropic usage. +- Gemini `usageMetadata`. +- Responses usage. +- reasoning double-count prevention. +- missing/unknown usage shape returns empty record. + +Cost tests: + +- explicit pricing calculates all buckets. +- provider `skip_cost_calculation` returns zero. +- LiteLLM fallback path still returns a float when LiteLLM has model info. +- missing pricing returns zero with `pricing_source="unavailable"`. + +Quota snapshot tests: + +- model window snapshot. +- group window snapshot. +- missing window snapshot is empty/no-op. + +Executor tests: + +- non-stream success uses normalized usage. +- trace emits usage accounting summary. +- fallback second target success records usage once. + +Streaming tests: + +- final stream usage uses normalized accounting. +- stream cost honors skip-cost provider. +- Phase 8 metrics still emitted. + +Responses tests: + +- stored Responses usage shape remains compatible. +- normalized trace emitted when transaction logger exists. + +Native tests: + +- native response usage shapes normalize. + +Regression tests: + +- Phase 1 protocol tests. +- Phase 2 transform logging tests. +- Phase 3 adapter/cache tests. +- Phase 4 Responses tests. +- Phase 5 provider/native tests. +- Phase 6 routing tests. +- Phase 7 retry/cooldown tests. +- Phase 8 streaming tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add `UsageRecord`, extraction helpers, and tests. +2. Add `CostBreakdown`, pricing helpers/calculator, and tests. +3. Wire executor non-streaming accounting and trace summary with tests. +4. Wire streaming accounting/cost calculation with tests. +5. Add quota snapshot helpers and tests. +6. Add Responses/native usage accounting coverage. +7. Run focused and regression tests. +8. Review with `explore` and `explore-heavy`; fix findings; write uncommitted Phase 9 report. + +## Risks And Mitigations + +- Token normalization could change quota accounting. Mitigation: preserve current numeric `mark_success()` buckets and test current OpenAI/LiteLLM behavior. +- Cost estimates may be wrong for providers with unknown pricing. Mitigation: return zero/unavailable rather than guessing; make `approx_cost` advisory. +- Persisted usage JSON compatibility could break. Mitigation: keep existing storage fields and only add optional helpers unless tests prove serialization is safe. +- Provider-specific usage shapes are broad. Mitigation: implement common native shapes now and leave provider overrides possible. +- Fallback chains could double-record usage. Mitigation: only successful attempt calls `mark_success()`; add tests for failure then success. diff --git a/docs/experimental/phase-9b-usage-cost-corrections.md b/docs/experimental/phase-9b-usage-cost-corrections.md new file mode 100644 index 000000000..8cff0bd49 --- /dev/null +++ b/docs/experimental/phase-9b-usage-cost-corrections.md @@ -0,0 +1,76 @@ +# Phase 9b: Provider-Reported Cost, SSE Cost Events, And Quota Snapshot Hardening + +## Goal + +Correct Phase 9 validation findings. Phase 9 normalized token usage and added advisory cost calculation, but provider-reported costs and streaming cost events are not yet first-class runtime accounting inputs. + +## Non-Goals + +- Do not replace `UsageManager`, JSON persistence, or window/limit engines. +- Do not introduce SQLite or any new database. +- Do not make advisory pricing authoritative when providers report actual cost. +- Do not build an admin quota UI. +- Do not commit user-facing reports. +- Do not alter stream fallback/cooldown safety from Phases 6b-8b. + +## Current State + +- `protocols.types.CostDetails` exists and some protocol adapters populate it. +- `UsageRecord` has token buckets but no first-class provider-reported cost fields. +- `extract_usage_record()` ignores `cost_details`, `cost`, `total_cost`, and protocol `Usage.cost`. +- `CostCalculator.calculate()` prefers advisory pricing and drops actual provider cost values. +- `StreamingHandler` does not parse SSE comment/event cost frames. +- `mark_success(approx_cost=...)` can already store an approximate cost total without changing `UsageManager` shape. + +## Implementation Plan + +1. Extend `UsageRecord` with provider-reported cost fields. + - `provider_reported_cost: Optional[float]` + - `cost_currency: str = "USD"` + - `cost_source: Optional[str]` + - Include them in `to_dict()`. + +2. Extract provider-reported costs from provider shapes. + - OpenAI-like `cost_details.total_cost`, `cost_details.cost`, `cost`, `total_cost`, and `provider_reported_cost`. + - Anthropic/Gemini generic cost fields and `costMetadata`. + - Protocol `Usage.cost` values. + +3. Make provider-reported actual cost win in `CostCalculator`. + - `skip_cost_calculation` remains the highest-priority behavior. + - Add `provider_reported_cost` to `CostBreakdown` and make `total_cost` use it without double counting. + - Preserve advisory pricing paths when no provider-reported cost exists. + +4. Parse SSE cost comment/event frames. + - Support `: cost {json}`, `: cost 0.001`, and `event: cost` frames. + - Keep cost comments non-visible and out of session anchors. + - Final provider usage cost overrides earlier cost comments. + +5. Extend formatted SSE usage extraction. + - Ensure formatted SSE chunks with `usage.cost_details` flow through `UsageRecord` and cost calculation. + +6. Harden native and Responses cost traces. + - Native usage traces should include cost breakdown. + - Responses usage traces should show provider-reported cost when present. + +7. Keep quota snapshots honest. + - Do not invent cost totals if usage state does not store them. + - Document/request-token-window scope explicitly. + +## Tests + +- Provider-reported cost extraction tests in `test_usage_accounting.py`. +- Provider-reported cost precedence tests in `test_usage_costs.py`. +- SSE cost comment/event tests in `test_streaming_usage_accounting.py`. +- Native/Responses trace cost tests. +- Quota snapshot honesty tests. +- Phase 8b stream regression subset. + +## Acceptance Criteria + +- Provider-reported actual cost is preserved in `UsageRecord`. +- `CostCalculator` uses provider-reported cost before advisory pricing, while skip-cost still wins. +- SSE cost comments/events update streaming cost accounting and `mark_success(approx_cost=...)`. +- Cost comments remain non-visible for fallback/retry/session purposes. +- Native and Responses usage traces include provider-reported cost when present. +- Quota snapshots remain read-only request/token window reports and do not invent unsupported cost totals. +- Focused tests and dual-agent review pass. diff --git a/docs/experimental/phase-9c-usage-cost-correction.md b/docs/experimental/phase-9c-usage-cost-correction.md new file mode 100644 index 000000000..ec056965f --- /dev/null +++ b/docs/experimental/phase-9c-usage-cost-correction.md @@ -0,0 +1,61 @@ +# Phase 9c: Usage And Cost Correctness Completion + +## Goal + +Close all Phase 9/9b third-pass usage/cost findings while preserving `UsageManager` as the authoritative persistence/limit engine and keeping provider-reported cost precedence from Phase 9b. + +## Scope + +- Preserve provider-reported top-level response cost even when a `usage` object exists. +- Prevent OpenAI-like cache-write tokens from being double-counted in normalized totals/costs. +- Preserve/sum structured provider cost breakdowns that omit `total_cost`. +- Carry Responses streaming SSE cost comments/events into completed usage/cost records. +- Ensure `event: cost` frames are treated as metadata and never visible output for retry/fallback. +- Add native streaming `usage_accounting_summary` traces and preserve provider-reported stream cost where protocol events/raw chunks expose usage. + +## Non-Goals + +- Do not replace `UsageManager`, quota windows, or persisted usage JSON format. +- Do not invent cost totals when neither provider-reported cost nor configured/advisory pricing exists. +- Do not add billing persistence beyond existing usage state and traces. +- Do not implement admin quota/cost dashboards. +- Do not alter public token fields except to correct double-counting. + +## Implementation Plan + +1. Top-level provider cost preservation. + - Merge sibling top-level cost fields into nested `usage` payloads before normalization. + - Covered keys: `cost`, `total_cost`, `cost_details`, `provider_reported_cost`, `currency`, `costMetadata`. + +2. Structured provider cost breakdowns. + - Sum known provider-reported cost fields when no total is present. + - Preserve currency and source metadata. + +3. Cache-write double-count prevention. + - For OpenAI-like usage, subtract both cache-read and cache-write tokens from regular input tokens. + - Keep cache-write tokens as their own normalized bucket. + +4. Responses streaming SSE cost comments/events. + - Parse `: cost ...` comments and `event: cost` frames. + - Merge cost into in-progress/final Responses usage without emitting model output. + - Let final provider usage cost override earlier cost comments. + +5. Metadata visibility for `event: cost`. + - Add stream policy tests locking cost frames as non-visible and retry-safe. + +6. Native streaming usage/cost trace. + - Track stream usage/cost from unified events and raw chunks. + - Emit `usage_accounting_summary` on stream completion. + +7. Tests. + - Cover usage normalization, cost precedence, Responses streaming costs, stream visibility policy, and native streaming cost trace. + +## Acceptance Criteria + +- Top-level provider-reported cost is preserved when a `usage` object exists. +- Cache-write tokens are not double-counted in OpenAI-like normalized totals/costs. +- Structured provider cost breakdowns without totals are summed and treated as provider-reported cost. +- Responses streaming carries SSE cost comments/events into final usage/cost and treats them as metadata. +- `event: cost` frames do not block retry/fallback as visible output. +- Native streaming emits `usage_accounting_summary` and preserves stream cost when available. +- Focused tests pass and both reviewers report no blockers/highs/mediums. diff --git a/docs/experimental/third-pass-audit-findings.md b/docs/experimental/third-pass-audit-findings.md new file mode 100644 index 000000000..29658805c --- /dev/null +++ b/docs/experimental/third-pass-audit-findings.md @@ -0,0 +1,361 @@ +# Third-Pass Audit Findings + +This document preserves the complete third-pass validation findings so the remediation pass cannot lose track of them. + +References used by the reviewers: + +- `C:\Projects\test\new 1.txt` +- `C:\Projects\test\new 2.txt` +- General docs: `00-master-plan.md`, `01-protocol-architecture.md`, `02-transform-logging.md`, `03-field-cache-rules.md`, `04-provider-roadmap.md`, `05-routing-retry-usage-roadmap.md`, `06-phase-workflow.md`, `07-detailed-phase-roadmap.md` +- Phase first-pass plans: `phase-1-protocol-core.md` through `phase-10-config-polish.md` +- Phase second-pass plans: `phase-1b-protocol-breadth-operation-model.md` through `phase-10b-config-surface-wiring.md` + +## Overall Verdict + +The third-pass review is not clean. Each phase received one normal `explore` review and one `explore-heavy` review. The reviewers found remaining blockers, highs, mediums, and low-risk residuals across the implemented work. + +## Severity Summary + +| Phase | Highest Severity | Clean? | +|---|---:|---:| +| Phase 1 / 1b Protocols | High | No | +| Phase 2 / 2b Transform Tracing | High | No | +| Phase 3 / 3b Field Cache | Medium | No | +| Phase 4 / 4b Responses | Medium | No | +| Phase 5 / 5b Providers | Blocker | No | +| Phase 6 / 6b Routing/Fallback | Blocker | No | +| Phase 7 / 7b Retry/Cooldown | High | No | +| Phase 8 / 8b Streaming | Medium | No | +| Phase 9 / 9b Usage/Cost | High | No | +| Phase 10 / 10b Config | High | No | + +## Phase 1 / 1b: Protocols + +### Blockers + +- None reported. + +### High + +- Native OpenAI Chat and Responses formatting returns unified usage shape instead of protocol-native usage shape. + - OpenAI Chat `format_response()` can emit `Usage.to_dict()` fields like `input_tokens` / `output_tokens` instead of `prompt_tokens` / `completion_tokens` / `total_tokens`. + - Responses `format_response()` can do the same instead of Responses-compatible usage fields. + - Native executor returns formatted protocol bodies directly to clients, so this can leak wrong usage schemas into live native responses. + +### Medium + +- OpenAI legacy `function_call` is preserved only in `extra`, not modeled as a unified `ToolCall`. +- Ollama lacks a `format_response()` override, so adapter-mutated unified responses can be ignored when `raw` is present. +- Operation compatibility is not enforced in live native provider execution. + - Native execution resolves protocol and operation but does not call `supports_operation()`. + - Misconfigured provider/protocol/operation combinations can silently fall back or drift. + +### Low / Residual + +- No cross-protocol conversion tests such as Anthropic -> unified -> OpenAI or Gemini -> unified -> OpenAI. +- `format_stream_event()` is not overridden by Anthropic, Gemini, Responses, or Ollama; they rely on base raw-pass-through behavior. +- Ollama declares `jsonl` support without JSONL-specific stream formatting. +- Unified types have `to_dict()` but no `from_dict()` deserialization helpers. +- Duplicate helper functions exist across protocol modules. +- No explicit `format_response()` tests for OpenAI Chat, Anthropic Messages, or Gemini. +- `CostDetails` protocol extraction is present in OpenAI Chat and Responses only. +- Registry discovery imports protocol modules without import-error isolation. +- Missing tests for formatted usage schemas, legacy `function_call`, Ollama mutated response formatting, registry failure isolation, and native unsupported operation rejection. + +## Phase 2 / 2b: Transform Tracing + +### Blockers + +- None reported. + +### High + +- Anthropic compatibility path lacks transform-pass tracing for live conversion boundaries. + - Missing explicit traces for Anthropic raw request, Anthropic -> OpenAI conversion, OpenAI -> Anthropic final response, and per-event stream conversion. + - `AnthropicHandler` and `anthropic_streaming_wrapper` mostly rely on legacy transaction logger calls. + +### Medium + +- `raw_provider_stream_response` logs the stream iterator object rather than provider stream data. +- Streaming handler lifecycle traces are gated by `trace_metrics`; chunk traces remain, but lifecycle diagnostics can disappear. +- Non-streaming final response trace can be emitted before usage normalization, making `final_client_response` not truly final. +- Native streaming does not run or trace the stream-event adapter chain. +- Responses streaming does not trace final SSE formatting boundaries. +- Redaction misses common camelCase secret keys such as `apiKey`, `accessToken`, `refreshToken`, and `clientSecret`. +- Non-streaming Responses errors are not traced with standardized transform-failure shape. +- Built-in provider transform exceptions are not logged as transform failures. + +### Low / Residual + +- `ProviderTransforms.apply_sync()` has no trace support; currently not used by live execution. +- `AnthropicHandler.count_tokens()` has no transaction logger or trace entries. +- Provider logger trace entries cannot override per-entry `credential_id`; they rely on context correlation. +- No dedicated test for Anthropic transform tracing. +- Responses service has redundant `if transaction_logger` guards before `_trace()` calls. +- Native provider streaming does not apply adapter chains to stream events. +- Explicit LiteLLM fallback is traceable by metadata but has no pass named `litellm_fallback_request`. + +## Phase 3 / 3b: Adapters And Field Cache + +### Blockers + +- None reported. + +### High + +- None reported. + +### Medium + +- Runtime does not execute all declared field-cache sources and targets. + - `request`, `unified_request`, `unified_response`, and `unified_stream_event` are accepted, but live native execution only injects `request` and extracts `response` / `stream_event`. +- Adapter-chain traces can leak arbitrary configured provider-state fields. + - Native executor applies rule-aware redaction around its own traces, but `run_adapter_chain()` logs payloads directly. +- Credential-scoped field-cache rules do not fail closed when `credential_id` is missing. + - Missing credential can become a shared hashed `_none` bucket in direct/library native executor usage. + +### Low / Residual + +- Copilot provider has no field-cache rules by design. +- `ProviderCacheFieldStore.clear()` depends on injected cache exposing `clear()`. +- Turn-mode inference defaults to common `messages` shape; custom providers need explicit metadata. +- `per_tool_call` injection requires tool-call ID in context metadata or payload path. +- `FieldCacheEngine` defaults to a fresh in-memory store if used directly; `NativeProviderExecutor` correctly shares one per executor instance. +- Native streaming path extracts from stream events but does not inject into stream events, which is expected. +- `ProviderCacheFieldStore.append()` may serialize already wrapped values twice in edge cases. +- Process-local default store is per `NativeProviderExecutor` instance, not a global singleton. +- JSON parsing accepts non-positive `ttl_seconds`, which effectively disables expiry. +- Provider declarations are still conservative/mock-live for several priority providers. + +## Phase 4 / 4b: Responses API + +### Blockers + +- None reported. + +### High + +- None reported. + +### Medium + +- Responses storage is process-local by default; durable JSON/current-state or provider-cache-backed storage is not wired into app startup. +- `previous_response_id` bridge context is lossy. + - The bridge replays parent output only, not parent input or full lineage. +- Responses route errors are wrapped as FastAPI `detail.error` instead of top-level OpenAI-compatible `error` bodies. + +### Low / Residual + +- `ProviderCacheResponsesStore.delete()` returns `False` if the injected cache lacks `delete_async`. +- WebSocket formatter exists but has no FastAPI route or service-neutral integration test. +- `_record_responses_session_anchor()` no-ops if `session_tracker` or `session_id` is unavailable. +- No periodic TTL cleanup task; expired entries prune on access/write. +- Bridge output conversion handles text/tool calls but not all richer Responses item types. +- `store_in_progress` returns opt-in `in_progress` state without a formal retrieval schema. +- `ResponsesStreamState` uses fixed `msg_0` for bridge streams. +- No explicit concurrent access tests for `InMemoryResponsesStore`. +- No cancel endpoint. +- `validate_stream_request()` and `stream_events()` can load the same parent response twice. +- `stream_response(..., transport=...)` always uses SSE formatting. +- Formatted SSE frames themselves are not traced. +- Tests do not include end-to-end FastAPI + real `RotatingClient` sticky continuation or durable app-level storage wiring. + +## Phase 5 / 5b: Provider-Native Integrations + +### Blockers + +- Native provider responses are formatted back to the provider protocol, not the originating client protocol. + - `/v1/chat/completions` routed to Claude Code, Codex, or Antigravity can return Anthropic Messages, Responses API, or Gemini payloads instead of OpenAI Chat payloads. + - Native context does not carry a client target protocol; native executor uses the provider protocol for `format_response()`. + +### High + +- Claude Code native requests can omit required Anthropic `max_tokens`. +- Antigravity model normalization is unsafe for duplicated aliases and can lose low/high thinking-level behavior. + +### Medium + +- `get_native_headers()` and `get_native_endpoint()` are not declared on `ProviderInterface` even though the executor requires them dynamically. +- Native streaming is disabled for all priority providers. This is safe, but it leaves the native streaming path unexercised by real priority providers. +- Claude Code auth header convention may be wrong for API-key credentials because it always uses `Authorization: Bearer` with `anthropic-version`. +- ProviderInterface native contract lacks `supports_native_operation()` / `should_use_native_protocol()` style hooks, making non-stream native auto-selection all-or-nothing by protocol declaration. +- Native streaming fail-closed behavior is duplicated and only partially fail-closed; the executor duplicate catches fewer exceptions than the safe helper. +- Antigravity quota grouping regresses retired parity by grouping models too broadly. +- Field-cache declarations exist, but thinking/signature reinjection is not proven semantically correct for protocol-native continuation requirements. + +### Low / Residual + +- Copilot has no field-cache rules by design. +- Codex message-to-Responses conversion is minimal and does not handle tool calls, multipart content, images, or rich system formatting. +- Antigravity intentionally restores only a conservative safe subset; many retired features remain absent. +- Antigravity has quota groups but no usage reset configs or tier priorities. +- Priority provider model discovery catches broad exceptions and silently falls back to hardcoded lists. +- Execution-mode routing priority is correct and fail-closed. +- Native streaming fail-closed behavior is tested for explicit native streaming. +- Mock-live tests cover direct native request execution for all four priority providers, but not full end-to-end credential rotation/usage manager execution. +- Some provider docstrings still describe skeleton behavior. + +## Phase 6 / 6b: Routing And Fallback + +### Blockers + +- Stale fallback test file imports a missing module, so broad test collection can fail. + - `tests/test_fallback_groups.py` imports `rotator_library.fallback_groups.FallbackGroupManager`, which no longer exists. + - The same file asserts obsolete `RotatingClient(model_fallback_groups=...)` / `fallback_manager` behavior. + +### High + +- Requested-model-in-group promotion is not implemented. + - If the requested model is inside a fallback group, the router should try it first and preserve remaining group order. +- Streaming execution modes do not match non-streaming execution-mode behavior. + - Streaming can ignore explicit `@litellm_fallback` when a plugin has custom logic and can fall through differently for `@custom`. +- Explicit native non-streaming configuration errors can be treated as retryable fallback errors because some `RoutingExecutionError`s default to `unsupported_operation`. +- Structured route-error classification misses common status-less aliases such as `invalid_api_key`, `unauthorized`, `invalid_argument`, `rate_limited`, `too_many_requests`, `resource_exhausted`, and `unavailable`. + +### Medium + +- Fallback target session namespace is cloned from the first target instead of being re-inferred or adjusted per target, risking response anchors under the wrong provider/model namespace. + +### Low / Residual + +- No end-to-end test for env config -> `RequestContextBuilder` -> `RequestExecutor.execute()` full fallback scenario. +- `build_embedding_context()` does not resolve routing groups. +- No explicit test for `_execute_non_streaming_with_fallback()` when all targets return structured error responses. +- `_route_error_type_from_response()` falls back to text-summary scanning when structured fields are ambiguous. +- Fallback exhaustion summaries are sanitized and appear safe. +- Routing route aliases are lowercase keys only; there is no hyphen/space alias normalization. +- Wrapper-level streaming tests are thinner than helper-level coverage for some event frames. + +## Phase 7 / 7b: Retry, Cooldown, Backoff, Failure History + +### Blockers + +- None reported. + +### High + +- Provider/model cooldown can be bypassed when the remaining cooldown exceeds the request deadline budget. + - `_wait_for_cooldown()` logs and returns, and callers continue into credential acquisition/execution. +- Generic transient errors start provider-wide cooldown too aggressively. + - A single `server_error` or `api_connection` without `retry_after` can start provider cooldown immediately. + +### Medium + +- `decide_streaming_error_action()` is implemented and tested but not used by the live executor. +- Failure history lacks success reset; successes do not clear or reduce provider/model failure history. +- Structured retry attempt history is not populated beyond traces/summaries; `routing_attempt_history` exists but is unused in live executor. + +### Low / Residual + +- `FallbackAttemptRunner` is tested but not used by the live executor; inline wrappers remain. +- `decide_streaming_error_action()` does not accept `failure_history`, which would matter if it is adopted live. +- No `test_request_executor_group_policy.py` file despite the plan mentioning it. +- Streaming error branches in the executor are duplicated across several handlers. +- Model-capacity detection is narrow and misses variants such as `CAPACITY_EXHAUSTED` or `capacity-exhausted`. +- Tests miss cooldown-over-budget fail-fast/fallback, single generic transient non-cooldown, success reset, live model cooldown not blocking unrelated models, and live streaming parity with the decision helper. + +## Phase 8 / 8b: Streaming Hardening + +### Blockers + +- None reported. + +### High + +- None reported. + +### Medium + +- Responses streaming bypasses active Phase 8b hardening at the Responses layer. + - No direct TTFB/stall/heartbeat/disconnect/close policy is applied in `ResponsesService.stream_events()`. + - Responses formatter has no heartbeat formatter. +- Anthropic compatibility streaming can stop on disconnect without explicitly closing the upstream OpenAI stream. + +### Low / Residual + +- No formal `StreamTransportFormatter` protocol/ABC. +- Executor streaming error handling remains duplicated across branches. +- Responses streaming does not emit heartbeats or enforce TTFB/stall policy directly. +- Native streaming executor does not integrate with `StreamingHandler.wrap_stream()`. +- Native `_parse_stream_line()` returns raw text on JSON parse failure. +- No specific test for stall timeout after visible output. +- Normalized stream event parser comment says malformed chunks fail closed for visibility, but parser returns `visible_output=False`; routing safety is protected elsewhere. +- `decide_streaming_error_action()` is not used by the live executor. + +## Phase 9 / 9b: Usage, Quota, Cost + +### Blockers + +- None reported. + +### High + +- Provider-reported top-level response costs are ignored when a `usage` object exists. + - `extract_usage_record()` unwraps full responses to `response["usage"]`, discarding sibling `cost`, `total_cost`, `cost_details`, or `provider_reported_cost`. +- OpenAI-like cache-write tokens can be double-counted in normalized totals/costs. + - `input_tokens = prompt_tokens - cache_read`, while `cache_write_tokens` is stored separately and also added to total. +- Responses streaming drops SSE cost comments/events unless final usage carries cost. + +### Medium + +- `event: cost` frames can block streaming retry/fallback despite being metadata. +- Native streaming executor does not emit `usage_accounting_summary` trace. +- Native streaming executor does not preserve provider-reported cost from raw response. +- Structured provider cost breakdowns without `total_cost` are not preserved or summed. + +### Low / Residual + +- Responses bridge passes cost fields through verbatim rather than structuring them, which is acceptable but should be documented. +- No explicit native streaming cost-trace test. +- Streaming cost calculator does not pass `provider_plugin`, so provider explicit pricing can be missed in streaming mode when no provider-reported cost exists. +- Quota snapshot `used` maps to request count, not token count; metadata says request/token window but field naming is ambiguous. +- No integration-level test for Responses usage with provider-reported cost. +- No quota snapshot test for combined model and group filters. +- Planned trace pass names like `usage_cost_calculated` or `quota_snapshot_built` are not implemented if considered distinct from `usage_accounting_summary`. + +## Phase 10 / 10b: Config + +### Blockers + +- None reported. + +### High + +- Full `PROXY_API_KEY` is printed to console at startup. +- Provider/protocol/adapter/model/quota-checker config surfaces are still not live-wired, despite the `providers` JSON section being accepted/documented. + +### Medium + +- Responses streaming ignores Phase 10 streaming runtime settings. +- Pricing env parsing can fail requests instead of warning/ignoring invalid values. +- Routing config validation is not fully startup-safe; direct model-route target specs can parse later at request time. +- Secret rejection is key-name based and misses generic `credential` / `credentials` key names. + +### Low / Residual + +- `STREAM_RETRY_ON_REASONING_ONLY` is env-only and not part of `StreamRuntimeSettings` JSON config. +- `get_responses_store_settings()` returns `Any` rather than a precise type. +- `provider_cooldown_env()` remains a legacy tuple wrapper over richer retry settings. +- `STREAM_HEARTBEAT_SECONDS` legacy alias resolution is subtle. +- No standalone `config-reference.md` was created. +- `providers` section accepts raw dicts but does not validate protocol/adapter names; this is conservative but incomplete for future wiring. +- `.env.example` drift test does not verify the legacy `STREAM_HEARTBEAT_SECONDS=0` default line. + +## Cross-Phase Remediation Order + +1. Phase 5 blocker: native provider output must be converted back to the originating client protocol. +2. Phase 6 blocker: retire/update stale fallback tests and fix routing execution correctness. +3. Phase 10 high: stop printing secrets and make accepted config surfaces honest/live-wired or reject unsupported shapes. +4. Phase 1 high/medium protocol formatting and operation enforcement. +5. Phase 9 high usage/cost correctness. +6. Phase 7 high cooldown budget and transient cooldown behavior. +7. Phase 2 tracing gaps, especially Anthropic compatibility and redaction. +8. Phase 3 field-cache runtime/redaction gaps. +9. Phase 4 Responses storage/continuation/error-shape gaps. +10. Phase 8 streaming hardening gaps for Responses and Anthropic compatibility. + +## Required Completion Bar + +- Every blocker/high/medium listed above must be fixed or explicitly re-reviewed as no longer applicable. +- Low-risk items should be fixed when small/safe; otherwise they must be carried into phase reports as explicit residuals. +- Each remediation phase must follow the established loop: plan in conversation, write plan doc, implement focused commits, run focused tests, run `explore` and `explore-heavy`, fix findings, re-review, write user-facing report without committing the report. diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 4aa689720..5cf4a4dc1 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -140,6 +140,8 @@ from proxy_app.request_logger import log_request_to_console from proxy_app.batch_manager import EmbeddingBatcher from proxy_app.detailed_logger import RawIOLogger + from rotator_library.responses import ResponsesService, ResponsesServiceError + from rotator_library.transaction_logger import TransactionLogger print(" → Discovering provider plugins...") # Provider lazy loading happens during import, so time it here @@ -592,6 +594,13 @@ async def process_credential(provider: str, path: str, provider_instance): # print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})") client.background_refresher.start() # Start the background task app.state.rotating_client = client + # Phase 4 Responses API compatibility service. It currently bridges through + # the existing chat-completions client path; later native providers can reuse + # the same route/storage surface without changing clients. + from rotator_library.config.experimental import get_responses_store_settings + from rotator_library.responses import create_configured_responses_store + + app.state.responses_service = ResponsesService(store=create_configured_responses_store(), store_settings=get_responses_store_settings()) # Warn if no provider credentials are configured if not client.all_credentials: @@ -663,6 +672,19 @@ def get_embedding_batcher(request: Request) -> EmbeddingBatcher: return request.app.state.embedding_batcher +def get_responses_service(request: Request) -> ResponsesService: + """Dependency to get the Responses API service instance from app state.""" + + service = getattr(request.app.state, "responses_service", None) + if service is None: + from rotator_library.config.experimental import get_responses_store_settings + from rotator_library.responses import create_configured_responses_store + + service = ResponsesService(store=create_configured_responses_store(), store_settings=get_responses_store_settings()) + request.app.state.responses_service = service + return service + + async def verify_api_key(auth: str = Depends(api_key_header)): """Dependency to verify the proxy API key.""" # If PROXY_API_KEY is not set or empty, skip verification (open access) @@ -1011,6 +1033,99 @@ async def chat_completions( raise HTTPException(status_code=500, detail=str(e)) +def _responses_error_response(error: ResponsesServiceError) -> dict[str, Any]: + """Return an OpenAI-compatible error payload for Responses routes.""" + + return {"error": {"message": str(error), "type": error.error_type, "code": error.status_code}} + + +@app.post("/v1/responses") +async def responses_create( + request: Request, + client: RotatingClient = Depends(get_rotating_client), + service: ResponsesService = Depends(get_responses_service), + _=Depends(verify_api_key), +): + """OpenAI-compatible Responses API create endpoint. + + Phase 4 non-streaming requests bridge through the current completion engine. + Streaming requests are handled by the dedicated SSE checkpoint in this phase. + """ + + logger = RawIOLogger() if ENABLE_RAW_LOGGING else None + try: + request_data = await request.json() + except json.JSONDecodeError: + return JSONResponse(status_code=400, content={"error": {"message": "Invalid JSON in request body.", "type": "invalid_request_error", "code": 400}}) + if logger: + logger.log_request(headers=dict(request.headers), body=request_data) + transaction_logger = TransactionLogger("responses", request_data.get("model", "unknown")) if ENABLE_REQUEST_LOGGING else None + try: + if request_data.get("stream"): + await service.validate_stream_request(request_data) + return StreamingResponse( + service.stream_response(request_data, client, request=request, transaction_logger=transaction_logger), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) + result = await service.create_response(request_data, client, request=request, transaction_logger=transaction_logger) + if logger: + logger.log_final_response(status_code=200, headers=None, body=result) + return JSONResponse(content=result) + except ResponsesServiceError as e: + payload = _responses_error_response(e) + if logger: + logger.log_final_response(status_code=e.status_code, headers=None, body=payload) + return JSONResponse(status_code=e.status_code, content=payload) + except Exception as e: + logging.error(f"Responses endpoint error: {e}") + if logger: + logger.log_final_response(status_code=500, headers=None, body={"error": str(e)}) + return JSONResponse(status_code=500, content={"error": {"message": str(e), "type": "internal_error", "code": 500}}) + + +@app.get("/v1/responses/{response_id}") +async def responses_get( + response_id: str, + service: ResponsesService = Depends(get_responses_service), + _=Depends(verify_api_key), +): + """Retrieve a stored Responses object by ID.""" + + try: + return JSONResponse(content=await service.get_response(response_id)) + except ResponsesServiceError as e: + return JSONResponse(status_code=e.status_code, content=_responses_error_response(e)) + + +@app.delete("/v1/responses/{response_id}") +async def responses_delete( + response_id: str, + service: ResponsesService = Depends(get_responses_service), + _=Depends(verify_api_key), +): + """Delete a stored Responses object by ID.""" + + try: + return JSONResponse(content=await service.delete_response(response_id)) + except ResponsesServiceError as e: + return JSONResponse(status_code=e.status_code, content=_responses_error_response(e)) + + +@app.get("/v1/responses/{response_id}/input_items") +async def responses_input_items( + response_id: str, + service: ResponsesService = Depends(get_responses_service), + _=Depends(verify_api_key), +): + """Return stored input items for a Responses object.""" + + try: + return JSONResponse(content=await service.list_input_items(response_id)) + except ResponsesServiceError as e: + return JSONResponse(status_code=e.status_code, content=_responses_error_response(e)) + + # --- Anthropic Messages API Endpoint --- @app.post("/v1/messages") async def anthropic_messages( diff --git a/src/rotator_library/adapters/__init__.py b/src/rotator_library/adapters/__init__.py new file mode 100644 index 000000000..0c28ed20c --- /dev/null +++ b/src/rotator_library/adapters/__init__.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Composable payload adapters used between protocols and providers.""" + +from .base import AdapterContext, PayloadAdapter, run_adapter_chain +from .registry import ( + ADAPTER_ALIASES, + ADAPTER_PLUGINS, + get_adapter, + get_adapter_class, + list_adapters, + register_adapter, + resolve_adapter_name, +) + +__all__ = [ + "ADAPTER_ALIASES", + "ADAPTER_PLUGINS", + "AdapterContext", + "PayloadAdapter", + "get_adapter", + "get_adapter_class", + "list_adapters", + "register_adapter", + "resolve_adapter_name", + "run_adapter_chain", +] diff --git a/src/rotator_library/adapters/base.py b/src/rotator_library/adapters/base.py new file mode 100644 index 000000000..89f4e48a8 --- /dev/null +++ b/src/rotator_library/adapters/base.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Base classes for configurable protocol/provider payload adapters.""" + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Iterable, Mapping, Optional + + +@dataclass +class AdapterContext: + """Context available to every adapter pass. + + Adapters are intentionally base implementations, not provider law. Provider + code can subclass adapters, copy and mutate them, or bypass them when a + protocol needs special behavior. The context mirrors transform trace fields + so each adapter pass can be debugged by provider/model/session/scope. + """ + + provider: Optional[str] = None + model: Optional[str] = None + protocol: Optional[str] = None + credential_id: Optional[str] = None + session_id: Optional[str] = None + scope_key: Optional[str] = None + classifier: Optional[str] = None + transport: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + adapter_config: dict[str, dict[str, Any]] = field(default_factory=dict) + transaction_logger: Optional[Any] = None + + def config_for(self, adapter_name: str) -> dict[str, Any]: + """Return adapter-specific config without mutating the original context.""" + + return dict(self.adapter_config.get(adapter_name, {})) + + +class PayloadAdapter: + """Override-friendly base adapter. + + The default adapter is a no-op. Concrete adapters override only the stages + they support. Methods are async so future provider adapters can consult + caches, metadata services, or remote capability probes without changing the + chain runner API. + """ + + name: str = "" + aliases: tuple[str, ...] = () + supported_stages: tuple[str, ...] = ("request", "response", "stream_event") + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + return payload + + async def transform_response(self, payload: Any, context: AdapterContext) -> Any: + return payload + + async def transform_stream_event(self, payload: Any, context: AdapterContext) -> Any: + return payload + + async def transform(self, stage: str, payload: Any, context: AdapterContext) -> Any: + """Dispatch a stage-specific transform with a useful error for typos.""" + + if stage not in self.supported_stages: + return payload + if stage == "request": + return await self.transform_request(payload, context) + if stage == "response": + return await self.transform_response(payload, context) + if stage == "stream_event": + return await self.transform_stream_event(payload, context) + raise ValueError(f"Unknown adapter stage: {stage}") + + +def _trace(context: AdapterContext, pass_name: str, payload: Any, *, stage: str, metadata: Mapping[str, Any]) -> None: + logger = context.transaction_logger + if not logger: + return + logger.log_transform_pass( + pass_name, + payload, + direction="stream" if stage == "stream_event" else stage, + stage="adapter", + protocol=context.protocol, + credential_id=context.credential_id, + transport=context.transport, + metadata=dict(metadata), + snapshot=stage != "stream_event", + ) + + +async def run_adapter_chain( + adapters: Iterable[PayloadAdapter], + payload: Any, + context: AdapterContext, + *, + stage: str, + mutate: bool = False, +) -> Any: + """Run adapters in order and emit trace entries around the chain. + + By default the payload is deep-copied before the first adapter. This keeps + Phase 3 isolated and prevents surprise mutations until runtime integration + explicitly chooses mutating behavior for performance. + """ + + current = payload if mutate else deepcopy(payload) + tracing_enabled = context.transaction_logger is not None + original = deepcopy(current) if tracing_enabled else None + adapter_names = [adapter.name for adapter in adapters] + _trace(context, "before_adapter_chain", current, stage=stage, metadata={"adapters": adapter_names}) + for adapter in adapters: + before = deepcopy(current) if tracing_enabled else None + try: + current = await adapter.transform(stage, current, context) + except Exception as exc: + if context.transaction_logger: + context.transaction_logger.log_transform_error( + f"adapter:{adapter.name}:{stage}", + exc, + payload=before if before is not None else current, + stage="adapter", + protocol=context.protocol, + transport=context.transport, + metadata={"adapter": adapter.name, "adapter_stage": stage}, + ) + raise + _trace( + context, + "after_adapter", + current, + stage=stage, + metadata={"adapter": adapter.name, "adapter_stage": stage, "changed": (current != before) if before is not None else None}, + ) + _trace( + context, + "after_adapter_chain", + current, + stage=stage, + metadata={"adapters": adapter_names, "adapter_stage": stage, "adapter_count": len(adapter_names), "changed": (original != current) if original is not None else None}, + ) + return current diff --git a/src/rotator_library/adapters/builtin.py b/src/rotator_library/adapters/builtin.py new file mode 100644 index 000000000..1475d70a7 --- /dev/null +++ b/src/rotator_library/adapters/builtin.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Built-in base adapters for common provider payload quirks.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any +from uuid import uuid4 + +from ..field_cache.paths import extract_path, inject_path +from .base import AdapterContext, PayloadAdapter + + +class NoOpAdapter(PayloadAdapter): + """Adapter that intentionally leaves payloads unchanged.""" + + name = "noop" + aliases = ("none", "passthrough") + + +class ModelOverrideAdapter(PayloadAdapter): + """Replace the outbound model field from adapter config. + + Config shape: + `{ "model": "provider/native-model-name" }` + """ + + name = "model_override" + aliases = ("override_model",) + supported_stages = ("request",) + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + config = context.config_for(self.name) + override = config.get("model") or context.metadata.get("model_override") + if not override or not isinstance(payload, dict): + return payload + updated = deepcopy(payload) + updated["model"] = override + return updated + + +class SuppressDeveloperRoleAdapter(PayloadAdapter): + """Convert or remove developer-role messages for providers that reject them. + + Config shape: + `{ "mode": "system" | "user" | "drop" }` + """ + + name = "suppress_developer_role" + aliases = ("developer_role",) + supported_stages = ("request",) + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + if not isinstance(payload, dict) or not isinstance(payload.get("messages"), list): + return payload + mode = context.config_for(self.name).get("mode", "system") + if mode not in {"system", "user", "drop"}: + raise ValueError("suppress_developer_role mode must be system, user, or drop") + updated = deepcopy(payload) + messages = [] + for message in updated.get("messages", []): + if not isinstance(message, dict) or message.get("role") != "developer": + messages.append(message) + continue + if mode == "drop": + continue + converted = dict(message) + converted["role"] = mode + messages.append(converted) + updated["messages"] = messages + return updated + + +class ReasoningContentAdapter(PayloadAdapter): + """Normalize common reasoning fields on assistant messages. + + This base adapter copies `reasoning`, `reasoning_content`, or configured + aliases into the configured output field. It deliberately does not delete + source fields; provider-specific subclasses can choose stricter behavior. + """ + + name = "reasoning_content" + aliases = ("reasoning_rewrite",) + supported_stages = ("response",) + + async def transform_response(self, payload: Any, context: AdapterContext) -> Any: + if not isinstance(payload, dict): + return payload + config = context.config_for(self.name) + output_field = config.get("output_field", "reasoning_content") + source_fields = tuple(config.get("source_fields", ("reasoning_content", "reasoning"))) + updated = deepcopy(payload) + for choice in updated.get("choices", []) if isinstance(updated.get("choices"), list) else []: + message = choice.get("message") if isinstance(choice, dict) else None + if not isinstance(message, dict): + continue + if output_field in message: + continue + for source_field in source_fields: + if source_field in message: + message[output_field] = message[source_field] + break + return updated + + +class FieldRenameAdapter(PayloadAdapter): + """Copy configured values between paths on raw payloads. + + Config shape: + `{ "rules": [{ "source_path": "a.b", "target_path": "c.d", "stage": "request", "move": false }] }` + + This adapter is conservative by design: it copies the last matched value by + default, delegates target ambiguity checks to `inject_path`, and only removes + the source for simple dotted-key paths when `move=true`. + """ + + name = "field_rename" + aliases = ("field_copy",) + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + return self._transform_stage(payload, context, "request") + + async def transform_response(self, payload: Any, context: AdapterContext) -> Any: + return self._transform_stage(payload, context, "response") + + async def transform_stream_event(self, payload: Any, context: AdapterContext) -> Any: + return self._transform_stage(payload, context, "stream_event") + + def _transform_stage(self, payload: Any, context: AdapterContext, stage: str) -> Any: + if not isinstance(payload, dict): + return payload + updated = deepcopy(payload) + for rule in context.config_for(self.name).get("rules", []): + if rule.get("stage", stage) != stage: + continue + values = extract_path(updated, rule["source_path"]) + if not values: + continue + value = values if rule.get("as_list") else values[-1] + inject_path( + updated, + rule["target_path"], + value, + when_missing_only=bool(rule.get("when_missing_only", False)), + ) + if rule.get("move"): + _delete_simple_path(updated, rule["source_path"]) + return updated + + +class AntigravityEnvelopeAdapter(PayloadAdapter): + """Wrap Gemini payloads in the Antigravity internal request envelope. + + The active provider restores only stable envelope fields. Device profiles, + fingerprints, and other volatile client-emulation fields are intentionally + not generated here until they are verified against current service behavior. + """ + + name = "antigravity_envelope" + supported_stages = ("request",) + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + if not isinstance(payload, dict): + return payload + if _looks_like_antigravity_envelope(payload): + return payload + config = context.config_for(self.name) + model = payload.get("model") or context.model + request_payload = {key: deepcopy(value) for key, value in payload.items() if key != "model"} + envelope = { + "model": model, + "request": request_payload, + "requestType": config.get("request_type", "CHAT_COMPLETION"), + "requestId": str(uuid4()), + "userAgent": config.get("user_agent"), + } + project = config.get("project") + if project: + envelope["project"] = project + return {key: value for key, value in envelope.items() if value is not None} + + +def _looks_like_antigravity_envelope(payload: dict[str, Any]) -> bool: + """Return whether a payload already has the controlled envelope shape.""" + + return "request" in payload and "requestType" in payload and "requestId" in payload + + +def _delete_simple_path(payload: dict[str, Any], path: str) -> None: + """Delete a simple dotted dict path after a conservative move operation.""" + + parts = path.split(".") + if any("[" in part or part == "*" for part in parts): + return + current: Any = payload + for part in parts[:-1]: + if not isinstance(current, dict): + return + current = current.get(part) + if isinstance(current, dict): + current.pop(parts[-1], None) diff --git a/src/rotator_library/adapters/registry.py b/src/rotator_library/adapters/registry.py new file mode 100644 index 000000000..7d702c0be --- /dev/null +++ b/src/rotator_library/adapters/registry.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Auto-discovery registry for payload adapters.""" + +from __future__ import annotations + +import importlib +import inspect +import logging +import pkgutil +from typing import Type + +from .base import PayloadAdapter + +lib_logger = logging.getLogger("rotator_library") + +ADAPTER_PLUGINS: dict[str, Type[PayloadAdapter]] = {} +ADAPTER_ALIASES: dict[str, str] = {} +_ADAPTER_INSTANCES: dict[str, PayloadAdapter] = {} + +_INFRASTRUCTURE_MODULES = {"base", "registry"} + + +def register_adapter(adapter_class: Type[PayloadAdapter], *, replace: bool = False) -> Type[PayloadAdapter]: + """Register an adapter class and its aliases with collision checks.""" + + if not inspect.isclass(adapter_class) or not issubclass(adapter_class, PayloadAdapter): + raise TypeError("adapter_class must inherit PayloadAdapter") + if adapter_class is PayloadAdapter: + raise TypeError("cannot register PayloadAdapter itself") + name = adapter_class.name + if not name: + raise ValueError(f"Adapter {adapter_class.__name__} must define a name") + alias_owner = ADAPTER_ALIASES.get(name) + if alias_owner and alias_owner != name and not replace: + raise ValueError(f"Adapter name conflicts with registered alias: {name}") + existing = ADAPTER_PLUGINS.get(name) + if existing and existing is not adapter_class and not replace: + raise ValueError(f"Adapter name already registered: {name}") + if replace and existing and existing is not adapter_class: + for alias, owner in list(ADAPTER_ALIASES.items()): + if owner == name: + ADAPTER_ALIASES.pop(alias, None) + ADAPTER_PLUGINS[name] = adapter_class + _ADAPTER_INSTANCES.pop(name, None) + for alias in adapter_class.aliases: + existing_name = ADAPTER_ALIASES.get(alias) + if existing_name and existing_name != name and not replace: + raise ValueError(f"Adapter alias already registered: {alias}") + if alias in ADAPTER_PLUGINS and alias != name and not replace: + raise ValueError(f"Adapter alias conflicts with registered name: {alias}") + ADAPTER_ALIASES[alias] = name + lib_logger.debug("Registered adapter: %s", name) + return adapter_class + + +def resolve_adapter_name(name: str) -> str: + if name in ADAPTER_PLUGINS: + return name + if name in ADAPTER_ALIASES: + return ADAPTER_ALIASES[name] + raise KeyError(f"Unknown adapter: {name}") + + +def get_adapter_class(name: str) -> Type[PayloadAdapter]: + return ADAPTER_PLUGINS[resolve_adapter_name(name)] + + +def get_adapter(name: str) -> PayloadAdapter: + canonical = resolve_adapter_name(name) + if canonical not in _ADAPTER_INSTANCES: + _ADAPTER_INSTANCES[canonical] = ADAPTER_PLUGINS[canonical]() + return _ADAPTER_INSTANCES[canonical] + + +def list_adapters() -> list[str]: + return sorted(ADAPTER_PLUGINS) + + +def _register_adapters() -> None: + package = importlib.import_module(__package__ or "rotator_library.adapters") + for _, module_name, _ in pkgutil.iter_modules(package.__path__): + if module_name.startswith("_") or module_name in _INFRASTRUCTURE_MODULES: + continue + module = importlib.import_module(f"{package.__name__}.{module_name}") + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if ( + inspect.isclass(attribute) + and issubclass(attribute, PayloadAdapter) + and attribute is not PayloadAdapter + and attribute.__module__ == module.__name__ + ): + register_adapter(attribute) + + +_register_adapters() diff --git a/src/rotator_library/anthropic_compat/streaming.py b/src/rotator_library/anthropic_compat/streaming.py index ecb074baa..126ed70ab 100644 --- a/src/rotator_library/anthropic_compat/streaming.py +++ b/src/rotator_library/anthropic_compat/streaming.py @@ -66,11 +66,76 @@ async def anthropic_streaming_wrapper( accumulated_text = "" # Track accumulated text for logging accumulated_thinking = "" # Track accumulated thinking for logging stop_reason_final = "end_turn" # Track final stop reason for logging + upstream_closed = False + stream_iterator = openai_stream.__aiter__() + + def trace_frame(pass_name: str, frame: str, payload: Any | None = None) -> str: + """Trace one Anthropic stream conversion frame and return it for yield.""" + + if transaction_logger: + transaction_logger.log_transform_pass( + pass_name, + payload if payload is not None else frame, + direction="stream", + stage="adapter", + protocol="anthropic_messages", + transport="sse", + snapshot=False, + ) + return frame + + async def close_upstream(reason: str) -> None: + """Close the wrapped OpenAI stream when Anthropic streaming stops early.""" + + nonlocal upstream_closed + if upstream_closed: + return + upstream_closed = True + for candidate in (stream_iterator, openai_stream): + close = getattr(candidate, "aclose", None) + if callable(close): + await close() + break + sync_close = getattr(candidate, "close", None) + if callable(sync_close): + sync_close() + break + if transaction_logger: + transaction_logger.log_transform_pass( + "anthropic_stream_upstream_closed", + {"reason": reason}, + direction="stream", + stage="adapter", + protocol="anthropic_messages", + transport="sse", + snapshot=False, + ) try: - async for chunk_str in openai_stream: + async for chunk_str in stream_iterator: + if transaction_logger: + transaction_logger.log_transform_pass( + "anthropic_stream_source_chunk", + chunk_str, + direction="stream", + stage="provider", + protocol="openai_chat", + transport="sse", + snapshot=False, + ) # Check for client disconnection if callback provided if is_disconnected is not None and await is_disconnected(): + if transaction_logger: + transaction_logger.log_transform_pass( + "anthropic_stream_disconnected", + {"reason": "client_disconnected"}, + direction="stream", + stage="adapter", + protocol="anthropic_messages", + transport="sse", + snapshot=False, + ) + await close_upstream("client_disconnected") break if not chunk_str.strip() or not chunk_str.startswith("data:"): @@ -103,25 +168,25 @@ async def anthropic_streaming_wrapper( "usage": usage_dict, }, } - yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" + yield trace_frame("anthropic_stream_message_start", f"event: message_start\ndata: {json.dumps(message_start)}\n\n", message_start) message_started = True # Close any open thinking block if thinking_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 thinking_block_started = False # Close any open text block if content_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 content_block_started = False # Close all open tool_use blocks for tc_index in sorted(tool_block_indices.keys()): block_idx = tool_block_indices[tc_index] - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {block_idx}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {block_idx}}}\n\n') # Determine stop_reason based on whether we had tool calls stop_reason = "tool_use" if tool_calls_by_index else "end_turn" @@ -134,10 +199,10 @@ async def anthropic_streaming_wrapper( final_usage["cache_creation_input_tokens"] = 0 # Send message_delta with final info - yield f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "{stop_reason}", "stop_sequence": null}}, "usage": {json.dumps(final_usage)}}}\n\n' + yield trace_frame("anthropic_stream_message_delta", f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "{stop_reason}", "stop_sequence": null}}, "usage": {json.dumps(final_usage)}}}\n\n') # Send message_stop - yield 'event: message_stop\ndata: {"type": "message_stop"}\n\n' + yield trace_frame("anthropic_stream_message_stop", 'event: message_stop\ndata: {"type": "message_stop"}\n\n') # Log final Anthropic response if logger provided if transaction_logger: @@ -242,7 +307,7 @@ async def anthropic_streaming_wrapper( "usage": usage_dict, }, } - yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" + yield trace_frame("anthropic_stream_message_start", f"event: message_start\ndata: {json.dumps(message_start)}\n\n", message_start) message_started = True choices = chunk.get("choices") or [] @@ -261,7 +326,7 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "content_block": {"type": "thinking", "thinking": ""}, } - yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + yield trace_frame("anthropic_stream_content_block_start", f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n", block_start) thinking_block_started = True # Send thinking delta @@ -270,7 +335,7 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "delta": {"type": "thinking_delta", "thinking": reasoning_content}, } - yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + yield trace_frame("anthropic_stream_content_delta", f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n", block_delta) # Accumulate thinking for logging accumulated_thinking += reasoning_content @@ -279,7 +344,7 @@ async def anthropic_streaming_wrapper( if content: # If we were in a thinking block, close it first if thinking_block_started and not content_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 thinking_block_started = False @@ -290,7 +355,7 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "content_block": {"type": "text", "text": ""}, } - yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + yield trace_frame("anthropic_stream_content_block_start", f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n", block_start) content_block_started = True # Send content delta @@ -299,7 +364,7 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "delta": {"type": "text_delta", "text": content}, } - yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + yield trace_frame("anthropic_stream_content_delta", f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n", block_delta) # Accumulate text for logging accumulated_text += content @@ -312,13 +377,13 @@ async def anthropic_streaming_wrapper( if tc_index not in tool_calls_by_index: # Close previous thinking block if open if thinking_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 thinking_block_started = False # Close previous text block if open if content_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 content_block_started = False @@ -341,7 +406,7 @@ async def anthropic_streaming_wrapper( "input": {}, }, } - yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + yield trace_frame("anthropic_stream_content_block_start", f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n", block_start) # Increment for the next block current_block_index += 1 @@ -361,7 +426,7 @@ async def anthropic_streaming_wrapper( "partial_json": func["arguments"], }, } - yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + yield trace_frame("anthropic_stream_content_delta", f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n", block_delta) # Note: We intentionally ignore finish_reason here. # Block closing is handled when we receive [DONE] to avoid @@ -369,6 +434,15 @@ async def anthropic_streaming_wrapper( except Exception as e: logger.error(f"Error in Anthropic streaming wrapper: {e}") + if transaction_logger: + transaction_logger.log_transform_error( + "anthropic_stream_transform", + e, + payload={"request_id": request_id, "model": original_model}, + stage="adapter", + protocol="anthropic_messages", + transport="sse", + ) # If we haven't sent message_start yet, send it now so the client can display the error # Claude Code and other clients may ignore events that come before message_start @@ -395,7 +469,7 @@ async def anthropic_streaming_wrapper( "usage": usage_dict, }, } - yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" + yield trace_frame("anthropic_stream_message_start", f"event: message_start\ndata: {json.dumps(message_start)}\n\n", message_start) # Send the error as a text content block so it's visible to the user error_message = f"Error: {str(e)}" @@ -404,16 +478,16 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "content_block": {"type": "text", "text": ""}, } - yield f"event: content_block_start\ndata: {json.dumps(error_block_start)}\n\n" + yield trace_frame("anthropic_stream_content_block_start", f"event: content_block_start\ndata: {json.dumps(error_block_start)}\n\n", error_block_start) error_block_delta = { "type": "content_block_delta", "index": current_block_index, "delta": {"type": "text_delta", "text": error_message}, } - yield f"event: content_block_delta\ndata: {json.dumps(error_block_delta)}\n\n" + yield trace_frame("anthropic_stream_content_delta", f"event: content_block_delta\ndata: {json.dumps(error_block_delta)}\n\n", error_block_delta) - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') # Build final usage with cached tokens final_usage = {"output_tokens": 0} @@ -422,12 +496,15 @@ async def anthropic_streaming_wrapper( final_usage["cache_creation_input_tokens"] = 0 # Send message_delta and message_stop to properly close the stream - yield f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "end_turn", "stop_sequence": null}}, "usage": {json.dumps(final_usage)}}}\n\n' - yield 'event: message_stop\ndata: {"type": "message_stop"}\n\n' + yield trace_frame("anthropic_stream_message_delta", f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "end_turn", "stop_sequence": null}}, "usage": {json.dumps(final_usage)}}}\n\n') + yield trace_frame("anthropic_stream_message_stop", 'event: message_stop\ndata: {"type": "message_stop"}\n\n') # Also send the formal error event for clients that handle it error_event = { "type": "error", "error": {"type": "api_error", "message": str(e)}, } - yield f"event: error\ndata: {json.dumps(error_event)}\n\n" + yield trace_frame("anthropic_stream_error", f"event: error\ndata: {json.dumps(error_event)}\n\n", error_event) + finally: + if not upstream_closed: + await close_upstream("wrapper_exit") diff --git a/src/rotator_library/client/anthropic.py b/src/rotator_library/client/anthropic.py index 507e82fb9..47f1d1e31 100644 --- a/src/rotator_library/client/anthropic.py +++ b/src/rotator_library/client/anthropic.py @@ -94,9 +94,23 @@ async def messages( request.model_dump(exclude_none=True), filename="anthropic_request.json", ) + _trace_anthropic( + anthropic_logger, + "anthropic_raw_request", + request.model_dump(exclude_none=True), + direction="request", + stage="client", + ) # Translate Anthropic request to OpenAI format openai_request = translate_anthropic_request(request) + _trace_anthropic( + anthropic_logger, + "anthropic_to_openai_request", + openai_request, + direction="request", + stage="adapter", + ) # Pass parent log directory to acompletion for nested logging if anthropic_logger and anthropic_logger.log_dir: @@ -138,9 +152,23 @@ async def messages( if hasattr(response, "model_dump") else dict(response) ) + _trace_anthropic( + anthropic_logger, + "anthropic_openai_response", + openai_response, + direction="response", + stage="provider", + ) anthropic_response = openai_to_anthropic_response( openai_response, original_model ) + _trace_anthropic( + anthropic_logger, + "openai_to_anthropic_response", + anthropic_response, + direction="response", + stage="adapter", + ) # Override the ID with our request ID anthropic_response["id"] = request_id @@ -151,6 +179,13 @@ async def messages( anthropic_response, filename="anthropic_response.json", ) + _trace_anthropic( + anthropic_logger, + "anthropic_final_response", + anthropic_response, + direction="response", + stage="final", + ) return anthropic_response @@ -201,3 +236,26 @@ async def count_tokens( total_tokens = message_tokens + tool_tokens return {"input_tokens": total_tokens} + + +def _trace_anthropic( + logger: Optional[TransactionLogger], + pass_name: str, + payload: Any, + *, + direction: str, + stage: str, + transport: Optional[str] = None, +) -> None: + """Emit an Anthropic compatibility transform trace when logging is enabled.""" + + if not logger: + return + logger.log_transform_pass( + pass_name, + payload, + direction=direction, + stage=stage, + protocol="anthropic_messages", + transport=transport, + ) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index d83c61192..c1c062d71 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -18,6 +18,7 @@ import os import random import time +from copy import deepcopy from typing import ( Any, AsyncGenerator, @@ -62,12 +63,22 @@ from ..request_sanitizer import sanitize_request_payload from ..transaction_logger import TransactionLogger from ..failure_logger import log_failure +from ..retry_policy import FailureHistory, decide_provider_cooldown, provider_cooldown_env +from ..routing import FallbackPolicy, clone_context_for_target +from ..routing.policy import normalize_route_error_type +from ..routing.types import RouteTarget +from ..native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from ..native_provider.streaming import provider_supports_native_streaming as native_provider_supports_streaming +from ..field_cache.paths import FieldCachePathError, PathToken, parse_path +from ..transform_trace import REDACTED +from ..usage.accounting import UsageRecord, extract_usage_record +from ..usage.costs import CostBreakdown, CostCalculator from .types import RetryState, AvailabilityStats from .filters import CredentialFilter from .transforms import ProviderTransforms from .streaming import StreamingHandler -from .stream_retry_policy import can_retry_stream_after_error +from ..streaming.policy import can_retry_stream_after_error, is_stream_heartbeat_or_comment, is_visible_stream_output if TYPE_CHECKING: from ..usage import UsageManager @@ -75,6 +86,14 @@ lib_logger = logging.getLogger("rotator_library") +class RoutingExecutionError(RuntimeError): + """Internal error used when a routed target cannot use its requested mode.""" + + def __init__(self, message: str, error_type: str = "configuration_error") -> None: + super().__init__(message) + self.error_type = error_type + + class RequestExecutor: """ Unified retry/rotation logic for all request types. @@ -136,6 +155,8 @@ def __init__( self._litellm_logger_fn = litellm_logger_fn # StreamingHandler no longer needs usage_manager - we pass cred_context directly self._streaming_handler = StreamingHandler() + self._native_executor = NativeProviderExecutor() + self._failure_history = FailureHistory() def _get_transient_retry_delay(self) -> float: """Small jittered delay used before transient retries and rotations.""" @@ -358,6 +379,8 @@ async def _prepare_request_kwargs( model: str, cred: str, context: "RequestContext", + *, + credential_id: Optional[str] = None, ) -> Dict[str, Any]: """ Prepare request kwargs with transforms, sanitization, and provider params. @@ -378,17 +401,58 @@ async def _prepare_request_kwargs( cred, context.kwargs.copy(), provider_config_override=context.provider_config, + transaction_logger=context.transaction_logger, + credential_id=credential_id, + transport="sse" if context.streaming else "http", + trace_metadata={ + "session_id": context.session_id, + "scope_key": context.usage_manager_key, + "classifier": context.classifier, + }, ) - # Sanitize request payload + # Sanitize request payload. Some provider compatibility fields are + # intentionally removed here, so record it as its own transform pass. + before_sanitize = deepcopy(kwargs) if context.transaction_logger else None kwargs = sanitize_request_payload(kwargs, model) + self._log_executor_trace( + context, + "after_request_sanitization", + kwargs, + direction="request", + stage="client", + credential_id=credential_id, + changed_from_previous=(before_sanitize != kwargs) if before_sanitize is not None else None, + metadata={"provider": provider, "model": model}, + ) # Apply provider-specific LiteLLM params + before_params = deepcopy(kwargs) if context.transaction_logger else None self._apply_litellm_provider_params(provider, kwargs) + self._log_executor_trace( + context, + "after_litellm_provider_params", + kwargs, + direction="request", + stage="client", + credential_id=credential_id, + changed_from_previous=(before_params != kwargs) if before_params is not None else None, + metadata={"provider": provider, "model": model}, + ) # Add transaction context for provider logging if context.transaction_logger: kwargs["transaction_context"] = context.transaction_logger.get_context() + self._log_executor_trace( + context, + "after_transaction_context_attached", + kwargs, + direction="request", + stage="client", + credential_id=credential_id, + changed_from_previous=True, + metadata={"provider": provider, "model": model}, + ) return kwargs @@ -472,12 +536,198 @@ async def _run_pre_request_callback( """ if context.pre_request_callback: try: + before = deepcopy(kwargs) if context.transaction_logger else None await context.pre_request_callback(context.request, kwargs) + if before is not None and before != kwargs: + self._log_executor_trace( + context, + "after_pre_request_callback", + kwargs, + direction="request", + stage="client", + changed_from_previous=True, + snapshot=True, + ) except Exception as e: if self._abort_on_callback_error: raise PreRequestCallbackError(str(e)) from e lib_logger.warning(f"Pre-request callback failed: {e}") + async def _execute_provider_request( + self, + provider: str, + model: str, + plugin: Any, + credential_secret: str, + credential_id: str, + kwargs: Dict[str, Any], + context: RequestContext, + ) -> Any: + """Execute one provider request using routed execution-mode rules.""" + + target = _current_route_target(context) + execution = target.execution if target else "auto" + self._log_executor_trace( + context, + "pre_provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=credential_id, + metadata={"execution": execution, "provider": provider, "model": model}, + ) + if execution == "litellm_fallback": + self._log_routing_trace( + context, + "routing_litellm_fallback", + _target_trace(target) if target else {"provider": provider, "model": model}, + ) + return await self._execute_litellm_request(kwargs, credential_secret, context=context, credential_id=credential_id) + + if execution == "custom" or (execution == "auto" and plugin and plugin.has_custom_logic()): + if not plugin or not plugin.has_custom_logic(): + raise RoutingExecutionError(f"Provider {provider} does not support custom execution") + kwargs["credential_identifier"] = credential_secret + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=credential_id, + metadata={"execution": "custom", "provider": provider, "model": model}, + ) + return await plugin.acompletion(self._http_client, **kwargs) + + if execution == "native" or (execution == "auto" and _should_use_native_protocol(plugin, model, target, kwargs, stream=False, execution=execution)): + native_context, native_request = self._build_native_provider_context( + provider, + model, + plugin, + credential_secret, + credential_id, + context, + target, + raw_request=kwargs, + return_request=True, + ) + self._log_routing_trace( + context, + "routing_native_execution_selected", + _target_trace(target) if target else {"provider": provider, "model": model}, + metadata={"protocol": native_context.protocol_name, "operation": native_context.operation}, + ) + return await self._get_native_executor().execute(native_request, native_context, NativeHTTPTransport(self._http_client)) + + return await self._execute_litellm_request(kwargs, credential_secret, context=context, credential_id=credential_id) + + def _get_native_executor(self) -> NativeProviderExecutor: + """Return the shared native executor for process-local field-cache state.""" + + native_executor = getattr(self, "_native_executor", None) + if native_executor is None: + native_executor = NativeProviderExecutor() + self._native_executor = native_executor + return native_executor + + async def _execute_litellm_request( + self, + kwargs: Dict[str, Any], + credential_secret: str, + *, + context: Optional[RequestContext] = None, + credential_id: Optional[str] = None, + ) -> Any: + """Execute the existing LiteLLM request path.""" + + kwargs["api_key"] = credential_secret + self._apply_litellm_logger(kwargs) + kwargs.pop("transaction_context", None) + if context: + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=credential_id, + metadata={"execution": "litellm", "provider": context.provider, "model": context.model}, + ) + return await litellm.acompletion(**kwargs) + + def _build_native_provider_context( + self, + provider: str, + model: str, + plugin: Any, + credential_secret: str, + credential_id: str, + context: RequestContext, + target: Optional[RouteTarget], + raw_request: Optional[Dict[str, Any]] = None, + transport: str = "http", + stream: bool = False, + return_request: bool = False, + ) -> NativeProviderContext | tuple[NativeProviderContext, Dict[str, Any]]: + """Build native provider context from provider declarations.""" + + if not plugin: + raise RoutingExecutionError(f"Provider {provider} has no plugin for native execution") + protocol_name = _provider_native_protocol(plugin, model, target) + if not protocol_name: + raise RoutingExecutionError(f"Provider {provider} has no native protocol declaration") + public_model = model + native_model = plugin.normalize_native_model(model) if hasattr(plugin, "normalize_native_model") else _strip_provider_prefix(model) + request_payload = _native_request_payload(raw_request or {}) + request_payload["_proxy_model"] = public_model + if native_model: + request_payload["model"] = native_model + operation = plugin.get_native_operation(native_model, request_payload, stream=stream) if hasattr(plugin, "get_native_operation") else "chat" + if hasattr(plugin, "supports_native_operation") and not plugin.supports_native_operation(native_model, operation): + raise RoutingExecutionError(f"Provider {provider} does not support native operation {operation}") + if hasattr(plugin, "prepare_native_request"): + prepared = plugin.prepare_native_request(request_payload, model=native_model, operation=operation) + if prepared is not request_payload: + request_payload = dict(prepared) + request_payload.pop("_proxy_model", None) + self._log_executor_trace( + context, + "provider_native_request_prepared", + request_payload, + direction="request", + stage="provider", + credential_id=credential_id, + metadata={"provider": provider, "model": public_model, "native_model": native_model, "operation": operation}, + ) + request_payload.pop("_proxy_model", None) + try: + endpoint = plugin.get_native_endpoint(model=native_model, operation=operation) + headers = plugin.get_native_headers(credential_secret, model=native_model, operation=operation) + except NotImplementedError as exc: + raise RoutingExecutionError(str(exc)) from exc + native_context = NativeProviderContext( + provider=provider, + model=native_model, + protocol_name=protocol_name, + endpoint=endpoint, + operation=operation, + client_protocol_name="openai_chat", + headers=headers, + credential_id=credential_id, + session_id=context.session_id, + scope_key=context.usage_manager_key, + classifier=context.classifier, + transport=transport, + adapter_names=tuple(plugin.get_adapter_names(native_model) if hasattr(plugin, "get_adapter_names") else ()), + adapter_config=dict(plugin.get_adapter_config(native_model) if hasattr(plugin, "get_adapter_config") else {}), + field_cache_rules=_merged_field_cache_rules(provider, public_model, plugin), + transaction_logger=context.transaction_logger, + metadata={"public_model": public_model}, + ) + if return_request: + return native_context, request_payload + return native_context + async def execute( self, context: RequestContext, @@ -493,11 +743,294 @@ async def execute( Returns: Response object or async generator for streaming """ + if context.streaming and context.routing_targets: + return self._execute_streaming_with_fallback(context) if context.streaming: return self._execute_streaming(context) + elif context.routing_targets: + return await self._execute_non_streaming_with_fallback(context) else: return await self._execute_non_streaming(context) + async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> Any: + """Execute an ordered non-streaming fallback target chain. + + The normal single-target path remains `_execute_non_streaming()`. This + wrapper only runs when request building has populated `routing_targets`, + preserving existing behavior for all current requests. + """ + + targets = tuple(context.routing_targets or ()) + if not targets: + return await self._execute_non_streaming(context) + policy = FallbackPolicy() + last_failure: Any = None + target_failures: List[Dict[str, Any]] = [] + self._log_routing_trace( + context, + "routing_decision", + {"requested_model": context.model, "target_count": len(targets)}, + metadata={"group": context.routing_group_name, "targets": [_target_trace(target) for target in targets]}, + ) + for index, target in enumerate(targets): + attempt_started_at = time.monotonic() + target_context = clone_context_for_target( + context, + target, + target_index=index, + credentials=_target_scope_value(target, "credentials", context.credentials), + usage_manager_key=_target_scope_value(target, "usage_manager_key", target.provider), + provider_config=_target_scope_value(target, "provider_config", context.provider_config), + credential_secrets=_target_scope_value(target, "credential_secrets", context.credential_secrets), + ) + self._log_routing_trace( + context, + "routing_target_attempt_started", + _target_trace(target), + metadata={"target_index": index, "group": context.routing_group_name}, + ) + try: + result = await self._execute_non_streaming(target_context) + except Exception as exc: + last_failure = exc + error_type = _route_error_type(exc, target.provider) + self._log_routing_trace( + context, + "routing_target_attempt_failed", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type, "exception": exc.__class__.__name__}, + ) + target_failures.append(_target_failure_summary(target, error_type)) + fallback_allowed = index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group) + _append_routing_attempt_history(context, target, index, success=False, error_type=error_type, fallback_allowed=fallback_allowed, duration_ms=_elapsed_ms(attempt_started_at)) + if not fallback_allowed: + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "fallback_targets": target_failures}) + raise + self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) + continue + + error_type = _route_error_type_from_response(result) + if error_type: + last_failure = result + self._log_routing_trace( + context, + "routing_target_attempt_failed", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type}, + ) + target_failures.append(_target_failure_summary(target, error_type, status_code=_route_status_code_from_response(result))) + fallback_allowed = index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group) + _append_routing_attempt_history(context, target, index, success=False, error_type=error_type, status_code=_route_status_code_from_response(result), fallback_allowed=fallback_allowed, duration_ms=_elapsed_ms(attempt_started_at)) + if fallback_allowed: + self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) + continue + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "fallback_targets": target_failures}) + return _with_fallback_summary(result, target_failures, context.routing_attempt_history) + + self._log_routing_trace( + context, + "routing_target_attempt_succeeded", + _target_trace(target), + metadata={"target_index": index}, + ) + _append_routing_attempt_history(context, target, index, success=True, duration_ms=_elapsed_ms(attempt_started_at)) + return result + + if isinstance(last_failure, Exception): + raise last_failure + return last_failure + + async def _execute_streaming_with_fallback(self, context: RequestContext) -> AsyncGenerator[str, None]: + """Execute streaming fallback targets with pre-output-only failover.""" + + targets = tuple(context.routing_targets or ()) + if not targets: + async for chunk in self._execute_streaming(context): + yield chunk + return + policy = FallbackPolicy() + target_failures: List[Dict[str, Any]] = [] + self._log_routing_trace( + context, + "routing_decision", + {"requested_model": context.model, "target_count": len(targets), "stream": True}, + metadata={"group": context.routing_group_name, "streaming_policy": _group_streaming_policy(context.routing_group), "targets": [_target_trace(target) for target in targets]}, + ) + for index, target in enumerate(targets): + attempt_started_at = time.monotonic() + emitted_output = False + pending_chunks: List[str] = [] + terminal_error_type: Optional[str] = None + target_context = clone_context_for_target( + context, + target, + target_index=index, + credentials=_target_scope_value(target, "credentials", context.credentials), + usage_manager_key=_target_scope_value(target, "usage_manager_key", target.provider), + provider_config=_target_scope_value(target, "provider_config", context.provider_config), + credential_secrets=_target_scope_value(target, "credential_secrets", context.credential_secrets), + ) + self._log_routing_trace( + context, + "routing_stream_target_attempt_started", + _target_trace(target), + metadata={"target_index": index, "group": context.routing_group_name}, + ) + try: + async for chunk in self._execute_streaming(target_context): + if is_stream_heartbeat_or_comment(chunk): + yield chunk + continue + chunk_error_type = _stream_chunk_error_type(chunk) + if chunk_error_type and not emitted_output: + terminal_error_type = chunk_error_type + pending_chunks.append(chunk) + continue + if _stream_chunk_is_visible_output(chunk): + for pending in pending_chunks: + yield pending + pending_chunks.clear() + emitted_output = True + yield chunk + continue + pending_chunks.append(chunk) + if terminal_error_type and not emitted_output: + error_type = terminal_error_type + self._log_routing_trace( + context, + "routing_stream_target_attempt_failed", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type, "emitted_output": emitted_output, "terminal_error_frame": True}, + ) + target_failures.append(_target_failure_summary(target, error_type)) + fallback_allowed = index < len(targets) - 1 and _streaming_policy_allows_fallback(context.routing_group) and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False) + _append_routing_attempt_history(context, target, index, success=False, error_type=error_type, emitted_output=emitted_output, fallback_allowed=fallback_allowed, duration_ms=_elapsed_ms(attempt_started_at)) + if fallback_allowed: + self._log_routing_trace( + context, + "routing_fallback_selected", + _target_trace(targets[index + 1]), + metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type, "stream": True}, + ) + continue + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "stream": True, "streaming_policy": _group_streaming_policy(context.routing_group), "fallback_targets": target_failures}) + for pending in pending_chunks: + yield pending + return + for pending in pending_chunks: + yield pending + self._log_routing_trace( + context, + "routing_stream_target_attempt_succeeded", + _target_trace(target), + metadata={"target_index": index, "emitted_output": emitted_output}, + ) + _append_routing_attempt_history(context, target, index, success=True, emitted_output=emitted_output, duration_ms=_elapsed_ms(attempt_started_at)) + return + except Exception as exc: + error_type = _route_error_type(exc, target.provider) + self._log_routing_trace( + context, + "routing_stream_target_attempt_failed", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type, "emitted_output": emitted_output}, + ) + target_failures.append(_target_failure_summary(target, error_type)) + fallback_allowed = (not emitted_output) and index < len(targets) - 1 and _streaming_policy_allows_fallback(context.routing_group) and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False) + _append_routing_attempt_history(context, target, index, success=False, error_type=error_type, emitted_output=emitted_output, fallback_allowed=fallback_allowed, duration_ms=_elapsed_ms(attempt_started_at)) + if emitted_output: + self._log_routing_trace( + context, + "routing_stream_fallback_blocked_after_output", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type}, + ) + raise + if fallback_allowed: + self._log_routing_trace( + context, + "routing_fallback_selected", + _target_trace(targets[index + 1]), + metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type, "stream": True}, + ) + continue + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "stream": True, "streaming_policy": _group_streaming_policy(context.routing_group), "fallback_targets": target_failures}) + raise + + @staticmethod + def _log_routing_trace(context: RequestContext, pass_name: str, data: Any, *, metadata: Optional[Dict[str, Any]] = None) -> None: + """Record routing trace entries without affecting request execution.""" + + if not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + pass_name, + data, + direction="metadata", + stage="routing", + metadata=metadata or {}, + snapshot=False, + ) + + @staticmethod + def _log_executor_trace( + context: RequestContext, + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + credential_id: Optional[str] = None, + changed_from_previous: Optional[bool] = None, + metadata: Optional[Dict[str, Any]] = None, + snapshot: bool = True, + ) -> None: + """Record live executor trace boundaries without affecting requests.""" + + if not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + pass_name, + data, + direction=direction, + stage=stage, + credential_id=credential_id, + transport="sse" if context.streaming else "http", + changed_from_previous=changed_from_previous, + metadata={ + "provider": context.provider, + "model": context.model, + "session_id": context.session_id, + "scope_key": context.usage_manager_key, + "classifier": context.classifier, + **(metadata or {}), + }, + snapshot=snapshot, + ) + + def _terminal_stream_error_lines(self, context: RequestContext, error_data: Dict[str, Any]) -> Tuple[str, str]: + """Return executor-created terminal SSE lines and trace them first.""" + + error_line = f"data: {json.dumps(error_data)}\n\n" + done_line = "data: [DONE]\n\n" + self._log_executor_trace( + context, + "stream_error_event", + error_data, + direction="stream", + stage="client", + snapshot=False, + ) + self._log_executor_trace( + context, + "stream_done_event", + {"raw": done_line}, + direction="stream", + stage="final", + snapshot=False, + ) + return error_line, done_line + async def _prepare_execution( self, context: RequestContext, @@ -569,7 +1102,7 @@ async def _execute_non_streaming( break # Wait for provider cooldown - await self._wait_for_cooldown(provider, deadline) + await self._wait_for_cooldown(provider, deadline, model=model, context=context) # Acquire credential using context manager try: @@ -602,13 +1135,24 @@ async def _execute_non_streaming( try: # Prepare request kwargs kwargs = await self._prepare_request_kwargs( - provider, model, cred, context + provider, + model, + cred, + context, + credential_id=cred_context.stable_id, ) # Log transformed request if it differs from original if context.transaction_logger: context.transaction_logger.log_transformed_request( - kwargs, context.kwargs + kwargs, + context.kwargs, + credential_id=cred_context.stable_id, + metadata={ + "session_id": context.session_id, + "scope_key": context.usage_manager_key, + "classifier": context.classifier, + }, ) # Get provider plugin @@ -624,31 +1168,36 @@ async def _execute_non_streaming( # Pre-request callback await self._run_pre_request_callback(context, kwargs) - # Make the API call - if plugin and plugin.has_custom_logic(): - kwargs["credential_identifier"] = credential_secret - response = await plugin.acompletion( - self._http_client, **kwargs - ) - else: - # Standard LiteLLM call - kwargs["api_key"] = credential_secret - self._apply_litellm_logger(kwargs) - # Remove internal context before litellm call - kwargs.pop("transaction_context", None) - response = await litellm.acompletion(**kwargs) + response = await self._execute_provider_request( + provider, + model, + plugin, + credential_secret, + cred_context.stable_id, + kwargs, + context, + ) + trace_response = _redact_context_field_cache_paths(response, context, "response", plugin) + self._log_executor_trace( + context, + "raw_provider_response", + trace_response, + direction="response", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"provider": provider, "model": model}, + ) # Success! Extract token usage if available - ( - prompt_tokens, - completion_tokens, - prompt_tokens_cached, - prompt_tokens_cache_write, - thinking_tokens, - ) = self._extract_usage_tokens(response) - approx_cost = self._calculate_cost( - provider, model, response + usage_record, cost_breakdown = self._account_for_response_usage( + provider, model, response, context ) + prompt_tokens = usage_record.prompt_tokens_for_mark_success + completion_tokens = usage_record.completion_tokens + prompt_tokens_cached = usage_record.cache_read_tokens + prompt_tokens_cache_write = usage_record.cache_write_tokens + thinking_tokens = usage_record.reasoning_tokens + approx_cost = cost_breakdown.total_cost response_headers = self._extract_response_headers( response ) @@ -663,30 +1212,57 @@ async def _execute_non_streaming( approx_cost=approx_cost, response_headers=response_headers, ) + self._clear_failure_history_on_success(provider, model) self._record_session_response(context, response) lib_logger.info( f"Recorded usage from response object for key {mask_credential(cred)}" ) - # Log response if transaction logging enabled + normalized_response = self._normalize_response_usage(response, model) + trace_normalized_response = _redact_context_field_cache_paths(normalized_response, context, "response", plugin) + self._log_executor_trace( + context, + "post_usage_normalization_response", + trace_normalized_response, + direction="response", + stage="final", + credential_id=cred_context.stable_id, + metadata={"provider": provider, "model": model}, + ) + # Legacy response.json and final_client_response must + # reflect the same post-normalization payload returned + # to the caller, not the pre-accounting SDK object. if context.transaction_logger: try: - response_data = ( - response.model_dump() - if hasattr(response, "model_dump") - else response - ) - context.transaction_logger.log_response( - response_data - ) + context.transaction_logger.log_response(trace_normalized_response) except Exception as log_err: lib_logger.debug( f"Failed to log response: {log_err}" ) + return normalized_response - return self._normalize_response_usage(response, model) - + except RoutingExecutionError as e: + if e.error_type == "configuration_error": + raise + last_exception = e + action = await self._handle_error_with_context( + e, + cred_context, + model, + provider, + attempt, + error_accumulator, + retry_state, + request_headers, + context, + ) + if action == ErrorAction.RETRY_SAME: + continue + elif action == ErrorAction.ROTATE: + break + else: + raise except Exception as e: last_exception = e action = await self._handle_error_with_context( @@ -698,6 +1274,7 @@ async def _execute_non_streaming( error_accumulator, retry_state, request_headers, + context, ) if action == ErrorAction.RETRY_SAME: @@ -709,6 +1286,9 @@ async def _execute_non_streaming( except PreRequestCallbackError: raise + except RoutingExecutionError as exc: + if exc.error_type == "configuration_error": + raise except Exception: # Let context manager handle cleanup pass @@ -758,8 +1338,8 @@ async def _execute_streaming( "type": "proxy_error", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line return error_accumulator = RequestErrorAccumulator() @@ -768,6 +1348,7 @@ async def _execute_streaming( retry_state = RetryState() last_exception: Optional[Exception] = None + last_stream_error_payload: Optional[Dict[str, Any]] = None try: while time.time() < deadline: @@ -785,7 +1366,7 @@ async def _execute_streaming( remaining = deadline - time.time() if remaining <= 0: break - await self._wait_for_cooldown(provider, deadline) + await self._wait_for_cooldown(provider, deadline, model=model, context=context) # Acquire credential using context manager try: @@ -818,13 +1399,24 @@ async def _execute_streaming( try: # Prepare request kwargs kwargs = await self._prepare_request_kwargs( - provider, model, cred, context + provider, + model, + cred, + context, + credential_id=cred_context.stable_id, ) # Log transformed request if it differs from original if context.transaction_logger: context.transaction_logger.log_transformed_request( - kwargs, context.kwargs + kwargs, + context.kwargs, + credential_id=cred_context.stable_id, + metadata={ + "session_id": context.session_id, + "scope_key": context.usage_manager_key, + "classifier": context.classifier, + }, ) # Add stream usage metadata for active providers. @@ -843,6 +1435,7 @@ async def _execute_streaming( # Execute request with retries for attempt in range(self._max_retries): last_streamed_chunk: Optional[str] = None + stream_visible_output_emitted = False try: lib_logger.info( @@ -854,20 +1447,93 @@ async def _execute_streaming( context, kwargs ) - # Make the API call - if plugin and plugin.has_custom_logic(): + target = _current_route_target(context) + execution = target.execution if target else "auto" + + # Make the API call. Keep execution-mode precedence aligned with + # the non-streaming path: explicit LiteLLM wins, explicit custom + # requires custom logic, explicit native fails closed, and auto + # prefers custom before native streaming. + if execution == "litellm_fallback": + kwargs["api_key"] = credential_secret + kwargs["stream"] = True + self._apply_litellm_logger(kwargs) + kwargs.pop("transaction_context", None) + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"execution": "litellm_stream", "provider": provider, "model": model}, + ) + stream = await litellm.acompletion(**kwargs) + elif execution == "custom" or (execution == "auto" and plugin and plugin.has_custom_logic()): + if not plugin or not plugin.has_custom_logic(): + raise RoutingExecutionError(f"Provider {provider} does not support custom execution") kwargs["credential_identifier"] = credential_secret + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"execution": "custom_stream", "provider": provider, "model": model}, + ) stream = await plugin.acompletion( self._http_client, **kwargs ) + elif _should_use_native_streaming(plugin, model, target, execution, provider): + native_context, native_request = self._build_native_provider_context( + provider, + model, + plugin, + credential_secret, + cred_context.stable_id, + context, + target, + raw_request=kwargs, + transport="sse", + stream=True, + return_request=True, + ) + self._log_routing_trace( + context, + "routing_native_stream_execution_selected", + _target_trace(target) if target else {"provider": provider, "model": model}, + metadata={"protocol": native_context.protocol_name, "operation": native_context.operation}, + ) + stream = self._get_native_executor().stream(native_request, native_context, NativeHTTPTransport(self._http_client)) else: kwargs["api_key"] = credential_secret kwargs["stream"] = True self._apply_litellm_logger(kwargs) # Remove internal context before litellm call kwargs.pop("transaction_context", None) + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"execution": "litellm_stream", "provider": provider, "model": model}, + ) stream = await litellm.acompletion(**kwargs) + self._log_executor_trace( + context, + "provider_stream_opened", + {"stream_type": f"{type(stream).__module__}.{type(stream).__name__}"}, + direction="response", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"provider": provider, "model": model}, + snapshot=False, + ) + # Hand off to streaming handler with cred_context # The handler will call mark_success on completion base_stream = self._streaming_handler.wrap_stream( @@ -880,6 +1546,8 @@ async def _execute_streaming( response_callback=lambda response: self._record_session_response( context, response ), + success_callback=lambda: self._clear_failure_history_on_success(provider, model), + transaction_logger=context.transaction_logger, ) lib_logger.info( @@ -895,12 +1563,20 @@ async def _execute_streaming( base_stream, context.transaction_logger, context.kwargs, + context=context, + plugin=plugin, ): - last_streamed_chunk = chunk + if not is_stream_heartbeat_or_comment(chunk): + last_streamed_chunk = chunk + if _stream_chunk_is_visible_output(chunk): + stream_visible_output_emitted = True yield chunk else: async for chunk in base_stream: - last_streamed_chunk = chunk + if not is_stream_heartbeat_or_comment(chunk): + last_streamed_chunk = chunk + if _stream_chunk_is_visible_output(chunk): + stream_visible_output_emitted = True yield chunk return @@ -908,6 +1584,12 @@ async def _execute_streaming( last_exception = e original = getattr(e, "data", e) classified = classify_error(original, provider) + last_stream_error_payload = _streamed_error_payload(e, classified) + if _can_start_stream_provider_cooldown( + last_streamed_chunk, + emitted_output=stream_visible_output_emitted, + ): + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=original) log_failure( api_key=cred, model=model, @@ -934,8 +1616,8 @@ async def _execute_streaming( "type": "quota_exhausted", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line return else: retry_state.reset_quota_failures() @@ -944,19 +1626,11 @@ async def _execute_streaming( cred_context.mark_failure(classified) raise - if not can_retry_stream_after_error( - last_streamed_chunk, - self._stream_retry_on_reasoning_only_enabled(), - ): + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): cred_context.mark_failure(classified) - error_data = { - "error": { - "message": "Upstream stream failed after output began", - "type": classified.error_type, - } - } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + error_data = _streamed_error_payload(e, classified) + for line in self._terminal_stream_error_lines(context, error_data): + yield line return small_cooldown_threshold = int( @@ -993,6 +1667,11 @@ async def _execute_streaming( except (RateLimitError, httpx.HTTPStatusError) as e: last_exception = e classified = classify_error(e, provider) + if _can_start_stream_provider_cooldown( + last_streamed_chunk, + emitted_output=stream_visible_output_emitted, + ): + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( api_key=cred, model=model, @@ -1019,8 +1698,8 @@ async def _execute_streaming( "type": "quota_exhausted", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line return else: retry_state.reset_quota_failures() @@ -1029,6 +1708,13 @@ async def _execute_streaming( cred_context.mark_failure(classified) raise + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): + cred_context.mark_failure(classified) + error_data = {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line + return + # Check for small cooldown - retry same key instead of rotating small_cooldown_threshold = int( os.environ.get( @@ -1068,6 +1754,11 @@ async def _execute_streaming( ) as e: last_exception = e classified = classify_error(e, provider) + if _can_start_stream_provider_cooldown( + last_streamed_chunk, + emitted_output=stream_visible_output_emitted, + ): + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( api_key=cred, model=model, @@ -1076,6 +1767,13 @@ async def _execute_streaming( request_headers=request_headers, ) + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): + cred_context.mark_failure(classified) + error_data = {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line + return + if attempt >= self._max_retries - 1: error_accumulator.record_error( cred, classified, str(e)[:150] @@ -1101,9 +1799,35 @@ async def _execute_streaming( continue # Retry + except RoutingExecutionError as e: + if e.error_type == "configuration_error": + raise + last_exception = e + classified = classify_error(e, provider) + if _can_start_stream_provider_cooldown(last_streamed_chunk, emitted_output=stream_visible_output_emitted): + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) + log_failure(api_key=cred, model=model, attempt=attempt + 1, error=e, request_headers=request_headers) + error_accumulator.record_error(cred, classified, str(e)[:150]) + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): + cred_context.mark_failure(classified) + error_data = {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line + return + cred_context.mark_failure(classified) + break + except Exception as e: last_exception = e classified = classify_error(e, provider) + if _can_start_stream_provider_cooldown( + last_streamed_chunk, + emitted_output=stream_visible_output_emitted, + ): + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( api_key=cred, model=model, @@ -1119,6 +1843,13 @@ async def _execute_streaming( cred_context.mark_failure(classified) raise + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): + cred_context.mark_failure(classified) + error_data = {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line + return + small_cooldown_threshold = int( os.environ.get( "SMALL_COOLDOWN_RETRY_THRESHOLD", @@ -1152,6 +1883,9 @@ async def _execute_streaming( except PreRequestCallbackError: raise + except RoutingExecutionError as exc: + if exc.error_type == "configuration_error": + raise except Exception: # Let context manager handle cleanup pass @@ -1162,20 +1896,23 @@ async def _execute_streaming( # All credentials exhausted or timeout error_accumulator.timeout_occurred = time.time() >= deadline error_data = error_accumulator.build_client_error_response() - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + if last_stream_error_payload: + _merge_stream_error_details(error_data, last_stream_error_payload) + for line in self._terminal_stream_error_lines(context, error_data): + yield line except NoAvailableKeysError as e: lib_logger.error(f"No keys available: {e}") error_data = {"error": {"message": str(e), "type": "proxy_busy"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line except Exception as e: lib_logger.error(f"Unhandled exception in streaming: {e}", exc_info=True) - error_data = {"error": {"message": str(e), "type": "proxy_internal_error"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + classified = classify_error(e, context.provider) + error_data = {"error": {"message": "Streaming request failed", "type": classified.error_type, "details": {"status_code": classified.status_code}}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line def _apply_litellm_provider_params( self, provider: str, kwargs: Dict[str, Any] @@ -1209,9 +1946,12 @@ async def _wait_for_cooldown( self, provider: str, deadline: float, + *, + model: Optional[str] = None, + context: Optional[RequestContext] = None, ) -> None: """ - Wait for provider-level cooldown to end. + Wait for provider/model cooldown to end. Args: provider: Provider name @@ -1220,15 +1960,31 @@ async def _wait_for_cooldown( if not self._cooldown: return - remaining = await self._cooldown.get_remaining_cooldown(provider) + if hasattr(self._cooldown, "get_max_remaining"): + remaining = await self._cooldown.get_max_remaining(provider, model=model) + else: + remaining = await self._cooldown.get_remaining_cooldown(provider) if remaining > 0: budget = deadline - time.time() if remaining > budget: lib_logger.warning( f"Provider {provider} cooldown ({remaining:.1f}s) exceeds budget ({budget:.1f}s)" ) - return # Will fail on no keys available + self._log_provider_cooldown_trace( + context, + "cooldown_wait_exceeds_budget", + provider, + ClassifiedError(error_type="rate_limit", original_exception=RuntimeError("cooldown exceeds budget"), retry_after=int(remaining)), + int(remaining), + "cooldown_exceeds_budget", + model=model, + ) + raise RoutingExecutionError( + f"Provider {provider} cooldown exceeds request budget", + error_type="rate_limit", + ) lib_logger.info(f"Waiting {remaining:.1f}s for {provider} cooldown") + self._log_cooldown_wait_trace(context, provider, model, remaining) await asyncio.sleep(remaining) async def _handle_error_with_context( @@ -1241,6 +1997,7 @@ async def _handle_error_with_context( error_accumulator: RequestErrorAccumulator, retry_state: RetryState, request_headers: Dict[str, Any], + context: Optional[RequestContext] = None, ) -> str: """ Handle an error and determine next action. @@ -1289,6 +2046,8 @@ async def _handle_error_with_context( cred_context.mark_failure(classified) return ErrorAction.FAIL + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=error) + # Check if should retry same key (including small cooldown auto-retry) small_cooldown_threshold = int( os.environ.get( @@ -1334,6 +2093,135 @@ async def _handle_error_with_context( ) return ErrorAction.ROTATE + async def _maybe_start_provider_cooldown( + self, + provider: str, + classified: ClassifiedError, + *, + context: Optional[RequestContext], + model: Optional[str] = None, + original_error: Any = None, + ) -> None: + """Start provider-wide cooldown for large provider-level throttles. + + This is intentionally conservative: small retry-after values stay on the + same credential path, and quota cooldown is disabled by default because + most quota errors are per credential or account. + """ + + if not self._cooldown: + return + small_cooldown_threshold = int( + os.environ.get( + "SMALL_COOLDOWN_RETRY_THRESHOLD", DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD + ) + ) + min_seconds, default_seconds, cooldown_on_quota = provider_cooldown_env() + decision = decide_provider_cooldown( + classified, + small_cooldown_threshold=small_cooldown_threshold, + provider_cooldown_min_seconds=min_seconds, + default_duration=default_seconds, + cooldown_on_quota=cooldown_on_quota, + provider=provider, + model=model, + original_error=original_error, + failure_history=getattr(self, "_failure_history", None), + ) + if not decision.should_start: + if decision.reason == "transient_backoff_threshold_not_met" and classified.error_type in {"server_error", "api_connection"}: + history = getattr(self, "_failure_history", None) + if history is not None: + history.record(provider=provider, model=decision.model, error_type=classified.error_type, scope=decision.scope, duration=0, reason=decision.reason) + self._log_provider_cooldown_trace( + context, + "provider_cooldown_skipped", + provider, + classified, + decision.duration, + decision.reason, + scope=decision.scope, + model=decision.model, + backoff_level=decision.backoff_level, + ) + return + try: + if hasattr(self._cooldown, "start_scoped_cooldown"): + await self._cooldown.start_scoped_cooldown(provider, decision.duration, model=decision.model, scope=decision.scope, reason=decision.reason) + else: + await self._cooldown.start_cooldown(provider, decision.duration) + history = getattr(self, "_failure_history", None) + if history is not None: + history.record(provider=provider, model=decision.model, error_type=classified.error_type, scope=decision.scope, duration=decision.duration, reason=decision.reason) + self._log_provider_cooldown_trace( + context, + "provider_cooldown_started", + provider, + classified, + decision.duration, + decision.reason, + scope=decision.scope, + model=decision.model, + backoff_level=decision.backoff_level, + ) + except Exception as exc: + lib_logger.debug("Failed to start provider cooldown for %s: %s", provider, exc) + + def _clear_failure_history_on_success(self, provider: str, model: Optional[str]) -> None: + """Clear transient failure history after a successful provider/model call.""" + + history = getattr(self, "_failure_history", None) + if history is not None and hasattr(history, "clear"): + history.clear(provider=provider, model=model) + + @staticmethod + def _log_provider_cooldown_trace( + context: Optional[RequestContext], + pass_name: str, + provider: str, + classified: ClassifiedError, + duration: int, + reason: str, + *, + scope: str = "provider", + model: Optional[str] = None, + backoff_level: int = 0, + ) -> None: + if not context or not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + pass_name, + {"provider": provider, "model": model, "scope": scope, "error_type": classified.error_type, "duration": duration}, + direction="metadata", + stage="retry", + metadata={ + "provider": provider, + "duration": duration, + "scope": scope, + "model": model, + "backoff_level": backoff_level, + "error_type": classified.error_type, + "retry_after_present": classified.retry_after is not None, + "reason": reason, + }, + snapshot=False, + ) + + @staticmethod + def _log_cooldown_wait_trace(context: Optional[RequestContext], provider: str, model: Optional[str], remaining: float) -> None: + """Trace cooldown waits without exposing credentials or payloads.""" + + if not context or not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + "cooldown_wait", + {"provider": provider, "model": model, "remaining": remaining}, + direction="metadata", + stage="retry", + metadata={"provider": provider, "model": model, "remaining": remaining}, + snapshot=False, + ) + def _record_session_response(self, context: RequestContext, response: Any) -> None: """Let the tracker learn anchors emitted by a successful response. @@ -1390,60 +2278,52 @@ async def _validate_request( if isinstance(result, str): raise ValueError(result) - def _extract_usage_tokens(self, response: Any) -> tuple[int, int, int, int, int]: - prompt_tokens = 0 - completion_tokens = 0 - cached_tokens = 0 - cache_write_tokens = 0 - thinking_tokens = 0 - - if hasattr(response, "usage") and response.usage: - prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 - completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 - - prompt_details = getattr(response.usage, "prompt_tokens_details", None) - if prompt_details: - if isinstance(prompt_details, dict): - cached_tokens = prompt_details.get("cached_tokens", 0) or 0 - cache_write_tokens = ( - prompt_details.get("cache_creation_tokens", 0) or 0 - ) - else: - cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0 - cache_write_tokens = ( - getattr(prompt_details, "cache_creation_tokens", 0) or 0 - ) + def _account_for_response_usage( + self, + provider: str, + model: str, + response: Any, + context: RequestContext, + ) -> tuple[UsageRecord, CostBreakdown]: + """Normalize usage and advisory cost for one successful response.""" - completion_details = getattr( - response.usage, "completion_tokens_details", None - ) - if completion_details: - if isinstance(completion_details, dict): - thinking_tokens = completion_details.get("reasoning_tokens", 0) or 0 - else: - thinking_tokens = ( - getattr(completion_details, "reasoning_tokens", 0) or 0 - ) + usage_record = extract_usage_record( + response, + provider=provider, + model=model, + source="executor_response", + ) + plugin = self._get_plugin_instance(provider) + cost_breakdown = CostCalculator(provider_plugin=plugin).calculate( + usage_record, + model=model, + response=response, + ) + self._trace_usage_accounting(context, usage_record, cost_breakdown) + return usage_record, cost_breakdown - cache_read_tokens = getattr(response.usage, "cache_read_tokens", None) - if cache_read_tokens is not None: - cached_tokens = cache_read_tokens or 0 - cache_creation_tokens = getattr( - response.usage, "cache_creation_tokens", None - ) - if cache_creation_tokens is not None: - cache_write_tokens = cache_creation_tokens or 0 - - if thinking_tokens and completion_tokens >= thinking_tokens: - completion_tokens = completion_tokens - thinking_tokens - - uncached_prompt = max(0, prompt_tokens - cached_tokens) - return ( - uncached_prompt, - completion_tokens, - cached_tokens, - cache_write_tokens, - thinking_tokens, + @staticmethod + def _trace_usage_accounting( + context: RequestContext, + usage_record: UsageRecord, + cost_breakdown: CostBreakdown, + ) -> None: + """Record normalized usage/cost trace data without affecting requests.""" + + if not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + "usage_accounting_summary", + {"usage": usage_record.to_dict(), "cost": cost_breakdown.to_dict()}, + direction="metadata", + stage="final", + metadata={ + "provider": usage_record.provider, + "model": usage_record.model, + "source": usage_record.source, + "pricing_source": cost_breakdown.pricing_source, + }, + snapshot=False, ) @staticmethod @@ -1453,39 +2333,22 @@ def _normalize_response_usage(response: Any, model: str) -> Any: Delegates to normalize_usage_for_response which handles both dicts (streaming) and pydantic objects (non-streaming). - Internal tracking values from _extract_usage_tokens are unaffected. + Internal tracking values from UsageRecord accounting are unaffected. """ - if hasattr(response, "usage") and response.usage: + if isinstance(response, dict): + normalize_usage_for_response(response.get("usage"), model) + elif hasattr(response, "usage") and response.usage: normalize_usage_for_response(response.usage, model) return response - def _calculate_cost(self, provider: str, model: str, response: Any) -> float: - plugin = self._get_plugin_instance(provider) - if plugin and getattr(plugin, "skip_cost_calculation", False): - return 0.0 - - try: - if isinstance(response, litellm.EmbeddingResponse): - model_info = litellm.get_model_info(model) - input_cost = model_info.get("input_cost_per_token") - if input_cost: - return (response.usage.prompt_tokens or 0) * input_cost - return 0.0 - - cost = litellm.completion_cost( - completion_response=response, - model=model, - ) - return float(cost) if cost is not None else 0.0 - except Exception as exc: - lib_logger.debug(f"Cost calculation failed for {model}: {exc}") - return 0.0 - async def _transaction_logging_stream_wrapper( self, stream: AsyncGenerator[str, None], transaction_logger: TransactionLogger, request_kwargs: Dict[str, Any], + *, + context: Optional[RequestContext] = None, + plugin: Any = None, ) -> AsyncGenerator[str, None]: """ Wrap a stream to log chunks and final response to TransactionLogger. @@ -1503,6 +2366,24 @@ async def _transaction_logging_stream_wrapper( chunks = [] async for sse_line in stream: + trace_sse_line = _redact_stream_sse_for_trace(sse_line, context, plugin) + transaction_logger.log_transform_pass( + "raw_stream_chunk", + trace_sse_line, + direction="stream", + stage="client", + transport="sse", + snapshot=False, + ) + if sse_line.startswith("data: [DONE]"): + transaction_logger.log_transform_pass( + "stream_done_event", + {"raw": trace_sse_line}, + direction="stream", + stage="final", + transport="sse", + snapshot=False, + ) yield sse_line # Parse and accumulate for final logging @@ -1514,7 +2395,17 @@ async def _transaction_logging_stream_wrapper( if content: chunk_data = json.loads(content) chunks.append(chunk_data) - transaction_logger.log_stream_chunk(chunk_data) + trace_chunk_data = _redact_context_field_cache_paths(chunk_data, context, "stream", plugin) if context else chunk_data + transaction_logger.log_stream_chunk(trace_chunk_data) + if isinstance(chunk_data, dict) and chunk_data.get("error") is not None: + transaction_logger.log_transform_pass( + "stream_error_event", + trace_chunk_data, + direction="stream", + stage="client", + transport="sse", + snapshot=False, + ) except json.JSONDecodeError: lib_logger.debug( f"Failed to parse chunk for logging: {sse_line[:100]}" @@ -1524,8 +2415,618 @@ async def _transaction_logging_stream_wrapper( if chunks: try: final_response = TransactionLogger.assemble_streaming_response(chunks) - transaction_logger.log_response(final_response) + trace_final_response = _redact_context_field_cache_paths(final_response, context, "stream", plugin) if context else final_response + transaction_logger.log_transform_pass( + "assembled_stream_response", + trace_final_response, + direction="response", + stage="client", + transport="sse", + ) + transaction_logger.log_response(trace_final_response) except Exception as e: lib_logger.debug( f"Failed to assemble/log final streaming response: {e}" ) + + +def _target_trace(target: RouteTarget) -> Dict[str, Any]: + """Return non-secret route target metadata for transaction traces.""" + + return { + "name": target.name, + "provider": target.provider, + "model": target.prefixed_model, + "execution": target.execution, + "protocol": target.protocol, + } + + +def _current_route_target(context: RequestContext) -> Optional[RouteTarget]: + """Return the currently selected route target from context metadata.""" + + targets = tuple(context.routing_targets or ()) + if not targets: + return None + if context.routing_target_index < 0 or context.routing_target_index >= len(targets): + return None + return targets[context.routing_target_index] + + +def _provider_native_protocol(plugin: Any, model: str, target: Optional[RouteTarget]) -> Optional[str]: + """Resolve native protocol from target override or provider declaration.""" + + if target and target.protocol: + return target.protocol + if plugin and hasattr(plugin, "get_protocol_name"): + return plugin.get_protocol_name(model) + return None + + +def _should_use_native_protocol(plugin: Any, model: str, target: Optional[RouteTarget], kwargs: Dict[str, Any], *, stream: bool, execution: str) -> bool: + """Return whether auto routing should use provider-native execution.""" + + protocol_name = _provider_native_protocol(plugin, model, target) + if not plugin or not protocol_name: + return False + native_model = plugin.normalize_native_model(model) if hasattr(plugin, "normalize_native_model") else _strip_provider_prefix(model) + request_payload = _native_request_payload(kwargs) + request_payload["_proxy_model"] = model + if native_model: + request_payload["model"] = native_model + operation = plugin.get_native_operation(native_model, request_payload, stream=stream) if hasattr(plugin, "get_native_operation") else "chat" + hook = getattr(plugin, "should_use_native_protocol", None) + if callable(hook): + return bool(hook(model=native_model, operation=operation, stream=stream, execution=execution)) + support = getattr(plugin, "supports_native_operation", None) + return bool(support(native_model, operation) if callable(support) else True) + + +def _strip_provider_prefix(model: str) -> str: + """Return model without the proxy-facing provider prefix.""" + + return model.split("/", 1)[1] if "/" in model else model + + +_NATIVE_REQUEST_DROP_KEYS = { + "api_base", + "api_key", + "api_type", + "api_version", + "base_url", + "custom_llm_provider", + "drop_params", + "logger_fn", + "litellm_call_id", + "mock_response", + "organization", + "project", + "transaction_context", +} + + +def _native_request_payload(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Return provider-visible kwargs for native protocol parsing. + + Full executor calls prepare kwargs for LiteLLM before execution-mode routing. + Native providers must not receive LiteLLM-only routing, logging, or transport + controls because protocol adapters preserve unknown fields intentionally. + """ + + return {key: deepcopy(value) for key, value in kwargs.items() if key not in _NATIVE_REQUEST_DROP_KEYS and not key.startswith("litellm_")} + + +def _provider_supports_native_streaming(plugin: Any, model: str) -> bool: + """Return whether a provider declares native streaming support.""" + + support = getattr(plugin, "supports_native_streaming", None) + if not callable(support): + return False + operation = "chat" + resolver = getattr(plugin, "get_native_operation", None) + if callable(resolver): + try: + operation = resolver(model, {"model": model, "stream": True}, stream=True) + except TypeError: + try: + operation = resolver(model) + except Exception: + return False + except Exception: + return False + return native_provider_supports_streaming(plugin, model=model, operation=operation) + + +def _should_use_native_streaming(plugin: Any, model: str, target: Optional[RouteTarget], execution: str, provider: str) -> bool: + """Return whether streaming may use the native executor. + + Explicit `@native` routing is still constrained by provider capability. + Falling through to native streaming when a provider has not opted in is unsafe + because the generic stream wrapper currently expects LiteLLM-shaped chunks. + """ + + if execution == "native": + if _provider_supports_native_streaming(plugin, model): + return True + raise RoutingExecutionError( + f"Provider {provider} does not support native streaming for {model}", + error_type="configuration_error", + ) + if execution != "auto" or not plugin: + return False + if not _should_use_native_protocol(plugin, model, target, {"model": model, "stream": True}, stream=True, execution=execution): + return False + return _provider_supports_native_streaming(plugin, model) + + +def _redact_context_field_cache_paths(payload: Any, context: RequestContext, direction: str, plugin: Any) -> Any: + """Redact configured field-cache paths before executor-level traces. + + Native provider execution already redacts its internal trace passes. The + client executor also logs returned responses, so it must apply the same + rule-aware redaction there without mutating the client-facing response. + """ + + if direction not in {"response", "stream"} or not plugin or not _provider_native_protocol(plugin, context.model, _current_route_target(context)): + return payload + try: + rules = _merged_field_cache_rules(context.provider, context.model, plugin) + except RoutingExecutionError: + raise + except Exception: + return payload + if not rules: + return payload + redacted = deepcopy(payload) + for rule in rules: + if direction == "response" and getattr(rule, "source", None) not in {"response", "unified_response"}: + continue + if direction == "stream" and getattr(rule, "source", None) not in {"stream_event", "unified_stream_event", "response", "unified_response"}: + continue + for path in _trace_redaction_paths((rule.path,), direction=direction): + try: + tokens = parse_path(path) + _redact_trace_path(redacted, tokens) + _redact_trace_leaf_key(redacted, tokens) + except (FieldCachePathError, TypeError, ValueError): + continue + return redacted + + +def _trace_redaction_paths(paths: tuple[str, ...] | list[str], *, direction: str) -> list[str]: + """Return configured paths plus raw-stream envelope fallbacks for traces.""" + + expanded: list[str] = [] + for path in paths: + expanded.append(path) + if direction == "stream" and path.startswith("raw."): + expanded.append(path[4:]) + return expanded + + +def _redact_stream_sse_for_trace(sse_line: str, context: Optional[RequestContext], plugin: Any) -> str: + """Return a trace-only SSE line with configured native cache paths redacted.""" + + if not context or not isinstance(sse_line, str) or not sse_line.startswith("data: ") or sse_line.startswith("data: [DONE]"): + return sse_line + try: + payload = json.loads(sse_line[6:].strip()) + except json.JSONDecodeError: + return sse_line + redacted = _redact_context_field_cache_paths(payload, context, "stream", plugin) + if redacted is payload: + return sse_line + return f"data: {json.dumps(redacted, separators=(',', ':'))}\n\n" + + +def _redact_trace_path(value: Any, tokens: tuple[PathToken, ...]) -> None: + if not tokens: + return + token = tokens[0] + rest = tokens[1:] + if token.kind == "key": + if isinstance(value, dict) and token.value in value: + if rest: + _redact_trace_path(value[token.value], rest) + else: + value[token.value] = REDACTED + return + if token.kind == "index": + if isinstance(value, list) and value: + index = int(token.value) + if -len(value) <= index < len(value): + if rest: + _redact_trace_path(value[index], rest) + else: + value[index] = REDACTED + return + if token.kind == "wildcard": + if isinstance(value, dict): + for key in list(value.keys()): + if rest: + _redact_trace_path(value[key], rest) + else: + value[key] = REDACTED + elif isinstance(value, list): + for index, item in enumerate(value): + if rest: + _redact_trace_path(item, rest) + else: + value[index] = REDACTED + + +def _redact_trace_leaf_key(value: Any, tokens: tuple[PathToken, ...]) -> None: + """Redact terminal configured cache keys across duplicated trace envelopes.""" + + leaf = next((token.value for token in reversed(tokens) if token.kind == "key"), None) + if not leaf: + return + if isinstance(value, dict): + for key, item in list(value.items()): + if key == leaf: + value[key] = REDACTED + else: + _redact_trace_leaf_key(item, tokens) + elif isinstance(value, list): + for item in value: + _redact_trace_leaf_key(item, tokens) + + +def _merged_field_cache_rules(provider: str, model: str, plugin: Any) -> tuple[Any, ...]: + """Merge provider-declared and JSON-configured field-cache rules. + + Provider declarations are the safe default. Optional JSON config can add or + replace rules by name so operators can tune protocol-state preservation per + provider/model without editing provider code. The import is local to keep the + experimental config layer out of executor module initialization. + """ + + declared = list(plugin.get_field_cache_rules(model) if plugin and hasattr(plugin, "get_field_cache_rules") else ()) + try: + from ..config.experimental import load_experimental_config, parse_field_cache_rules + + configured = list(parse_field_cache_rules(load_experimental_config(), provider, model)) + except Exception as exc: + raise RoutingExecutionError(f"Invalid field-cache configuration for {provider}/{model}", error_type="configuration_error") from exc + if not configured: + return tuple(declared) + merged: dict[str, Any] = {getattr(rule, "name", str(index)): rule for index, rule in enumerate(declared)} + order = [getattr(rule, "name", str(index)) for index, rule in enumerate(declared)] + for rule in configured: + name = getattr(rule, "name", "") + if name and name not in merged: + order.append(name) + if name: + merged[name] = rule + return tuple(merged[name] for name in order if name in merged) + + +def _target_scope_value(target: RouteTarget, key: str, default: Any) -> Any: + """Read request-scope metadata attached by routing resolution.""" + + scope = target.metadata.get("request_scope") if isinstance(target.metadata, dict) else None + if isinstance(scope, dict) and key in scope: + return scope[key] + return default + + +def _route_error_type(error: BaseException, provider: Optional[str] = None) -> str: + """Map an exception to a fallback-policy error type.""" + + if isinstance(error, asyncio.CancelledError): + return "cancelled" + explicit = getattr(error, "error_type", None) + if explicit: + return normalize_route_error_type(str(explicit)) + classified = classify_error(error, provider) + return normalize_route_error_type(classified.error_type) + + +def _route_error_type_from_response(response: Any) -> Optional[str]: + """Infer retryability from the proxy's structured error response.""" + + if not isinstance(response, dict) or not isinstance(response.get("error"), dict): + return None + error = response["error"] + details = error.get("details") if isinstance(error.get("details"), dict) else {} + candidates = _structured_error_candidates(error, details) + hard_stop = _first_route_error_candidate(candidates, _HARD_STOP_ROUTE_ERRORS) + if hard_stop: + return hard_stop + retryable = _first_route_error_candidate(candidates, _RETRYABLE_ROUTE_ERRORS) + if retryable: + return retryable + normal_summary = str(details.get("normal_error_summary", "")).lower() + if any(token in normal_summary for token in ("authentication", "forbidden", "invalid_request", "context_window", "credential_reauth", "configuration_error")): + return _summary_hard_stop_type(normal_summary) + if any(token in normal_summary for token in ("rate_limit", "quota", "capacity")): + return "quota_exceeded" if "quota" in normal_summary else "rate_limit" + if any(token in normal_summary for token in ("server_error", "api_connection", "transient")): + return "api_connection" if "api_connection" in normal_summary else "server_error" + error_type = normalize_route_error_type(str(error.get("type", ""))) + if error_type in {"proxy_timeout", "proxy_all_credentials_exhausted"}: + return "rate_limit" + return None + + +def _route_status_code_from_response(response: Any) -> Optional[int]: + """Return a structured status code without reading raw provider text.""" + + if not isinstance(response, dict) or not isinstance(response.get("error"), dict): + return None + error = response["error"] + details = error.get("details") if isinstance(error.get("details"), dict) else {} + for candidate in (details.get("status_code"), details.get("status"), error.get("status_code"), error.get("status"), error.get("code")): + try: + return int(candidate) + except (TypeError, ValueError): + continue + return None + + +def _target_failure_summary(target: RouteTarget, error_type: str, *, status_code: Optional[int] = None) -> Dict[str, Any]: + """Return a client-safe fallback target failure summary.""" + + summary = { + "target": target.name, + "provider": target.provider, + "model": target.prefixed_model, + "execution": target.execution, + "error_type": normalize_route_error_type(error_type), + # Provider error text can contain raw upstream payload fragments or + # credential-like identifiers. Keep the cross-target summary structural; + # detailed per-credential errors remain in the existing sanitized error + # accumulator. + "message": "", + } + if status_code is not None: + summary["status_code"] = status_code + return summary + + +def _append_routing_attempt_history( + context: RequestContext, + target: RouteTarget, + target_index: int, + *, + success: bool, + error_type: Optional[str] = None, + status_code: Optional[int] = None, + emitted_output: Optional[bool] = None, + fallback_allowed: Optional[bool] = None, + duration_ms: Optional[int] = None, +) -> None: + """Append sanitized live fallback-attempt metadata to the request context.""" + + entry: Dict[str, Any] = { + "target_index": target_index, + "target": target.name, + "provider": target.provider, + "model": target.prefixed_model, + "execution": target.execution, + "success": success, + } + if error_type is not None: + entry["error_type"] = normalize_route_error_type(error_type) + if status_code is not None: + entry["status_code"] = status_code + if emitted_output is not None: + entry["emitted_output"] = bool(emitted_output) + if fallback_allowed is not None: + entry["fallback_allowed"] = bool(fallback_allowed) + if duration_ms is not None: + entry["duration_ms"] = max(0, duration_ms) + context.routing_attempt_history.append(entry) + + +def _elapsed_ms(started_at: float) -> int: + return int((time.monotonic() - started_at) * 1000) + + +def _group_streaming_policy(group: Any) -> str: + """Return the active streaming fallback policy for trace and decisions.""" + + return str(getattr(group, "streaming_policy", "pre_output_only") or "pre_output_only") + + +def _streaming_policy_allows_fallback(group: Any) -> bool: + """Return whether a group permits pre-output streaming fallback.""" + + return _group_streaming_policy(group) != "never" + + +_HARD_STOP_ROUTE_ERRORS = { + "authentication", + "forbidden", + "invalid_request", + "context_window_exceeded", + "credential_reauth_needed", + "pre_request_callback_error", + "cancelled", + "configuration_error", +} +_RETRYABLE_ROUTE_ERRORS = {"rate_limit", "quota_exceeded", "server_error", "api_connection", "unsupported_operation"} + + +def _structured_error_candidates(error: Dict[str, Any], details: Dict[str, Any]) -> List[str]: + """Return normalized structured error hints before reading free-form text.""" + + values: List[Any] = [ + error.get("type"), + error.get("code"), + error.get("status"), + details.get("classification"), + details.get("error_type"), + details.get("status"), + ] + for abnormal in details.get("abnormal_errors") or []: + if isinstance(abnormal, dict): + values.append(abnormal.get("error_type")) + values.append(abnormal.get("status_code")) + status_code = _route_status_code_from_response({"error": error}) + if status_code == 400: + values.append("invalid_request") + elif status_code == 401: + values.append("authentication") + elif status_code == 403: + values.append("forbidden") + elif status_code == 429: + values.append("rate_limit") + elif status_code is not None and status_code >= 500: + values.append("server_error") + return [normalize_route_error_type(str(value)) for value in values if value not in (None, "")] + + +def _first_route_error_candidate(candidates: List[str], allowed: set[str]) -> Optional[str]: + """Return the first structured candidate in an allowed policy set.""" + + for candidate in candidates: + if candidate in allowed: + return candidate + return None + + +def _summary_hard_stop_type(summary: str) -> str: + """Map legacy summary text to a hard-stop category without using raw text.""" + + if "credential_reauth" in summary: + return "credential_reauth_needed" + if "context_window" in summary: + return "context_window_exceeded" + if "configuration_error" in summary or "config" in summary: + return "configuration_error" + if "forbidden" in summary: + return "forbidden" + if "authentication" in summary: + return "authentication" + return "invalid_request" + + +def _with_fallback_summary(response: Any, target_failures: List[Dict[str, Any]], attempt_history: Optional[List[Dict[str, Any]]] = None) -> Any: + """Attach fallback target summaries to a structured error response.""" + + if not target_failures or not isinstance(response, dict) or not isinstance(response.get("error"), dict): + return response + details = response["error"].setdefault("details", {}) + if isinstance(details, dict): + details["fallback_targets"] = list(target_failures) + if attempt_history: + details["routing_attempt_history"] = list(attempt_history) + return response + + +def _stream_chunk_is_visible_output(chunk: str) -> bool: + """Return whether a stream chunk should block cross-target fallback. + + Only client-visible model output should lock the route. Empty frames, DONE + sentinels, and structured error frames are not considered visible output, so + a provider that fails before producing content can still fall through to the + next ordered target. + """ + + return is_visible_stream_output(chunk) + + +def _stream_chunk_error_type(chunk: str) -> Optional[str]: + """Return a route error type for terminal stream error frames. + + Per-target stream executors can emit a structured error SSE and `[DONE]` + instead of raising. Fallback wrappers must treat those frames as target + failures before visible output, while still forwarding them if no fallback is + available. + """ + + payload = _stream_chunk_payload(chunk) + if not isinstance(payload, dict): + return None + event_type = normalize_route_error_type(str(payload.get("event_type") or payload.get("type") or "")) + if event_type == "error": + error = payload.get("error") if isinstance(payload.get("error"), dict) else payload + return _route_error_type_from_response({"error": error}) or "server_error" + if event_type == "response.failed": + error = payload.get("error") if isinstance(payload.get("error"), dict) else {"type": "server_error"} + return _route_error_type_from_response({"error": error}) or "server_error" + if isinstance(payload.get("error"), dict): + return _route_error_type_from_response({"error": payload["error"]}) or "server_error" + return None + + +def _stream_chunk_payload(chunk: str) -> Optional[Dict[str, Any]]: + """Parse a minimal SSE payload for routing decisions only.""" + + text = str(chunk or "").strip() + if not text: + return None + event_type = None + data_lines: List[str] = [] + for line in text.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith(":"): + continue + if stripped.startswith("event:"): + event_type = stripped[6:].strip() + continue + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + if not data_lines: + return {"event_type": event_type} if event_type else None + data = "\n".join(data_lines).strip() + if not data or data == "[DONE]": + return None + try: + parsed = json.loads(data) + except json.JSONDecodeError: + return None + if not isinstance(parsed, dict): + return None + if event_type and "event_type" not in parsed: + parsed["event_type"] = event_type + return parsed + + +def _can_start_stream_provider_cooldown(last_streamed_chunk: Optional[str], *, emitted_output: bool = False) -> bool: + """Return whether a streaming failure occurred before visible output.""" + + if emitted_output: + return False + return last_streamed_chunk is None or not _stream_chunk_is_visible_output(last_streamed_chunk) + + +def _can_retry_stream_after_error(last_streamed_chunk: Optional[str], allow_reasoning_only_retry: bool, *, emitted_output: bool = False) -> bool: + """Return whether a stream can retry/rotate without duplicating output.""" + + if emitted_output: + return False + return can_retry_stream_after_error(last_streamed_chunk, allow_reasoning_only_retry) + + +def _streamed_error_payload(error: StreamedAPIError, classified: Any) -> Dict[str, Any]: + """Return a terminal stream error while preserving structured timeout data.""" + + data = getattr(error, "data", None) + if isinstance(data, dict) and isinstance(data.get("error"), dict): + source = data["error"] + return { + "error": { + "message": source.get("message") or "Upstream stream failed after output began", + "type": source.get("type") or classified.error_type, + "details": source.get("details") or {"status_code": classified.status_code}, + } + } + return {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type, "details": {"status_code": classified.status_code}}} + + +def _merge_stream_error_details(error_data: Dict[str, Any], stream_error: Dict[str, Any]) -> None: + """Preserve structured stream timeout metadata in aggregate stream errors.""" + + target = error_data.get("error") if isinstance(error_data, dict) else None + source = stream_error.get("error") if isinstance(stream_error, dict) else None + if not isinstance(target, dict) or not isinstance(source, dict): + return + details = source.get("details") + if not isinstance(details, dict) or "timeout_type" not in details: + return + merged = dict(target.get("details") or {}) if isinstance(target.get("details"), dict) else {} + merged.update(details) + target["details"] = merged + target["type"] = source.get("type") or target.get("type") diff --git a/src/rotator_library/client/request_builder.py b/src/rotator_library/client/request_builder.py index c9455ec23..80a4ec728 100644 --- a/src/rotator_library/client/request_builder.py +++ b/src/rotator_library/client/request_builder.py @@ -8,6 +8,9 @@ from typing import Any, Awaitable, Callable, Dict, Optional from ..core.types import RequestContext +from ..routing import FallbackResolver, RoutingConfigError, load_routing_config_from_env +from ..routing.types import RouteTarget, RoutingDecision +from ..session_tracking import SessionTrackingHints from ..transaction_logger import TransactionLogger @@ -35,13 +38,14 @@ def __init__( self._get_provider_instance = get_provider_instance @staticmethod - def _pop_scope_kwargs(kwargs: Dict[str, Any]) -> tuple[Optional[str], Any, Any, bool]: + def _pop_scope_kwargs(kwargs: Dict[str, Any]) -> tuple[Optional[str], Any, Any, bool, Any]: classifier = kwargs.pop("classifier", None) request_api_keys = kwargs.pop("api_keys", None) request_providers = kwargs.pop("providers", None) private = bool(kwargs.pop("private", False)) + session_tracking_hints = kwargs.pop("_session_tracking_hints", None) kwargs.pop("model_filters", None) - return classifier, request_api_keys, request_providers, private + return classifier, request_api_keys, request_providers, private, session_tracking_hints @staticmethod def _provider_from_model(model: str) -> str: @@ -51,6 +55,45 @@ def _provider_from_model(model: str) -> str: def _raise_no_provider(model: str) -> None: raise ValueError(f"Invalid model format or no credentials for provider: {model}") + def _resolve_routing_decision(self, model: str) -> Optional[RoutingDecision]: + """Resolve env-configured fallback routing, if any applies.""" + + config = load_routing_config_from_env() + if not config.fallback_groups and not config.model_routes: + return None + try: + decision = FallbackResolver(config).resolve(model) + if decision.reason == "direct_provider_model" and model.lower() not in config.model_routes: + return None + return decision + except RoutingConfigError: + if "/" in model: + return None + raise + + @staticmethod + def _with_request_scope(target: RouteTarget, scope: Dict[str, Any]) -> RouteTarget: + """Attach per-provider request scope to a route target without secrets in traces.""" + + metadata = dict(target.metadata) + metadata["request_scope"] = { + "credentials": list(scope["credentials"]), + "usage_manager_key": scope["usage_manager_key"], + "provider_config": scope["provider_config"], + "credential_secrets": dict(scope["credential_secrets"]), + } + return RouteTarget( + provider=target.provider, + model=target.model, + name=target.name, + protocol=target.protocol, + execution=target.execution, + priority=target.priority, + weight=target.weight, + conditions=dict(target.conditions), + metadata=metadata, + ) + async def _get_session_hints( self, provider: str, @@ -88,17 +131,53 @@ async def _get_session_hints( ) return None + @staticmethod + def _merge_session_hints(*hints: Any) -> Any: + """Merge proxy-internal and provider session evidence. + + Internal hints are removed from request kwargs before provider execution. + They let services such as Responses expose stable continuation IDs to the + centralized tracker without adding provider-visible payload fields. + """ + + merged = SessionTrackingHints() + seen = False + for hint in hints: + if not hint: + continue + if isinstance(hint, dict): + hint = SessionTrackingHints( + strong_anchors=list(hint.get("strong_anchors") or []), + medium_anchors=list(hint.get("medium_anchors") or []), + weak_anchors=list(hint.get("weak_anchors") or []), + affinity_key=hint.get("affinity_key"), + session_scope=hint.get("session_scope"), + ) + if not isinstance(hint, SessionTrackingHints): + continue + seen = True + merged.strong_anchors.extend(hint.strong_anchors) + merged.medium_anchors.extend(hint.medium_anchors) + merged.weak_anchors.extend(hint.weak_anchors) + if not merged.affinity_key and hint.affinity_key: + merged.affinity_key = hint.affinity_key + if not merged.session_scope and hint.session_scope: + merged.session_scope = hint.session_scope + return merged if seen else None + async def build_completion_context( self, request: Optional[Any], pre_request_callback: Optional[Callable], kwargs: Dict[str, Any], ) -> RequestContext: - classifier, request_api_keys, request_providers, private = self._pop_scope_kwargs( + classifier, request_api_keys, request_providers, private, internal_session_hints = self._pop_scope_kwargs( kwargs ) model = kwargs.get("model", "") - provider = self._provider_from_model(model) + routing_decision = self._resolve_routing_decision(model) + routing_targets = routing_decision.targets if routing_decision else None + provider = routing_targets[0].provider if routing_targets else self._provider_from_model(model) if not provider: self._raise_no_provider(model) @@ -112,8 +191,23 @@ async def build_completion_context( if not scope["credentials"]: self._raise_no_provider(model) + if routing_targets: + scoped_targets = [] + for index, target in enumerate(routing_targets): + target_scope = scope if index == 0 else await self._resolve_scope_for_provider( + target.provider, + classifier, + request_api_keys, + request_providers, + private, + ) + if not target_scope["credentials"]: + self._raise_no_provider(target.prefixed_model) + scoped_targets.append(self._with_request_scope(target, target_scope)) + routing_targets = tuple(scoped_targets) + parent_log_dir = kwargs.pop("_parent_log_dir", None) - resolved_model = self._model_resolver.resolve_model_id(model, provider) + resolved_model = self._model_resolver.resolve_model_id(routing_targets[0].prefixed_model if routing_targets else model, provider) kwargs["model"] = resolved_model transaction_logger = None @@ -131,8 +225,17 @@ async def build_completion_context( provider=provider, model=resolved_model, scope_key=scope["usage_manager_key"], - hints=await self._get_session_hints(provider, resolved_model, kwargs), + hints=self._merge_session_hints( + internal_session_hints, + await self._get_session_hints(provider, resolved_model, kwargs), + ), ) + if transaction_logger: + transaction_logger.set_trace_context( + session_id=session.session_id, + scope_key=scope["usage_manager_key"], + classifier=scope["classifier"], + ) return RequestContext( model=resolved_model, @@ -154,6 +257,9 @@ async def build_completion_context( provider_config=scope["provider_config"], credential_secrets=scope["credential_secrets"], classifier=scope["classifier"], + routing_targets=routing_targets, + routing_group_name=routing_decision.group_name if routing_decision else None, + routing_group=routing_decision.group if routing_decision else None, ) async def build_embedding_context( @@ -162,7 +268,7 @@ async def build_embedding_context( pre_request_callback: Optional[Callable], kwargs: Dict[str, Any], ) -> RequestContext: - classifier, request_api_keys, request_providers, private = self._pop_scope_kwargs( + classifier, request_api_keys, request_providers, private, internal_session_hints = self._pop_scope_kwargs( kwargs ) model = kwargs.get("model", "") @@ -185,7 +291,10 @@ async def build_embedding_context( provider=provider, model=model, scope_key=scope["usage_manager_key"], - hints=await self._get_session_hints(provider, model, kwargs), + hints=self._merge_session_hints( + internal_session_hints, + await self._get_session_hints(provider, model, kwargs), + ), ) return RequestContext( diff --git a/src/rotator_library/client/rotating_client.py b/src/rotator_library/client/rotating_client.py index 7d496ecc9..e9cd3804a 100644 --- a/src/rotator_library/client/rotating_client.py +++ b/src/rotator_library/client/rotating_client.py @@ -535,9 +535,12 @@ async def acompletion( pre_request_callback: Optional[callable] = None, **kwargs, ) -> Union[Any, AsyncGenerator[str, None]]: + request_context_callback = kwargs.pop("_request_context_callback", None) context = await self._request_builder.build_completion_context( request, pre_request_callback, kwargs ) + if callable(request_context_callback): + request_context_callback(context) return await self._executor.execute(context) async def aembedding( diff --git a/src/rotator_library/client/stream_retry_policy.py b/src/rotator_library/client/stream_retry_policy.py index ccf0a593c..c776d5cdd 100644 --- a/src/rotator_library/client/stream_retry_policy.py +++ b/src/rotator_library/client/stream_retry_policy.py @@ -1,78 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -"""Retry-safety policy for stream errors after output has started.""" +"""Backward-compatible import path for stream retry policy.""" -import json -from typing import Any, Optional +from ..streaming.policy import can_retry_stream_after_error - -_REASONING_FIELDS = ( - "reasoning", - "reasoning_content", - "thinking", - "thinking_content", -) - - -def can_retry_stream_after_error( - last_streamed_chunk: Optional[str], - allow_reasoning_only_retry: bool, -) -> bool: - """ - Return whether a stream can be retried after an upstream error. - - Retrying is always safe before any chunk has been emitted. After that, it is - only allowed when explicitly enabled and the latest chunk is clearly - reasoning/thinking-only. Ambiguous chunks fail closed. - """ - if last_streamed_chunk is None: - return True - if not allow_reasoning_only_retry: - return False - - payload = last_streamed_chunk.strip() - if not payload.startswith("data:"): - return False - payload = payload[5:].strip() - if not payload or payload == "[DONE]": - return False - try: - data = json.loads(payload) - except json.JSONDecodeError: - return False - - has_reasoning = False - choices = data.get("choices") - if not isinstance(choices, list): - return False - - for choice in choices: - if not isinstance(choice, dict): - return False - for source in (choice, choice.get("delta"), choice.get("message")): - if not isinstance(source, dict): - continue - if ( - _has_text(source.get("content")) - or _has_text(source.get("text")) - or source.get("tool_calls") - or source.get("function_call") - ): - return False - if any(_has_text(source.get(key)) for key in _REASONING_FIELDS): - has_reasoning = True - - return has_reasoning - - -def _has_text(value: Any) -> bool: - if isinstance(value, str): - return bool(value.strip()) - if isinstance(value, list): - for item in value: - if isinstance(item, str) and item.strip(): - return True - if isinstance(item, dict) and _has_text(item.get("text")): - return True - return False +__all__ = ["can_retry_stream_after_error"] diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 01900b50d..eb73df6fe 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -13,17 +13,23 @@ - Client disconnect handling """ +import asyncio import codecs +import contextlib import json import logging import re +import time +from dataclasses import replace from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, TYPE_CHECKING -import litellm - from ..core.errors import StreamedAPIError, CredentialNeedsReauthError from ..core.types import ProcessedChunk from ..core.utils import normalize_usage_for_response +from ..streaming import StreamEvent, StreamMonitor, stream_event_from_sse_chunk +from ..streaming.transport import SSEStreamFormatter +from ..usage.accounting import UsageRecord, extract_usage_record +from ..usage.costs import CostBreakdown, CostCalculator if TYPE_CHECKING: from ..usage.manager import CredentialContext @@ -50,6 +56,8 @@ async def wrap_stream( cred_context: Optional["CredentialContext"] = None, skip_cost_calculation: bool = False, response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + success_callback: Optional[Callable[[], None]] = None, + transaction_logger: Optional[Any] = None, ) -> AsyncGenerator[str, None]: """ Wrap a LiteLLM stream with error handling and usage tracking. @@ -79,12 +87,57 @@ async def wrap_stream( prompt_tokens_uncached = 0 completion_tokens = 0 thinking_tokens = 0 + usage_record = UsageRecord(source="stream") assistant_parts: List[str] = [] tool_call_ids: List[str] = [] + monitor = StreamMonitor(clock=time.monotonic) + from ..config.experimental import get_stream_runtime_settings + + stream_settings = get_stream_runtime_settings() + formatter = SSEStreamFormatter() + upstream_closed = False + stream_cancelled = False + last_heartbeat_at = monitor.metrics.started_at + lifecycle_logger = transaction_logger + self._log_stream_lifecycle( + lifecycle_logger, + "stream_started", + monitor, + StreamEvent("started", protocol="openai_chat"), + ) # Use manual iteration to allow continue after partial JSON errors stream_iterator = stream.__aiter__() + async def close_upstream(reason: str, *, force: bool = False) -> None: + """Best-effort close for upstream async streams. + + Client disconnects and timeout failures should not leave provider + HTTP streams running in the background. Close failures are logged as + lifecycle metadata only and never replace the original stream error. + """ + + nonlocal upstream_closed + if upstream_closed or (not force and not stream_settings.cancel_upstream_on_disconnect): + return + upstream_closed = True + for candidate in (stream_iterator, stream): + try: + closer = getattr(candidate, "aclose", None) + if closer: + await closer() + self._log_stream_lifecycle(lifecycle_logger, "stream_upstream_cancelled", monitor, StreamEvent("cancelled", protocol="openai_chat", data={"reason": reason})) + return + closer = getattr(candidate, "close", None) + if closer: + closer() + self._log_stream_lifecycle(lifecycle_logger, "stream_upstream_cancelled", monitor, StreamEvent("cancelled", protocol="openai_chat", data={"reason": reason})) + return + except Exception as exc: + lib_logger.debug("Failed to close upstream stream: %s", exc) + self._log_stream_lifecycle(lifecycle_logger, "stream_upstream_close_failed", monitor, StreamEvent("error", protocol="openai_chat", data={"reason": reason, "error_type": type(exc).__name__})) + return + try: while True: try: @@ -95,23 +148,101 @@ async def wrap_stream( ) break - chunk = await stream_iterator.__anext__() + next_task = asyncio.create_task(stream_iterator.__anext__()) + try: + while True: + wait_seconds = _next_stream_wait_seconds(monitor, stream_settings, last_heartbeat_at) + wait_tasks = {next_task} + disconnect_task = None + if request is not None: + disconnect_task = asyncio.create_task(request.is_disconnected()) + wait_tasks.add(disconnect_task) + done, _ = await asyncio.wait(wait_tasks, timeout=wait_seconds) + if disconnect_task is not None: + if disconnect_task in done and disconnect_task.result(): + stream_cancelled = True + next_task.cancel() + with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + await next_task + await close_upstream("client_disconnect") + return + if not disconnect_task.done(): + disconnect_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await disconnect_task + if next_task in done: + chunk = next_task.result() + break + + timeout_error = _stream_timeout_error(monitor, stream_settings) + if timeout_error: + next_task.cancel() + with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + await next_task + await close_upstream(timeout_error[0], force=True) + self._log_stream_lifecycle(lifecycle_logger, timeout_error[2], monitor, StreamEvent("error", protocol="openai_chat", data={"error": timeout_error[1]})) + raise StreamedAPIError(timeout_error[1]["message"], data={"error": timeout_error[1]}) + + if _heartbeat_due(monitor, stream_settings, last_heartbeat_at): + heartbeat = formatter.format_heartbeat() + last_heartbeat_at = time.monotonic() + self._log_stream_lifecycle(lifecycle_logger, "stream_heartbeat", monitor, StreamEvent("heartbeat", protocol="openai_chat", visible_output=False)) + yield heartbeat + except Exception: + if not next_task.done(): + next_task.cancel() + with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + await next_task + raise + + raw_event = StreamEvent( + "raw_chunk", + protocol="openai_chat", + raw=chunk, + data={"chunk_index": monitor.metrics.chunk_count + 1}, + ) + if monitor.metrics.first_byte_at is None: + self._log_stream_lifecycle( + lifecycle_logger, + "stream_first_byte", + monitor, + raw_event, + ) # Clear error buffer on successful chunk receipt error_buffer.reset() # Process chunk + cost_event_record = _usage_record_from_sse_cost_chunk(chunk, model=model) processed = self._process_chunk( chunk, accumulated_finish_reason, has_tool_calls, model, ) + if cost_event_record.provider_reported_cost is not None: + usage_record = _merge_usage_cost(usage_record, cost_event_record) + if not processed.sse_string: + stream_completed = True + break self._collect_session_response_anchors( processed.sse_string, assistant_parts, tool_call_ids, ) + event = stream_event_from_sse_chunk(processed.sse_string) + first_visible = ( + event.visible_output + and monitor.metrics.first_visible_output_at is None + ) + monitor.record_event(event) + if first_visible: + self._log_stream_lifecycle( + lifecycle_logger, + "stream_first_visible_output", + monitor, + event, + ) # Update tracking state if processed.has_tool_calls: @@ -121,52 +252,20 @@ async def wrap_stream( # Only update if not already tool_calls (highest priority) accumulated_finish_reason = processed.finish_reason if processed.usage and isinstance(processed.usage, dict): - # Extract token counts from final chunk - prompt_tokens = processed.usage.get("prompt_tokens", 0) - completion_tokens = processed.usage.get("completion_tokens", 0) - prompt_details = processed.usage.get("prompt_tokens_details") - if prompt_details: - if isinstance(prompt_details, dict): - prompt_tokens_cached = ( - prompt_details.get("cached_tokens", 0) or 0 - ) - prompt_tokens_cache_write = ( - prompt_details.get("cache_creation_tokens", 0) or 0 - ) - else: - prompt_tokens_cached = ( - getattr(prompt_details, "cached_tokens", 0) or 0 - ) - prompt_tokens_cache_write = ( - getattr(prompt_details, "cache_creation_tokens", 0) - or 0 - ) - completion_details = processed.usage.get( - "completion_tokens_details" - ) - if completion_details: - if isinstance(completion_details, dict): - thinking_tokens = ( - completion_details.get("reasoning_tokens", 0) or 0 - ) - else: - thinking_tokens = ( - getattr(completion_details, "reasoning_tokens", 0) - or 0 - ) - if processed.usage.get("cache_read_tokens") is not None: - prompt_tokens_cached = ( - processed.usage.get("cache_read_tokens") or 0 - ) - if processed.usage.get("cache_creation_tokens") is not None: - prompt_tokens_cache_write = ( - processed.usage.get("cache_creation_tokens") or 0 - ) - if thinking_tokens and completion_tokens >= thinking_tokens: - completion_tokens = completion_tokens - thinking_tokens - prompt_tokens_uncached = max( - 0, prompt_tokens - prompt_tokens_cached + next_usage_record = extract_usage_record( + processed.usage, + model=model, + source="stream_final_chunk", ) + if next_usage_record.provider_reported_cost is None and usage_record.provider_reported_cost is not None: + next_usage_record = _merge_usage_cost(next_usage_record, usage_record) + usage_record = next_usage_record + prompt_tokens = usage_record.input_tokens + usage_record.cache_read_tokens + completion_tokens = usage_record.completion_tokens + thinking_tokens = usage_record.reasoning_tokens + prompt_tokens_cached = usage_record.cache_read_tokens + prompt_tokens_cache_write = usage_record.cache_write_tokens + prompt_tokens_uncached = usage_record.prompt_tokens_for_mark_success yield processed.sse_string @@ -221,31 +320,51 @@ async def wrap_stream( ) # Not a JSON-related error, re-raise + monitor.metrics.error_count += 1 + await close_upstream("stream_exception", force=True) raise except StreamedAPIError: # Re-raise for retry loop + await close_upstream("streamed_api_error", force=True) + raise + + except asyncio.CancelledError: + stream_cancelled = True + monitor.cancel() + self._log_stream_lifecycle( + lifecycle_logger, + "stream_cancelled", + monitor, + StreamEvent("cancelled", protocol="openai_chat", data={"reason": "task_cancelled"}), + ) + await close_upstream("task_cancelled", force=True) raise finally: # Record usage if stream completed if stream_completed: if cred_context: - approx_cost = 0.0 - if not skip_cost_calculation: - approx_cost = self._calculate_stream_cost( - model, - prompt_tokens_uncached + prompt_tokens_cached, - completion_tokens + thinking_tokens, - ) + cost_breakdown = self._calculate_stream_cost_breakdown( + model, + usage_record, + skip_cost_calculation=skip_cost_calculation, + ) + self._log_stream_usage_accounting( + transaction_logger, + usage_record, + cost_breakdown, + ) cred_context.mark_success( prompt_tokens=prompt_tokens_uncached, completion_tokens=completion_tokens, thinking_tokens=thinking_tokens, prompt_tokens_cache_read=prompt_tokens_cached, prompt_tokens_cache_write=prompt_tokens_cache_write, - approx_cost=approx_cost, + approx_cost=cost_breakdown.total_cost, ) + if success_callback: + success_callback() if response_callback and (assistant_parts or tool_call_ids): # Intentionally only record response anchors after a complete @@ -268,9 +387,61 @@ async def wrap_stream( } ) + monitor.complete() + self._log_stream_lifecycle( + lifecycle_logger, + "stream_completed", + monitor, + StreamEvent("completed", protocol="openai_chat"), + ) + self._log_stream_lifecycle( + lifecycle_logger, + "stream_metrics_final", + monitor, + StreamEvent("metadata", protocol="openai_chat"), + ) + # Yield [DONE] for completed streams yield "data: [DONE]\n\n" + elif request and await request.is_disconnected(): + stream_cancelled = True + monitor.cancel() + self._log_stream_lifecycle( + lifecycle_logger, + "stream_cancelled", + monitor, + StreamEvent("cancelled", protocol="openai_chat"), + ) + await close_upstream("client_disconnect") + elif stream_cancelled: + await close_upstream("stream_cancelled", force=True) + + @staticmethod + def _log_stream_lifecycle( + transaction_logger: Optional[Any], + pass_name: str, + monitor: StreamMonitor, + event: StreamEvent, + ) -> None: + """Emit stream lifecycle metrics without affecting stream delivery.""" + + if not transaction_logger: + return + try: + transaction_logger.log_transform_pass( + pass_name, + {"event": event.to_dict(), "metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="client", + protocol=event.protocol, + transport="sse", + metadata={"event_type": event.event_type}, + snapshot=False, + ) + except Exception as exc: + lib_logger.debug("Stream lifecycle trace failed: %s", exc) + def _collect_session_response_anchors( self, sse_string: str, @@ -325,6 +496,13 @@ def _process_chunk( ProcessedChunk with SSE string and metadata """ # Convert chunk to dict + if isinstance(chunk, str): + stripped = chunk.strip() + if stripped == "[DONE]" or stripped == "data: [DONE]": + return ProcessedChunk(sse_string="", finish_reason="stop") + if stripped.startswith("data:") or stripped.startswith("event:") or stripped.startswith(":"): + usage = _usage_from_sse_string(chunk) + return ProcessedChunk(sse_string=chunk if chunk.endswith("\n\n") else f"{chunk}\n\n", usage=usage) if hasattr(chunk, "model_dump"): chunk_dict = chunk.model_dump() elif hasattr(chunk, "dict"): @@ -332,8 +510,9 @@ def _process_chunk( else: chunk_dict = chunk - # Extract metadata before modifying - usage = chunk_dict.get("usage") + # Extract metadata before modifying. Providers can report cost as a + # sibling of `usage`, so keep those fields attached for normalization. + usage = _usage_from_chunk_dict(chunk_dict) finish_reason = None chunk_has_tool_calls = False @@ -387,7 +566,7 @@ def _process_chunk( # INTERMEDIATE CHUNK: Never emit finish_reason choice["finish_reason"] = None - usage = chunk_dict.get("usage") + usage = _usage_from_chunk_dict(chunk_dict) if isinstance(usage, dict): normalize_usage_for_response(usage, model) @@ -446,25 +625,214 @@ def _try_extract_error( return None - def _calculate_stream_cost( + def _calculate_stream_cost_breakdown( self, model: str, - prompt_tokens: int, - completion_tokens: int, - ) -> float: + usage_record: UsageRecord, + *, + skip_cost_calculation: bool, + ) -> CostBreakdown: + """Calculate advisory stream cost through the shared cost helper.""" + + if skip_cost_calculation: + return CostBreakdown(pricing_source="skipped") + return CostCalculator().calculate(usage_record, model=model) + + @staticmethod + def _log_stream_usage_accounting( + transaction_logger: Optional[Any], + usage_record: UsageRecord, + cost_breakdown: CostBreakdown, + ) -> None: + """Trace normalized stream usage without affecting stream delivery.""" + + if not transaction_logger: + return + transaction_logger.log_transform_pass( + "usage_accounting_summary", + {"usage": usage_record.to_dict(), "cost": cost_breakdown.to_dict()}, + direction="metadata", + stage="final", + transport="sse", + metadata={ + "source": usage_record.source, + "pricing_source": cost_breakdown.pricing_source, + }, + snapshot=False, + ) + + +def _next_stream_wait_seconds( + monitor: StreamMonitor, + settings: Any, + last_heartbeat_at: float, +) -> Optional[float]: + """Return the next active stream policy deadline. + + A pending upstream `__anext__()` task is not cancelled for heartbeats. The + handler waits until the nearest policy deadline, emits heartbeat metadata if + needed, and keeps waiting on the same upstream task. + """ + + candidates: list[float] = [] + now = time.monotonic() + if settings.heartbeat_seconds: + candidates.append(max(0.0, last_heartbeat_at + settings.heartbeat_seconds - now)) + if settings.ttfb_timeout_seconds and monitor.metrics.first_byte_at is None: + candidates.append(max(0.0, monitor.metrics.started_at + settings.ttfb_timeout_seconds - now)) + if settings.stall_timeout_seconds and monitor.metrics.first_byte_at is not None: + last_chunk_at = monitor.metrics.last_chunk_at or monitor.metrics.first_byte_at + candidates.append(max(0.0, last_chunk_at + settings.stall_timeout_seconds - now)) + return min(candidates) if candidates else None + + +def _heartbeat_due(monitor: StreamMonitor, settings: Any, last_heartbeat_at: float) -> bool: + """Return whether a heartbeat should be emitted while upstream is idle.""" + + if not settings.heartbeat_seconds: + return False + return time.monotonic() - last_heartbeat_at >= settings.heartbeat_seconds + + +def _stream_timeout_error(monitor: StreamMonitor, settings: Any) -> Optional[tuple[str, dict[str, Any], str]]: + """Return a structured stream timeout error when a configured deadline expires.""" + + now = time.monotonic() + if settings.ttfb_timeout_seconds and monitor.metrics.first_byte_at is None: + if now - monitor.metrics.started_at >= settings.ttfb_timeout_seconds: + return ( + "ttfb_timeout", + { + "message": "Stream timed out before first byte", + "type": "api_connection", + "details": {"timeout_type": "ttfb", "timeout_seconds": settings.ttfb_timeout_seconds}, + }, + "stream_ttfb_timeout", + ) + if settings.stall_timeout_seconds and monitor.metrics.first_byte_at is not None: + last_chunk_at = monitor.metrics.last_chunk_at or monitor.metrics.first_byte_at + if now - last_chunk_at >= settings.stall_timeout_seconds: + return ( + "stall_timeout", + { + "message": "Stream stalled while waiting for provider data", + "type": "api_connection", + "details": {"timeout_type": "stall", "timeout_seconds": settings.stall_timeout_seconds}, + }, + "stream_stall_timeout", + ) + return None + + +def _usage_from_sse_string(chunk: str) -> Optional[dict[str, Any]]: + """Extract usage from already formatted SSE chunks when native streams pass through.""" + + text = chunk.strip() + data_lines: list[str] = [] + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + if not data_lines: + return None + payload = "\n".join(data_lines).strip() + if not payload or payload == "[DONE]": + return None + try: + data = json.loads(payload) + except json.JSONDecodeError: + return None + if not isinstance(data, dict): + return None + usage = data.get("usage") + if not isinstance(usage, dict): + return None + merged = dict(usage) + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in data and key not in merged: + merged[key] = data[key] + return merged + + +def _usage_from_chunk_dict(chunk: Any) -> Optional[dict[str, Any]]: + """Return dict chunk usage with top-level provider cost siblings.""" + + if not isinstance(chunk, dict): + return None + usage = chunk.get("usage") + if not isinstance(usage, dict): + return None + merged = dict(usage) + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in chunk and key not in merged: + merged[key] = chunk[key] + return merged + + +def _usage_record_from_sse_cost_chunk(chunk: Any, *, model: str) -> UsageRecord: + """Extract provider-reported stream cost from SSE comments/events.""" + + if not isinstance(chunk, str): + return UsageRecord(source="stream_cost_event", model=model) + cost_payload = _sse_cost_payload(chunk) + if cost_payload is None: + return UsageRecord(source="stream_cost_event", model=model) + if isinstance(cost_payload, (int, float, str)): + cost_payload = {"provider_reported_cost": cost_payload, "source": "sse_cost"} + if not isinstance(cost_payload, dict): + return UsageRecord(source="stream_cost_event", model=model) + return extract_usage_record( + {"usage": {"provider_reported_cost": cost_payload.get("provider_reported_cost", cost_payload.get("request_cost_usd", cost_payload.get("total_cost", cost_payload.get("cost", cost_payload.get("estimated_cost"))))), "currency": cost_payload.get("currency", "USD"), "cost_details": cost_payload}}, + model=model, + source="stream_cost_event", + ) + + +def _sse_cost_payload(chunk: str) -> Any: + """Parse `: cost ...` comments and `event: cost` frames.""" + + event_type: Optional[str] = None + data_lines: list[str] = [] + for line in chunk.strip().splitlines(): + stripped = line.strip() + if stripped.startswith(":"): + comment = stripped[1:].strip() + if comment.startswith("cost"): + return _parse_cost_text(comment[4:].strip()) + continue + if stripped.startswith("event:"): + event_type = stripped[6:].strip() + continue + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + if event_type == "cost" and data_lines: + return _parse_cost_text("\n".join(data_lines).strip()) + return None + + +def _parse_cost_text(text: str) -> Any: + if not text: + return None + try: + return json.loads(text) + except json.JSONDecodeError: try: - model_info = litellm.get_model_info(model) - input_cost = model_info.get("input_cost_per_token") - output_cost = model_info.get("output_cost_per_token") - total_cost = 0.0 - if input_cost: - total_cost += prompt_tokens * input_cost - if output_cost: - total_cost += completion_tokens * output_cost - return total_cost - except Exception as exc: - lib_logger.debug(f"Stream cost calculation failed for {model}: {exc}") - return 0.0 + return float(text) + except ValueError: + return None + + +def _merge_usage_cost(base: UsageRecord, cost_record: UsageRecord) -> UsageRecord: + """Copy provider-reported cost onto an existing normalized usage record.""" + + if cost_record.provider_reported_cost is None: + return base + return replace( + base, + provider_reported_cost=cost_record.provider_reported_cost, + cost_currency=cost_record.cost_currency, + cost_source=cost_record.cost_source, + ) class StreamBuffer: diff --git a/src/rotator_library/client/transforms.py b/src/rotator_library/client/transforms.py index fb2a02d88..78f052ba4 100644 --- a/src/rotator_library/client/transforms.py +++ b/src/rotator_library/client/transforms.py @@ -15,6 +15,7 @@ """ import logging +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional lib_logger = logging.getLogger("rotator_library") @@ -81,6 +82,11 @@ async def apply( credential: str, kwargs: Dict[str, Any], provider_config_override: Optional[Dict[str, Any]] = None, + *, + transaction_logger: Optional[Any] = None, + credential_id: Optional[str] = None, + transport: Optional[str] = None, + trace_metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ Apply all applicable transforms to request kwargs. @@ -95,43 +101,146 @@ async def apply( Modified kwargs """ modifications: List[str] = [] + trace_metadata = dict(trace_metadata or {}) + _trace_transform_pass( + transaction_logger, + "pre_provider_transform_request", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + metadata={"phase": "start", **trace_metadata}, + ) # 1. Apply built-in transforms for transform_provider, transforms in self._transforms.items(): # Check if transform applies (provider match or model contains pattern) if transform_provider == provider or transform_provider in model.lower(): for transform in transforms: - result = transform(kwargs, model, provider) + before = deepcopy(kwargs) if transaction_logger else None + try: + result = transform(kwargs, model, provider) + except Exception as exc: + if transaction_logger: + transaction_logger.log_transform_error( + "builtin_provider_transform", + exc, + payload=before if before is not None else kwargs, + stage="client", + transport=transport, + metadata={ + "provider": provider, + "model": model, + "credential_id": credential_id, + "transform_provider": transform_provider, + "transform_name": getattr(transform, "__name__", repr(transform)), + **trace_metadata, + }, + ) + raise if result: modifications.append(result) + _trace_transform_pass( + transaction_logger, + "after_builtin_provider_transform", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + changed_from_previous=(before != kwargs) if before is not None else None, + metadata={ + "transform_provider": transform_provider, + "transform_name": getattr(transform, "__name__", repr(transform)), + "modification": result, + **trace_metadata, + }, + ) # 2. Apply provider hook transforms (async) plugin = self._get_plugin_instance(provider) if plugin and hasattr(plugin, "transform_request"): try: + before = deepcopy(kwargs) if transaction_logger else None hook_result = await plugin.transform_request(kwargs, model, credential) if hook_result: modifications.extend(hook_result) + if hook_result or (before is not None and before != kwargs): + _trace_transform_pass( + transaction_logger, + "after_provider_hook_transform", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + changed_from_previous=(before != kwargs) if before is not None else None, + metadata={"modifications": hook_result or [], **trace_metadata}, + ) except Exception as e: lib_logger.debug(f"Provider transform_request hook failed: {e}") + if transaction_logger: + transaction_logger.log_transform_error( + "provider_hook_transform", + e, + payload=kwargs, + stage="client", + transport=transport, + metadata={"provider": provider, "model": model, "credential_id": credential_id, **trace_metadata}, + ) # 3. Apply model-specific options from provider if plugin and hasattr(plugin, "get_model_options"): model_options = plugin.get_model_options(model) if model_options: + before = deepcopy(kwargs) if transaction_logger else None for key, value in model_options.items(): if key == "reasoning_effort": kwargs["reasoning_effort"] = value elif key not in kwargs: kwargs[key] = value modifications.append(f"applied model options for {model}") + _trace_transform_pass( + transaction_logger, + "after_provider_model_options", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + changed_from_previous=(before != kwargs) if before is not None else None, + metadata={"model_options": deepcopy(model_options) if transaction_logger else None, **trace_metadata}, + ) # 4. Apply LiteLLM conversion if config available if self._config and hasattr(self._config, "convert_for_litellm"): + before = deepcopy(kwargs) if transaction_logger else None + _trace_transform_pass( + transaction_logger, + "before_litellm_conversion", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + metadata={"provider_config_override": bool(provider_config_override), **trace_metadata}, + ) kwargs = self._config.convert_for_litellm( provider_override=provider_config_override, **kwargs, ) + _trace_transform_pass( + transaction_logger, + "after_litellm_conversion", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + changed_from_previous=(before != kwargs) if before is not None else None, + metadata={"provider_config_override": bool(provider_config_override), **trace_metadata}, + ) if modifications: lib_logger.debug( @@ -312,3 +421,36 @@ def _transform_dedaluslabs_tool_choice( # caused 400 errors on models that don't support those categories (e.g. Gemma). # See gemini_provider.py for the full removal comment with previous defaults. # ========================================================================= + + +def _trace_transform_pass( + transaction_logger: Optional[Any], + pass_name: str, + payload: Dict[str, Any], + *, + provider: str, + model: str, + credential_id: Optional[str], + transport: Optional[str], + metadata: Dict[str, Any], + changed_from_previous: Optional[bool] = None, +) -> None: + """Record provider-transform states without changing transform behavior. + + Transform tracing is observability-only. This helper centralizes pass + metadata so individual transforms can stay focused on payload mutation while + the transaction trace still shows each live request state. + """ + + if not transaction_logger: + return + transaction_logger.log_transform_pass( + pass_name, + payload, + direction="request", + stage="client", + credential_id=credential_id, + transport=transport, + changed_from_previous=changed_from_previous, + metadata={"provider": provider, "model": model, **metadata}, + ) diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py new file mode 100644 index 000000000..be1e76346 --- /dev/null +++ b/src/rotator_library/config/experimental.py @@ -0,0 +1,500 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Optional structured configuration for experimental native features. + +The proxy remains environment-first: existing `.env` variables keep working and +environment variables override this JSON layer. This module intentionally avoids +secrets. API keys, OAuth tokens, bearer headers, and similar values must remain +in environment variables or provider-managed credential files. +""" + +from __future__ import annotations + +import json +import os +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Mapping, Optional + +from ..field_cache import FieldCacheInjection, FieldCacheRule +from ..usage.costs import ModelPricing + +_CONFIG_ENV_KEYS = ("LLM_PROXY_CONFIG_FILE", "PROXY_CONFIG_FILE") +_KNOWN_SECTIONS = {"routing", "pricing", "streaming", "field_cache", "providers", "retry", "responses"} +_SECRET_KEY_PARTS = ("api_key", "apikey", "authorization", "access_token", "accesstoken", "refresh_token", "refreshtoken", "oauth_token", "oauthtoken", "oauth_token_secret", "oauthtokensecret", "id_token", "idtoken", "token_secret", "tokensecret", "client_secret", "clientsecret", "secret_key", "secretkey", "bearer_token", "bearertoken", "password") + + +class ExperimentalConfigError(ValueError): + """Raised when optional structured config is malformed or unsafe.""" + + +@dataclass(frozen=True) +class ExperimentalConfig: + """Parsed optional JSON config. + + Sections are stored as dictionaries rather than deep custom classes so Phase + 10 can layer config onto existing feature-specific parsers without creating + a second full application configuration system. + """ + + routing: dict[str, Any] = field(default_factory=dict) + pricing: dict[str, Any] = field(default_factory=dict) + streaming: dict[str, Any] = field(default_factory=dict) + field_cache: dict[str, Any] = field(default_factory=dict) + providers: dict[str, Any] = field(default_factory=dict) + retry: dict[str, Any] = field(default_factory=dict) + responses: dict[str, Any] = field(default_factory=dict) + unknown_sections: dict[str, Any] = field(default_factory=dict) + warnings: tuple[str, ...] = () + path: Optional[str] = None + + @property + def is_empty(self) -> bool: + return not (self.routing or self.pricing or self.streaming or self.field_cache or self.providers or self.retry or self.responses or self.unknown_sections) + + +@dataclass(frozen=True) +class StreamRuntimeSettings: + """Runtime stream observability settings. + + Timeout and heartbeat values default to disabled so existing long-running + reasoning streams keep working. Operators can opt into active stream + hardening through env or JSON config without changing provider code. + """ + + ttfb_timeout_seconds: Optional[float] = None + stall_timeout_seconds: Optional[float] = None + heartbeat_seconds: Optional[float] = None + cancel_upstream_on_disconnect: bool = True + trace_metrics: bool = True + + +@dataclass(frozen=True) +class RetryRuntimeSettings: + """Runtime retry/cooldown settings layered from JSON and env.""" + + provider_cooldown_min_seconds: int = 10 + provider_cooldown_default_seconds: int = 30 + provider_cooldown_on_quota: bool = False + provider_backoff_window_seconds: int = 60 + provider_backoff_threshold: int = 3 + provider_backoff_base_seconds: Optional[int] = None + provider_backoff_max_seconds: int = 300 + failure_history_max_entries: int = 200 + + +@dataclass(frozen=True) +class ResponsesStoreRuntimeSettings: + """Runtime backend selection for Responses storage.""" + + backend: str = "memory" + cache_name: str = "responses" + cache_prefix: str = "responses" + cache_dir: Optional[str] = None + cache_memory_ttl_seconds: int = 3600 + cache_disk_ttl_seconds: int = 172800 + + +def load_experimental_config(path: str | os.PathLike[str] | None = None, env: Mapping[str, str] | None = None) -> ExperimentalConfig: + """Load optional JSON config from an explicit path or config env var.""" + + source = env if env is not None else os.environ + resolved = Path(path) if path is not None else _path_from_env(source) + if resolved is None or not resolved.exists(): + return ExperimentalConfig(path=str(resolved) if resolved is not None else None) + try: + data = json.loads(resolved.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ExperimentalConfigError(f"Invalid JSON config at {resolved}: {exc.msg}") from exc + if not isinstance(data, dict): + raise ExperimentalConfigError("JSON config root must be an object") + _reject_secret_keys(data) + warnings = tuple(f"Unknown config section '{key}' ignored by current runtime" for key in data if key not in _KNOWN_SECTIONS) + unknown = {key: value for key, value in data.items() if key not in _KNOWN_SECTIONS} + return ExperimentalConfig( + routing=_dict_section(data, "routing"), + pricing=_dict_section(data, "pricing"), + streaming=_dict_section(data, "streaming"), + field_cache=_dict_section(data, "field_cache"), + providers=_dict_section(data, "providers"), + retry=_dict_section(data, "retry"), + responses=_dict_section(data, "responses"), + unknown_sections=unknown, + warnings=warnings, + path=str(resolved), + ) + + +def load_config_from_mapping(data: Mapping[str, Any]) -> ExperimentalConfig: + """Build config from an in-memory mapping for tests and provider helpers.""" + + _reject_secret_keys(data) + warnings = tuple(f"Unknown config section '{key}' ignored by current runtime" for key in data if key not in _KNOWN_SECTIONS) + return ExperimentalConfig( + routing=_dict_section(data, "routing"), + pricing=_dict_section(data, "pricing"), + streaming=_dict_section(data, "streaming"), + field_cache=_dict_section(data, "field_cache"), + providers=_dict_section(data, "providers"), + retry=_dict_section(data, "retry"), + responses=_dict_section(data, "responses"), + unknown_sections={key: value for key, value in data.items() if key not in _KNOWN_SECTIONS}, + warnings=warnings, + ) + + +def get_configured_model_pricing( + provider: str, + model: str, + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> Optional[ModelPricing]: + """Return JSON/env pricing for a provider/model, with env taking priority.""" + + source = env if env is not None else os.environ + env_pricing = _pricing_from_env(provider, model, source) + if env_pricing: + return env_pricing + active = config if config is not None else load_experimental_config(env=source) + pricing_section = active.pricing.get(provider, {}) if isinstance(active.pricing, dict) else {} + raw = pricing_section.get(model) if isinstance(pricing_section, dict) else None + if not isinstance(raw, dict): + return None + return ModelPricing( + input_cost_per_token=as_float(raw.get("input", raw.get("input_cost_per_token", 0.0)), name="pricing.input"), + output_cost_per_token=as_float(raw.get("output", raw.get("output_cost_per_token", 0.0)), name="pricing.output"), + cache_read_cost_per_token=as_float(raw.get("cache_read", raw.get("cache_read_cost_per_token", 0.0)), name="pricing.cache_read"), + cache_write_cost_per_token=as_float(raw.get("cache_write", raw.get("cache_write_cost_per_token", 0.0)), name="pricing.cache_write"), + reasoning_cost_per_token=as_float(raw.get("reasoning", raw.get("reasoning_cost_per_token", 0.0)), name="pricing.reasoning"), + currency=str(raw.get("currency", "USD")), + source="json_config", + ) + + +def get_stream_runtime_settings( + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> StreamRuntimeSettings: + """Return stream runtime settings with environment overriding JSON.""" + + source = env if env is not None else os.environ + active = config if config is not None else load_experimental_config(env=source) + streaming = active.streaming if isinstance(active.streaming, dict) else {} + + return StreamRuntimeSettings( + ttfb_timeout_seconds=_optional_positive_float(_env_or_json(source, "STREAM_TTFB_TIMEOUT_SECONDS", streaming, "ttfb_timeout_seconds"), "STREAM_TTFB_TIMEOUT_SECONDS"), + stall_timeout_seconds=_optional_positive_float(_env_or_json(source, "STREAM_STALL_TIMEOUT_SECONDS", streaming, "stall_timeout_seconds"), "STREAM_STALL_TIMEOUT_SECONDS"), + heartbeat_seconds=_optional_positive_float(_env_or_json(source, "STREAM_HEARTBEAT_INTERVAL_SECONDS", streaming, "heartbeat_interval_seconds", default=_env_or_json(source, "STREAM_HEARTBEAT_SECONDS", streaming, "heartbeat_seconds")), "STREAM_HEARTBEAT_INTERVAL_SECONDS"), + cancel_upstream_on_disconnect=as_bool(_env_or_json(source, "STREAM_CANCEL_UPSTREAM_ON_DISCONNECT", streaming, "cancel_upstream_on_disconnect", default=True), name="STREAM_CANCEL_UPSTREAM_ON_DISCONNECT"), + trace_metrics=as_bool(_env_or_json(source, "STREAM_TRACE_METRICS", streaming, "trace_metrics", default=True), name="STREAM_TRACE_METRICS"), + ) + + +def get_retry_runtime_settings( + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> RetryRuntimeSettings: + """Return retry/cooldown settings with environment overriding JSON.""" + + source = env if env is not None else os.environ + active = config if config is not None else load_experimental_config(env=source) + retry = active.retry if isinstance(active.retry, dict) else {} + cooldown = retry.get("provider_cooldown", {}) if isinstance(retry.get("provider_cooldown"), dict) else retry + backoff = retry.get("backoff", {}) if isinstance(retry.get("backoff"), dict) else retry + return RetryRuntimeSettings( + provider_cooldown_min_seconds=max(0, _int_setting(source, "PROVIDER_COOLDOWN_MIN_SECONDS", cooldown, "provider_cooldown_min_seconds", 10)), + provider_cooldown_default_seconds=max(0, _int_setting(source, "PROVIDER_COOLDOWN_DEFAULT_SECONDS", cooldown, "provider_cooldown_default_seconds", 30)), + provider_cooldown_on_quota=_bool_setting(source, "PROVIDER_COOLDOWN_ON_QUOTA", cooldown, "provider_cooldown_on_quota", False), + provider_backoff_window_seconds=max(0, _int_setting(source, "PROVIDER_BACKOFF_WINDOW_SECONDS", backoff, "provider_backoff_window_seconds", 60)), + provider_backoff_threshold=max(1, _int_setting(source, "PROVIDER_BACKOFF_THRESHOLD", backoff, "provider_backoff_threshold", 3)), + provider_backoff_base_seconds=_optional_int_setting(source, "PROVIDER_BACKOFF_BASE_SECONDS", backoff, "provider_backoff_base_seconds"), + provider_backoff_max_seconds=max(1, _int_setting(source, "PROVIDER_BACKOFF_MAX_SECONDS", backoff, "provider_backoff_max_seconds", 300)), + failure_history_max_entries=max(1, _int_setting(source, "FAILURE_HISTORY_MAX_ENTRIES", backoff, "failure_history_max_entries", 200)), + ) + + +def get_responses_store_settings( + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> Any: + """Return Responses store settings with environment overriding JSON.""" + + from ..responses import ResponsesStoreSettings + + source = env if env is not None else os.environ + active = config if config is not None else load_experimental_config(env=source) + responses = active.responses if isinstance(active.responses, dict) else {} + store = responses.get("store", {}) if isinstance(responses.get("store"), dict) else responses + ttl_seconds = _optional_positive_int(_env_or_json(source, "RESPONSES_STORE_TTL_SECONDS", store, "ttl_seconds"), "RESPONSES_STORE_TTL_SECONDS") + max_items = _optional_positive_int(_env_or_json(source, "RESPONSES_STORE_MAX_ITEMS", store, "max_items"), "RESPONSES_STORE_MAX_ITEMS") + return ResponsesStoreSettings( + ttl_seconds=ttl_seconds, + max_items=max_items, + store_failed=as_bool(_env_or_json(source, "RESPONSES_STORE_FAILED", store, "store_failed", default=True), name="RESPONSES_STORE_FAILED"), + store_in_progress=as_bool(_env_or_json(source, "RESPONSES_STORE_IN_PROGRESS", store, "store_in_progress", default=False), name="RESPONSES_STORE_IN_PROGRESS"), + ) + + +def get_responses_store_runtime_settings( + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> ResponsesStoreRuntimeSettings: + """Return Responses store backend settings with env overriding JSON.""" + + source = env if env is not None else os.environ + active = config if config is not None else load_experimental_config(env=source) + responses = active.responses if isinstance(active.responses, dict) else {} + store = responses.get("store", {}) if isinstance(responses.get("store"), dict) else responses + backend = str(_env_or_json(source, "RESPONSES_STORE_BACKEND", store, "backend", default="memory")).strip().lower() + if backend not in {"memory", "provider_cache"}: + raise ExperimentalConfigError("RESPONSES_STORE_BACKEND must be 'memory' or 'provider_cache'") + return ResponsesStoreRuntimeSettings( + backend=backend, + cache_name=str(_env_or_json(source, "RESPONSES_STORE_CACHE_NAME", store, "cache_name", default="responses")), + cache_prefix=str(_env_or_json(source, "RESPONSES_STORE_CACHE_PREFIX", store, "cache_prefix", default="responses")), + cache_dir=_optional_string(_env_or_json(source, "RESPONSES_STORE_CACHE_DIR", store, "cache_dir")), + cache_memory_ttl_seconds=max(1, _int_setting(source, "RESPONSES_STORE_CACHE_MEMORY_TTL_SECONDS", store, "cache_memory_ttl_seconds", 3600)), + cache_disk_ttl_seconds=max(1, _int_setting(source, "RESPONSES_STORE_CACHE_DISK_TTL_SECONDS", store, "cache_disk_ttl_seconds", 172800)), + ) + + +def parse_field_cache_rules(config: ExperimentalConfig, provider: str, model: str) -> tuple[FieldCacheRule, ...]: + """Parse configured field-cache rules for a provider/model. + + Wildcard model rules are returned before exact model rules so providers can + define general preservation behavior and then append model-specific rules. + This helper is intentionally not auto-wired into providers; providers decide + whether external config is appropriate for their protocol state. + """ + + provider_rules = config.field_cache.get(provider, {}) if isinstance(config.field_cache, dict) else {} + if not isinstance(provider_rules, dict): + raise ExperimentalConfigError("field_cache provider section must be an object") + raw_rules: list[Any] = [] + keys = ["*"] + if "/" in model: + keys.append(model.split("/", 1)[1]) + keys.append(model) + for key in dict.fromkeys(keys): + value = provider_rules.get(key, []) + if isinstance(value, list): + raw_rules.extend(value) + elif value not in (None, []): + raise ExperimentalConfigError("field_cache model rules must be a list") + parsed_rules = [] + for rule in raw_rules: + if not isinstance(rule, dict): + raise ExperimentalConfigError("field_cache rule entries must be objects") + parsed_rules.append(_field_cache_rule_from_dict(rule)) + return tuple(parsed_rules) + + +def as_bool(value: Any, *, name: str) -> bool: + """Parse a JSON/env boolean value.""" + + if isinstance(value, bool): + return value + if isinstance(value, (int, float)) and value in (0, 1): + return bool(value) + if isinstance(value, str): + text = value.strip().lower() + if text in {"1", "true", "yes", "on"}: + return True + if text in {"0", "false", "no", "off"}: + return False + raise ExperimentalConfigError(f"Invalid boolean for {name}") + + +def as_float(value: Any, *, name: str) -> float: + """Parse a JSON/env float value with redacted errors.""" + + try: + return float(value) + except (TypeError, ValueError) as exc: + raise ExperimentalConfigError(f"Invalid number for {name}") from exc + + +def as_int(value: Any, *, name: str) -> int: + """Parse a JSON/env integer value with redacted errors.""" + + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ExperimentalConfigError(f"Invalid integer for {name}") from exc + + +def env_price_key(provider: str, model: str, suffix: str) -> str: + """Return normalized model-price environment variable name.""" + + return f"MODEL_PRICE_{_env_part(provider)}_{_env_part(model)}_{_env_part(suffix)}" + + +def _path_from_env(env: Mapping[str, str]) -> Optional[Path]: + for key in _CONFIG_ENV_KEYS: + value = env.get(key) + if value: + return Path(value) + return None + + +def _dict_section(data: Mapping[str, Any], key: str) -> dict[str, Any]: + value = data.get(key, {}) + if value is None: + return {} + if not isinstance(value, dict): + raise ExperimentalConfigError(f"Config section '{key}' must be an object") + return dict(value) + + +def _reject_secret_keys(value: Any, path: str = "config") -> None: + if isinstance(value, Mapping): + for key, nested in value.items(): + key_text = str(key).lower() + compact_key = re.sub(r"[^a-z0-9]+", "", key_text) + underscored_key = re.sub(r"[^a-z0-9]+", "_", key_text) + if any(part in key_text or part in compact_key or part in underscored_key for part in _SECRET_KEY_PARTS): + raise ExperimentalConfigError(f"Unsafe secret-like key in JSON config at {path}.{key}") + _reject_secret_keys(nested, f"{path}.{key}") + elif isinstance(value, list): + for index, nested in enumerate(value): + _reject_secret_keys(nested, f"{path}[{index}]") + + +def _pricing_from_env(provider: str, model: str, env: Mapping[str, str]) -> Optional[ModelPricing]: + suffixes = { + "input": "INPUT", + "output": "OUTPUT", + "cache_read": "CACHE_READ", + "cache_write": "CACHE_WRITE", + "reasoning": "REASONING", + } + values: dict[str, float] = {} + for field_name, suffix in suffixes.items(): + key = env_price_key(provider, model, suffix) + raw = env.get(key) + if raw not in (None, ""): + values[field_name] = as_float(raw, name=key) + if not values: + return None + return ModelPricing( + input_cost_per_token=values.get("input", 0.0), + output_cost_per_token=values.get("output", 0.0), + cache_read_cost_per_token=values.get("cache_read", 0.0), + cache_write_cost_per_token=values.get("cache_write", 0.0), + reasoning_cost_per_token=values.get("reasoning", 0.0), + source="env", + ) + + +def _env_or_json(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], json_key: str, default: Any = None) -> Any: + if env_key in env: + return env[env_key] + return data.get(json_key, default) + + +def _int_setting(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], json_key: str, default: int) -> int: + if env_key in env: + try: + return int(env.get(env_key) or default) + except (TypeError, ValueError): + return default + return as_int(data.get(json_key, default), name=env_key) + + +def _optional_int_setting(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], json_key: str) -> Optional[int]: + if env_key in env: + try: + parsed = int(env.get(env_key) or 0) + except (TypeError, ValueError): + return None + return parsed if parsed > 0 else None + return _optional_positive_int(data.get(json_key), env_key) + + +def _bool_setting(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], json_key: str, default: bool) -> bool: + if env_key in env: + try: + return as_bool(env.get(env_key), name=env_key) + except ExperimentalConfigError: + return default + return as_bool(data.get(json_key, default), name=env_key) + + +def _optional_positive_float(value: Any, name: str) -> Optional[float]: + if value in (None, ""): + return None + parsed = as_float(value, name=name) + # Zero and negative values mean "not configured" for timeout-like knobs. + # Runtime enforcement is intentionally disabled by default. + return parsed if parsed > 0 else None + + +def _optional_positive_int(value: Any, name: str) -> Optional[int]: + if value in (None, ""): + return None + parsed = as_int(value, name=name) + return parsed if parsed > 0 else None + + +def _optional_string(value: Any) -> Optional[str]: + if value in (None, ""): + return None + return str(value) + + +def _field_cache_rule_from_dict(data: Mapping[str, Any]) -> FieldCacheRule: + inject_data = data.get("inject") + inject = None + if isinstance(inject_data, Mapping): + inject = FieldCacheInjection( + target=str(inject_data.get("target", "request")), + path=str(inject_data.get("path", data.get("target_path", ""))), + when_missing_only=as_bool(inject_data.get("when_missing_only", False), name="field_cache.inject.when_missing_only"), + insert=as_bool(inject_data.get("insert", False), name="field_cache.inject.insert"), + as_list=as_bool(inject_data.get("as_list", False), name="field_cache.inject.as_list"), + ) + elif inject_data is not None: + raise ExperimentalConfigError("field_cache.inject must be an object") + elif data.get("target_path"): + inject = FieldCacheInjection(target=str(data.get("target", "request")), path=str(data["target_path"])) + scope = data.get("scope", ("provider", "model", "classifier", "session")) + if isinstance(scope, str): + scope_values = tuple(part.strip() for part in scope.split(",") if part.strip()) + elif isinstance(scope, (list, tuple)): + scope_values = tuple(str(part) for part in scope) + else: + raise ExperimentalConfigError("field_cache.scope must be a string or list") + try: + return FieldCacheRule( + name=str(data["name"]), + source=str(data["source"]), + path=str(data["path"]), + mode=str(data.get("mode", "last")), + scope=scope_values, + inject=inject, + enabled=as_bool(data.get("enabled", True), name="field_cache.enabled"), + ttl_seconds=int(data["ttl_seconds"]) if data.get("ttl_seconds") is not None else None, + metadata=_metadata_dict(data.get("metadata", {})), + allow_missing_session=as_bool(data.get("allow_missing_session", False), name="field_cache.allow_missing_session"), + ) + except KeyError as exc: + raise ExperimentalConfigError(f"Missing field-cache rule key {exc.args[0]}") from exc + except ValueError as exc: + raise ExperimentalConfigError(str(exc)) from exc + + +def _metadata_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return dict(value) + raise ExperimentalConfigError("field_cache.metadata must be an object") + + +def _env_part(value: str) -> str: + return re.sub(r"[^A-Z0-9]+", "_", value.upper()).strip("_") diff --git a/src/rotator_library/cooldown_manager.py b/src/rotator_library/cooldown_manager.py index 0d1bb63ee..a4dbafefd 100644 --- a/src/rotator_library/cooldown_manager.py +++ b/src/rotator_library/cooldown_manager.py @@ -3,9 +3,21 @@ import asyncio import time +from dataclasses import dataclass from typing import Dict +@dataclass(frozen=True) +class CooldownSnapshot: + """Read-only view of one active cooldown scope for tests/observability.""" + + provider: str + scope: str + model: str | None + remaining: float + reason: str | None = None + + class CooldownManager: """ Manages global cooldown periods for API providers to handle IP-based rate limiting. @@ -15,34 +27,115 @@ class CooldownManager: def __init__(self): self._cooldowns: Dict[str, float] = {} + self._metadata: Dict[str, dict[str, str | None]] = {} self._lock = asyncio.Lock() async def is_cooling_down(self, provider: str) -> bool: """Checks if a provider is currently in a cooldown period.""" - async with self._lock: - return ( - provider in self._cooldowns and time.time() < self._cooldowns[provider] - ) + return await self.is_scoped_cooling_down(provider, scope="provider") async def start_cooldown(self, provider: str, duration: int): """ Initiates or extends a cooldown period for a provider. The cooldown is set to the current time plus the specified duration. + + A shorter new cooldown must not shorten an existing longer cooldown; + provider-wide throttles often arrive concurrently from several requests. """ - async with self._lock: - self._cooldowns[provider] = time.time() + duration + await self.start_scoped_cooldown(provider, duration, scope="provider") async def get_cooldown_remaining(self, provider: str) -> float: """ Returns the remaining cooldown time in seconds for a provider. Returns 0 if the provider is not in a cooldown period. """ - async with self._lock: - if provider in self._cooldowns: - remaining = self._cooldowns[provider] - time.time() - return max(0, remaining) - return 0 + return await self.get_scoped_remaining(provider, scope="provider") async def get_remaining_cooldown(self, provider: str) -> float: """Backward-compatible alias for get_cooldown_remaining.""" return await self.get_cooldown_remaining(provider) + + async def start_scoped_cooldown( + self, + provider: str, + duration: int, + *, + model: str | None = None, + scope: str = "provider", + reason: str | None = None, + ) -> None: + """Start or extend a provider/model cooldown scope. + + Model scopes are intentionally separate from provider scopes because + capacity failures often belong to one model deployment rather than the + entire provider. Provider scopes remain available for provider-wide + throttles and block every model for the provider. + """ + + key = _cooldown_key(provider, scope=scope, model=model) + async with self._lock: + new_expiry = time.time() + max(0, duration) + current_expiry = self._cooldowns.get(key, 0) + if new_expiry > current_expiry: + self._cooldowns[key] = new_expiry + self._metadata[key] = {"provider": provider, "scope": scope, "model": model, "reason": reason} + + async def get_scoped_remaining(self, provider: str, *, model: str | None = None, scope: str = "provider") -> float: + """Return remaining seconds for one cooldown scope.""" + + key = _cooldown_key(provider, scope=scope, model=model) + async with self._lock: + return self._remaining_for_key(key, time.time()) + + async def get_max_remaining(self, provider: str, *, model: str | None = None) -> float: + """Return the max provider/model cooldown remaining for a request.""" + + async with self._lock: + now = time.time() + provider_remaining = self._remaining_for_key(_cooldown_key(provider, scope="provider"), now) + model_remaining = self._remaining_for_key(_cooldown_key(provider, scope="model", model=model), now) if model else 0 + return max(provider_remaining, model_remaining) + + async def is_scoped_cooling_down(self, provider: str, *, model: str | None = None, scope: str = "provider") -> bool: + """Return whether one scoped cooldown is currently active.""" + + return (await self.get_scoped_remaining(provider, model=model, scope=scope)) > 0 + + async def snapshot(self) -> tuple[CooldownSnapshot, ...]: + """Return active cooldown scopes for tests and future observability.""" + + async with self._lock: + now = time.time() + snapshots = [] + for key, expires_at in list(self._cooldowns.items()): + remaining = max(0, expires_at - now) + if remaining <= 0: + continue + metadata = self._metadata.get(key, {}) + snapshots.append( + CooldownSnapshot( + provider=str(metadata.get("provider") or key), + scope=str(metadata.get("scope") or "provider"), + model=metadata.get("model"), + remaining=remaining, + reason=metadata.get("reason"), + ) + ) + return tuple(snapshots) + + def _remaining_for_key(self, key: str, now: float) -> float: + expires_at = self._cooldowns.get(key) + if expires_at is None: + return 0 + remaining = expires_at - now + if remaining <= 0: + self._cooldowns.pop(key, None) + self._metadata.pop(key, None) + return 0 + return remaining + + +def _cooldown_key(provider: str, *, scope: str, model: str | None = None) -> str: + if scope == "model" and model: + return f"provider:{provider}:model:{model}" + return f"provider:{provider}" diff --git a/src/rotator_library/core/types.py b/src/rotator_library/core/types.py index c502592ec..f9f07a44b 100644 --- a/src/rotator_library/core/types.py +++ b/src/rotator_library/core/types.py @@ -89,6 +89,11 @@ class RequestContext: provider_config: Optional[Dict[str, Any]] = None credential_secrets: Dict[str, str] = field(default_factory=dict) classifier: Optional[str] = None + routing_targets: Optional[Any] = None + routing_group_name: Optional[str] = None + routing_group: Optional[Any] = None + routing_target_index: int = 0 + routing_attempt_history: List[Dict[str, Any]] = field(default_factory=list) @dataclass diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index e0130c17c..5388b1c7d 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -759,12 +759,26 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr if isinstance(e, dict): payload = e.get("error", e) if isinstance(payload, dict): + details = payload.get("details") if isinstance(payload.get("details"), dict) else {} code = payload.get("code") + status_value = payload.get("status_code") or payload.get("status") or details.get("status_code") or details.get("status") or code status = str(payload.get("status", "")).upper() try: - status_code = int(code) if code is not None else None + status_code = int(status_value) if status_value is not None else None except (TypeError, ValueError): status_code = None + structured_type = _classify_structured_error_text(payload, details) + if status_code == 400: + return ClassifiedError(error_type=structured_type or "invalid_request", original_exception=e, status_code=status_code) + if status_code == 401: + return ClassifiedError(error_type="authentication", original_exception=e, status_code=status_code) + if status_code == 403: + return ClassifiedError(error_type="forbidden", original_exception=e, status_code=status_code) + if status_code == 429: + body = str(payload).lower() + return ClassifiedError(error_type="quota_exceeded" if "quota" in body or "resource_exhausted" in body else "rate_limit", original_exception=e, status_code=status_code) + if structured_type: + return ClassifiedError(error_type=structured_type, original_exception=e, status_code=status_code) if (status_code is not None and status_code >= 500) or status in { "INTERNAL", "UNAVAILABLE", @@ -775,6 +789,14 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code=status_code or 503, ) + explicit_error_type = getattr(e, "error_type", None) + if explicit_error_type: + return ClassifiedError( + error_type=str(explicit_error_type).strip().lower().replace("-", "_").replace(" ", "_"), + original_exception=e, + status_code=getattr(e, "status_code", None), + ) + error_text = str(e) error_type_name = type(e).__name__ if ( @@ -925,11 +947,6 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr original_exception=e, status_code=status_code, ) - return ClassifiedError( - error_type="invalid_request", - original_exception=e, - status_code=status_code, - ) if 400 <= status_code < 500: # Other 4xx errors - generally client errors return ClassifiedError( @@ -1073,6 +1090,13 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code=status_code or 503, ) + if "model_capacity_exhausted" in error_text.lower() or "model capacity" in error_text.lower() or "capacity exhausted" in error_text.lower(): + return ClassifiedError( + error_type="server_error", + original_exception=e, + status_code=status_code or 503, + ) + # Fallback for any other unclassified errors return ClassifiedError( error_type="unknown", original_exception=e, status_code=status_code @@ -1084,6 +1108,34 @@ def is_rate_limit_error(e: Exception) -> bool: return isinstance(e, RateLimitError) +def _classify_structured_error_text(payload: dict, details: dict) -> Optional[str]: + """Classify structured error type/code/status text without raw messages.""" + + values = [payload.get("type"), payload.get("code"), payload.get("status"), details.get("error_type"), details.get("classification"), details.get("status")] + normalized = {str(value).strip().lower().replace("-", "_").replace(" ", "_") for value in values if value not in (None, "")} + if normalized & {"authentication", "auth", "unauthorized", "invalid_api_key"}: + return "authentication" + if normalized & {"forbidden", "permission_denied", "access_denied"}: + return "forbidden" + if normalized & {"context_window_exceeded", "context_length", "context_length_exceeded", "too_many_tokens"}: + return "context_window_exceeded" + if normalized & {"invalid_request", "bad_request", "validation", "invalid_argument"}: + return "invalid_request" + if normalized & {"credential_reauth_needed"}: + return "credential_reauth_needed" + if normalized & {"configuration_error", "config", "configuration"}: + return "configuration_error" + if normalized & {"rate_limit", "rate_limited", "too_many_requests"}: + return "rate_limit" + if normalized & {"quota_exceeded", "resource_exhausted", "quota"}: + return "quota_exceeded" + if normalized & {"server_error", "internal", "unavailable"}: + return "server_error" + if normalized & {"api_connection", "network", "connection"}: + return "api_connection" + return None + + def is_server_error(e: Exception) -> bool: """Checks if the exception is a temporary server-side error.""" return isinstance( diff --git a/src/rotator_library/field_cache/__init__.py b/src/rotator_library/field_cache/__init__.py new file mode 100644 index 000000000..ae5d62aa2 --- /dev/null +++ b/src/rotator_library/field_cache/__init__.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Configurable provider field-cache rules and helpers.""" + +from .paths import FieldCachePathError, extract_path, inject_path, parse_path +from .engine import FieldCacheEngine, FieldCacheOperation, build_cache_key +from .store import FieldCacheStore, InMemoryFieldCacheStore, ProviderCacheFieldStore +from .types import FieldCacheContext, FieldCacheInjection, FieldCacheRule + +__all__ = [ + "FieldCacheContext", + "FieldCacheEngine", + "FieldCacheInjection", + "FieldCacheOperation", + "FieldCachePathError", + "FieldCacheRule", + "FieldCacheStore", + "InMemoryFieldCacheStore", + "ProviderCacheFieldStore", + "build_cache_key", + "extract_path", + "inject_path", + "parse_path", +] diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py new file mode 100644 index 000000000..c9a2134da --- /dev/null +++ b/src/rotator_library/field_cache/engine.py @@ -0,0 +1,461 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Extraction and injection engine for field-cache rules.""" + +from __future__ import annotations + +import hashlib +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Iterable, Optional + +from .paths import FieldCachePathError, extract_path, inject_path, parse_path +from .store import FieldCacheStore, InMemoryFieldCacheStore +from .types import FieldCacheContext, FieldCacheRule + + +@dataclass +class FieldCacheOperation: + """Summary of one field-cache rule application.""" + + rule_name: str + cache_key: Optional[str] + matched: int = 0 + changed: bool = False + hit: bool = False + skipped: bool = False + reason: Optional[str] = None + sample_values: list[Any] = field(default_factory=list) + + +def _safe_scope_value(value: str) -> str: + digest = hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] + return digest + + +def build_cache_key(rule: FieldCacheRule, context: FieldCacheContext) -> Optional[str]: + """Build a scoped cache key or return None when required scope is absent.""" + + parts = [f"rule={rule.name}"] + for scope in rule.scope: + value = context.value_for_scope(scope) + if value is None or value == "": + if scope == "credential": + return None + if scope == "session" and not rule.allow_missing_session: + return None + value = "_none" + if scope in {"provider", "model"}: + safe_value = value.replace("/", "_").replace("\\", "_").replace(":", "_") + else: + safe_value = _safe_scope_value(value) + parts.append(f"{scope}={safe_value}") + return "|".join(parts) + + +class FieldCacheEngine: + """Apply field-cache extraction and injection rules. + + The engine preserves provider protocol state; it is not session tracking. + It defaults to copying payloads before injection so providers can opt into + mutation explicitly. Turn and tool-call modes skip safely when the requested + context cannot be inferred rather than silently falling back to `last`. + """ + + def __init__(self, rules: Iterable[FieldCacheRule], store: Optional[FieldCacheStore] = None) -> None: + self.rules = tuple(rules) + self.store = store or InMemoryFieldCacheStore() + self._validate_rules() + + def _validate_rules(self) -> None: + names: set[str] = set() + for rule in self.rules: + if rule.name in names: + raise ValueError(f"Duplicate field-cache rule name: {rule.name}") + names.add(rule.name) + parse_path(rule.path) + if rule.inject: + parse_path(rule.inject.path) + if rule.mode == "per_tool_call" and not rule.metadata.get("tool_call_id_path"): + raise ValueError("per_tool_call field-cache mode requires metadata.tool_call_id_path") + + async def extract( + self, + source: str, + payload: Any, + context: FieldCacheContext, + *, + transaction_logger: Optional[Any] = None, + ) -> list[FieldCacheOperation]: + operations: list[FieldCacheOperation] = [] + rules = self._rules_for_source(source) + self._trace_summary(transaction_logger, "field_cache_extraction_start", payload, source=source, target=None, rules=rules, operations=operations) + for rule in rules: + operation = FieldCacheOperation(rule_name=rule.name, cache_key=build_cache_key(rule, context)) + self._trace(transaction_logger, "before_field_cache_extraction", payload, rule, operation, source=source) + if not operation.cache_key: + operation.skipped = True + operation.reason = "missing_required_scope" + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_extraction", payload, rule, operation, source=source) + continue + try: + values = extract_path(payload, rule.path) + operation.matched = len(values) + operation.sample_values = _sample_values(values) + if values or rule.mode in {"last_user_turn", "last_assistant_turn", "per_tool_call"}: + operation.changed = await self._store_values(rule, operation.cache_key, values, payload, operation) + except Exception as exc: + self._log_error(transaction_logger, "field_cache_extract", exc, payload, rule) + raise + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_extraction", payload, rule, operation, source=source) + self._trace_summary(transaction_logger, "field_cache_extraction_complete", payload, source=source, target=None, rules=rules, operations=operations) + return operations + + async def inject( + self, + target: str, + payload: Any, + context: FieldCacheContext, + *, + transaction_logger: Optional[Any] = None, + mutate: bool = False, + ) -> tuple[Any, list[FieldCacheOperation]]: + updated = payload if mutate else deepcopy(payload) + operations: list[FieldCacheOperation] = [] + rules = self._rules_for_injection(target) + self._trace_summary(transaction_logger, "field_cache_injection_start", updated, source=None, target=target, rules=rules, operations=operations) + for rule in rules: + operation = FieldCacheOperation(rule_name=rule.name, cache_key=build_cache_key(rule, context)) + if not rule.inject: + continue + self._trace(transaction_logger, "before_field_cache_injection", updated, rule, operation, target=target) + if not operation.cache_key: + operation.skipped = True + operation.reason = "missing_required_scope" + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + continue + try: + cached = await self.store.get(operation.cache_key) + if cached is None: + operation.reason = "cache_miss" + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + continue + operation.hit = True + value = self._injection_value(rule, cached, updated, context, operation) + if operation.skipped: + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + continue + operation.changed = inject_path( + updated, + rule.inject.path, + value, + when_missing_only=rule.inject.when_missing_only, + insert=rule.inject.insert, + ) + operation.sample_values = _sample_values(value if isinstance(value, list) else [value]) + except Exception as exc: + self._log_error(transaction_logger, "field_cache_inject", exc, updated, rule) + raise + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + self._trace_summary(transaction_logger, "field_cache_injection_complete", updated, source=None, target=target, rules=rules, operations=operations) + return updated, operations + + def _rules_for_source(self, source: str) -> list[FieldCacheRule]: + return [rule for rule in self.rules if rule.enabled and rule.source == source] + + def _rules_for_injection(self, target: str) -> list[FieldCacheRule]: + return [rule for rule in self.rules if rule.enabled and rule.inject and rule.inject.target == target] + + async def _store_values(self, rule: FieldCacheRule, cache_key: str, values: list[Any], payload: Any, operation: FieldCacheOperation) -> bool: + if rule.mode == "all": + await self._store_append(cache_key, values, ttl_seconds=rule.ttl_seconds) + return True + if rule.mode == "last": + await self._store_set(cache_key, _wrap_cached_value(values[-1]), ttl_seconds=rule.ttl_seconds) + return True + if rule.mode in {"last_user_turn", "last_assistant_turn"}: + role = "user" if rule.mode == "last_user_turn" else "assistant" + turn_values = _turn_values(rule, payload, role) + if not turn_values: + operation.skipped = True + operation.reason = "turn_context_not_found" + return False + operation.matched = len(turn_values) + operation.sample_values = _sample_values(turn_values) + await self._store_set(cache_key, _wrap_cached_value(turn_values[-1]), ttl_seconds=rule.ttl_seconds) + return True + if rule.mode == "per_tool_call": + stored = _tool_call_values(rule, payload, values) + if not stored: + operation.skipped = True + operation.reason = "tool_call_id_not_found" + return False + operation.matched = len(stored) + operation.sample_values = _sample_values(list(stored.values())) + await self._store_set(cache_key, stored, ttl_seconds=rule.ttl_seconds) + return True + raise ValueError(f"Unsupported field-cache mode: {rule.mode}") + + async def _store_set(self, cache_key: str, value: Any, *, ttl_seconds: Optional[int]) -> None: + try: + await self.store.set(cache_key, value, ttl_seconds=ttl_seconds) + except TypeError: + # Preserve compatibility with simple injected stores that implement + # the original set(key, value) shape. TTL is best-effort there. + await self.store.set(cache_key, value) + + async def _store_append(self, cache_key: str, values: list[Any], *, ttl_seconds: Optional[int]) -> None: + try: + await self.store.append(cache_key, values, ttl_seconds=ttl_seconds) + except TypeError: + await self.store.append(cache_key, values) + + def _injection_value(self, rule: FieldCacheRule, cached: Any, payload: Any, context: FieldCacheContext, operation: FieldCacheOperation) -> Any: + """Select the cached value to inject for the rule's mode. + + Per-tool-call maps require a current tool-call ID so the engine never + injects an arbitrary provider signature into the wrong tool result. + """ + + if rule.mode == "per_tool_call": + if not isinstance(cached, dict): + operation.skipped = True + operation.reason = "invalid_tool_call_cache" + return None + ids = _injection_tool_ids(rule, payload, context) + if not ids: + operation.skipped = True + operation.reason = "tool_call_id_not_found" + return None + matches = [_unwrap_cached_value(cached[str(tool_id)]) for tool_id in ids if str(tool_id) in cached] + if not matches: + operation.skipped = True + operation.reason = "tool_call_cache_miss" + return None + if rule.inject and rule.inject.as_list: + return matches + if len(matches) == 1: + return matches[0] + operation.skipped = True + operation.reason = "ambiguous_tool_call_values" + return None + if rule.mode == "all": + return cached if isinstance(cached, list) else [cached] + if rule.inject and rule.inject.as_list: + unwrapped = _unwrap_cached_value(cached) + return unwrapped if isinstance(unwrapped, list) else [unwrapped] + return _unwrap_cached_value(cached) + + def _trace( + self, + transaction_logger: Optional[Any], + pass_name: str, + payload: Any, + rule: FieldCacheRule, + operation: FieldCacheOperation, + **extra_metadata: Any, + ) -> None: + if not transaction_logger: + return + transaction_logger.log_transform_pass( + pass_name, + _payload_shape(payload), + direction=_trace_direction(pass_name, rule.source, extra_metadata), + stage="adapter", + metadata={ + "rule_name": rule.name, + "source": rule.source, + "path": rule.path, + "mode": rule.mode, + "scope": list(rule.scope), + "cache_key": operation.cache_key, + "matched": operation.matched, + "changed": operation.changed, + "hit": operation.hit, + "skipped": operation.skipped, + "reason": operation.reason, + # Cached fields can include provider signatures or session keys. + # Trace only shape/count metadata; keep raw samples out of logs. + "sample_value_count": len(operation.sample_values), + "sample_value_types": [type(value).__name__ for value in operation.sample_values[:3]], + **extra_metadata, + }, + snapshot=rule.source != "stream_event", + ) + + def _trace_summary( + self, + transaction_logger: Optional[Any], + pass_name: str, + payload: Any, + *, + source: Optional[str], + target: Optional[str], + rules: list[FieldCacheRule], + operations: list[FieldCacheOperation], + ) -> None: + """Record cache-pass boundaries even when no individual rule matches.""" + + if not transaction_logger: + return + transaction_logger.log_transform_pass( + pass_name, + _payload_shape(payload), + direction="request" if target or source == "request" else "response" if source == "response" else "stream" if source == "stream_event" else "metadata", + stage="adapter", + metadata={ + "source": source, + "target": target, + "rule_count": len(rules), + "operation_count": len(operations), + "matched_count": sum(operation.matched for operation in operations), + "changed_count": sum(1 for operation in operations if operation.changed), + "hit_count": sum(1 for operation in operations if operation.hit), + "skipped_count": sum(1 for operation in operations if operation.skipped), + }, + snapshot=(source != "stream_event"), + ) + + def _log_error(self, transaction_logger: Optional[Any], pass_name: str, error: BaseException, payload: Any, rule: FieldCacheRule) -> None: + if not transaction_logger: + return + transaction_logger.log_transform_error( + pass_name, + error, + payload=_payload_shape(payload), + stage="adapter", + metadata={"rule_name": rule.name, "path": rule.path, "mode": rule.mode}, + ) + + +def _last_value(value: Any) -> Any: + if isinstance(value, list): + return value[-1] if value else None + return value + + +def _wrap_cached_value(value: Any) -> dict[str, Any]: + """Wrap one extracted value so list-valued fields stay intact on injection.""" + + return {"__field_cache_value__": True, "value": deepcopy(value)} + + +def _unwrap_cached_value(value: Any) -> Any: + if isinstance(value, dict) and value.get("__field_cache_value__") is True: + return deepcopy(value.get("value")) + return _last_value(value) + + +def _turn_values(rule: FieldCacheRule, payload: Any, role: str) -> list[Any]: + """Return values from the latest turn matching `role`. + + Rules can provide explicit turn paths for provider-specific payloads. The + default handles the common `messages[*]` shape used by OpenAI-compatible and + Responses-like requests. + """ + + container_path = rule.metadata.get("turn_container_path", "messages") + role_path = rule.metadata.get("turn_role_path", "role") + value_path = rule.metadata.get("turn_value_path") or _message_relative_path(rule.path, container_path) + if not value_path: + return [] + turns = extract_path(payload, str(container_path)) + if len(turns) == 1 and isinstance(turns[0], list): + turns = turns[0] + latest: list[Any] = [] + for turn in turns: + roles = extract_path(turn, str(role_path)) + if roles and str(roles[0]) == role: + values = extract_path(turn, str(value_path)) + if values: + latest = values + return latest + + +def _message_relative_path(path: str, container_path: str) -> Optional[str]: + prefixes = (f"{container_path}.*.", f"{container_path}[-1].") + for prefix in prefixes: + if path.startswith(prefix): + return path[len(prefix) :] + return None + + +def _tool_call_values(rule: FieldCacheRule, payload: Any, values: list[Any]) -> dict[str, Any]: + """Correlate cached values to provider tool-call IDs.""" + + container_path = rule.metadata.get("tool_container_path") + tool_id_path = rule.metadata.get("tool_call_id_path") + tool_value_path = rule.metadata.get("tool_value_path") + stored: dict[str, Any] = {} + if container_path and tool_value_path: + containers = extract_path(payload, str(container_path)) + if len(containers) == 1 and isinstance(containers[0], list): + containers = containers[0] + for container in containers: + tool_ids = extract_path(container, str(tool_id_path)) if tool_id_path else [] + tool_values = extract_path(container, str(tool_value_path)) + if tool_ids and tool_values: + stored[str(tool_ids[0])] = _wrap_cached_value(tool_values[-1]) + return stored + for value in values: + tool_ids = extract_path(value, str(tool_id_path)) if tool_id_path else [] + if tool_ids: + stored[str(tool_ids[0])] = _wrap_cached_value(value) + return stored + + +def _injection_tool_ids(rule: FieldCacheRule, payload: Any, context: FieldCacheContext) -> list[str]: + configured = context.metadata.get("tool_call_id") + if configured: + return [str(configured)] + inject_path_value = rule.metadata.get("inject_tool_call_id_path") + if inject_path_value: + return [str(value) for value in extract_path(payload, str(inject_path_value))] + return [] + + +def _trace_direction(pass_name: str, source: str, metadata: dict[str, Any]) -> str: + if "injection" in pass_name: + target = metadata.get("target") + if target in {"stream_event", "unified_stream_event"}: + return "stream" + if target in {"request", "unified_request", "metadata"}: + return "request" + return "response" + if source in {"stream_event", "unified_stream_event"}: + return "stream" + if source in {"request", "unified_request"}: + return "request" + return "response" + + +def _sample_values(values: list[Any], *, max_items: int = 3, max_text: int = 500) -> list[Any]: + samples: list[Any] = [] + for value in values[:max_items]: + if isinstance(value, str) and len(value) > max_text: + samples.append(f"{value[:max_text]}...") + else: + samples.append(value) + return samples + + +def _payload_shape(payload: Any) -> dict[str, Any]: + """Return non-sensitive payload shape metadata for cache traces. + + Cache rules often target provider signatures, session IDs, and other opaque + state. Logging full payloads would expose exactly the fields the cache is + designed to preserve, so field-cache traces record structure only. + """ + + if isinstance(payload, dict): + return {"payload_type": "dict", "keys": sorted(str(key) for key in payload.keys())[:20]} + if isinstance(payload, list): + return {"payload_type": "list", "length": len(payload)} + return {"payload_type": type(payload).__name__} diff --git a/src/rotator_library/field_cache/paths.py b/src/rotator_library/field_cache/paths.py new file mode 100644 index 000000000..1dfc29203 --- /dev/null +++ b/src/rotator_library/field_cache/paths.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Small JSON-path-like selector used by field-cache rules.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + + +class FieldCachePathError(ValueError): + """Raised when a field-cache path is malformed or cannot be injected.""" + + +@dataclass(frozen=True) +class PathToken: + kind: Literal["key", "index", "wildcard"] + value: str | int | None = None + + +def parse_path(path: str) -> tuple[PathToken, ...]: + """Parse a minimal dotted path with indexes and wildcards. + + Supported examples: `choices.0.message.content`, `choices.*.message`, + `messages[-1].reasoning_content`. Escaping is intentionally unsupported in + Phase 3 so malformed provider configs fail early and clearly. + """ + + if not path or path.startswith(".") or path.endswith(".") or ".." in path: + raise FieldCachePathError(f"Malformed field-cache path: {path!r}") + tokens: list[PathToken] = [] + for segment in path.split("."): + if not segment: + raise FieldCachePathError(f"Malformed field-cache path: {path!r}") + _parse_segment(segment, tokens, path) + return tuple(tokens) + + +def _parse_segment(segment: str, tokens: list[PathToken], full_path: str) -> None: + if segment == "*": + tokens.append(PathToken("wildcard")) + return + cursor = 0 + while cursor < len(segment): + bracket = segment.find("[", cursor) + if bracket == -1: + value = segment[cursor:] + if value: + if value.lstrip("-").isdigit(): + tokens.append(PathToken("index", int(value))) + else: + tokens.append(PathToken("key", value)) + return + if bracket > cursor: + tokens.append(PathToken("key", segment[cursor:bracket])) + close = segment.find("]", bracket) + if close == -1: + raise FieldCachePathError(f"Unclosed index in field-cache path: {full_path!r}") + raw_index = segment[bracket + 1 : close] + if raw_index == "*": + tokens.append(PathToken("wildcard")) + else: + try: + tokens.append(PathToken("index", int(raw_index))) + except ValueError as exc: + raise FieldCachePathError(f"Invalid index {raw_index!r} in field-cache path: {full_path!r}") from exc + cursor = close + 1 + + +def extract_path(payload: Any, path: str) -> list[Any]: + """Return every value matching path in stable traversal order.""" + + tokens = parse_path(path) + current = [payload] + for token in tokens: + next_values: list[Any] = [] + for value in current: + next_values.extend(_extract_token(value, token)) + current = next_values + if not current: + break + return current + + +def _extract_token(value: Any, token: PathToken) -> list[Any]: + if token.kind == "key": + if isinstance(value, dict) and token.value in value: + return [value[token.value]] + return [] + if token.kind == "index": + if isinstance(value, list) and value: + index = int(token.value) + if -len(value) <= index < len(value): + return [value[index]] + return [] + if token.kind == "wildcard": + if isinstance(value, dict): + return list(value.values()) + if isinstance(value, list): + return list(value) + return [] + raise FieldCachePathError(f"Unknown path token: {token}") + + +def inject_path(payload: Any, path: str, injected_value: Any, *, when_missing_only: bool = False, insert: bool = False) -> bool: + """Inject a value at a simple path, creating dict containers as needed. + + Wildcard injection is rejected because creating multiple branches can be + ambiguous and provider-specific. List indexes must already exist; this keeps + mutation predictable for message-tail use cases like `messages[-1].field`. + `insert=True` is intentionally limited to final list-index tokens so rules + cannot accidentally create provider-specific list structures. + """ + + tokens = parse_path(path) + if any(token.kind == "wildcard" for token in tokens): + raise FieldCachePathError("Wildcard injection is not supported") + if not isinstance(payload, dict): + raise FieldCachePathError("Field-cache injection root must be a dict") + current = payload + for index, token in enumerate(tokens): + is_last = index == len(tokens) - 1 + if token.kind == "key": + if not isinstance(current, dict): + raise FieldCachePathError(f"Cannot inject key {token.value!r} into non-dict value") + key = str(token.value) + if is_last: + if insert: + raise FieldCachePathError("insert=True requires a final list index token") + if when_missing_only and key in current: + return False + changed = current.get(key) != injected_value + current[key] = injected_value + return changed + if key not in current or current[key] is None: + current[key] = [] if tokens[index + 1].kind == "index" else {} + current = current[key] + continue + if token.kind == "index": + if not isinstance(current, list) or (not current and not (is_last and insert)): + raise FieldCachePathError("Cannot inject into missing or empty list") + list_index = int(token.value) + if is_last and insert: + if list_index < 0: + list_index = max(0, len(current) + list_index) + if not (0 <= list_index <= len(current)): + raise FieldCachePathError(f"List index out of range for field-cache insertion: {list_index}") + if when_missing_only: + return False + current.insert(list_index, injected_value) + return True + if not (-len(current) <= list_index < len(current)): + raise FieldCachePathError(f"List index out of range for field-cache injection: {list_index}") + if is_last: + if when_missing_only and current[list_index] is not None: + return False + changed = current[list_index] != injected_value + current[list_index] = injected_value + return changed + current = current[list_index] + continue + raise FieldCachePathError(f"Unsupported injection token: {token}") + return False diff --git a/src/rotator_library/field_cache/store.py b/src/rotator_library/field_cache/store.py new file mode 100644 index 000000000..659e0c35f --- /dev/null +++ b/src/rotator_library/field_cache/store.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Async stores for field-cache values.""" + +from __future__ import annotations + +import json +import time +from copy import deepcopy +from typing import Any, Callable, Protocol + +from ..protocols import serialize_value + +_TTL_ENVELOPE_MARKER = "__llm_proxy_field_cache_ttl_v1__" + + +class FieldCacheStore(Protocol): + """Minimal async store interface used by `FieldCacheEngine`.""" + + async def get(self, key: str) -> Any: ... + + async def set(self, key: str, value: Any, *, ttl_seconds: int | None = None) -> None: ... + + async def append(self, key: str, values: list[Any], *, ttl_seconds: int | None = None) -> list[Any]: ... + + async def clear(self) -> None: ... + + +class InMemoryFieldCacheStore: + """Simple process-local store with optional per-key TTL. + + This is the default native runtime store. It intentionally persists only for + the Python process and avoids a database while still preserving protocol + state across requests handled by the same executor instance. + """ + + def __init__(self, *, clock: Callable[[], float] | None = None) -> None: + self._values: dict[str, Any] = {} + self._expires_at: dict[str, float] = {} + self._clock = clock or time.monotonic + + async def get(self, key: str) -> Any: + if self._is_expired(key): + self._values.pop(key, None) + self._expires_at.pop(key, None) + return None + return deepcopy(self._values.get(key)) + + async def set(self, key: str, value: Any, *, ttl_seconds: int | None = None) -> None: + self._values[key] = deepcopy(value) + self._set_expiry(key, ttl_seconds) + + async def append(self, key: str, values: list[Any], *, ttl_seconds: int | None = None) -> list[Any]: + current = await self.get(key) + if not isinstance(current, list): + current = [] + current = deepcopy(current) + deepcopy(values) + self._values[key] = current + self._set_expiry(key, ttl_seconds) + return deepcopy(current) + + async def clear(self) -> None: + self._values.clear() + self._expires_at.clear() + + def _set_expiry(self, key: str, ttl_seconds: int | None) -> None: + if ttl_seconds is None or ttl_seconds <= 0: + self._expires_at.pop(key, None) + return + self._expires_at[key] = self._clock() + ttl_seconds + + def _is_expired(self, key: str) -> bool: + expires_at = self._expires_at.get(key) + return expires_at is not None and expires_at <= self._clock() + + +class ProviderCacheFieldStore: + """Field-cache store backed by an injected `ProviderCache` instance. + + The wrapper does not create `ProviderCache` itself because that class starts + background async tasks during initialization. Providers or later config code + should own that lifecycle and pass an initialized cache here. + """ + + def __init__(self, provider_cache: Any) -> None: + self._cache = provider_cache + + async def get(self, key: str) -> Any: + raw = await self._cache.retrieve_async(key) + if raw is None: + return None + value = json.loads(raw) + if isinstance(value, dict) and value.get(_TTL_ENVELOPE_MARKER) is True: + expires_at = value.get("expires_at") + if isinstance(expires_at, (int, float)) and expires_at <= time.time(): + return None + return value.get("value") + return value + + async def set(self, key: str, value: Any, *, ttl_seconds: int | None = None) -> None: + payload = serialize_value(value) + if ttl_seconds is not None and ttl_seconds > 0: + payload = {_TTL_ENVELOPE_MARKER: True, "expires_at": time.time() + ttl_seconds, "value": payload} + await self._cache.store_async(key, json.dumps(payload, ensure_ascii=False)) + + async def append(self, key: str, values: list[Any], *, ttl_seconds: int | None = None) -> list[Any]: + current = await self.get(key) + if not isinstance(current, list): + current = [] + current = current + serialize_value(values) + await self.set(key, current, ttl_seconds=ttl_seconds) + return current + + async def clear(self) -> None: + await self._cache.clear() diff --git a/src/rotator_library/field_cache/types.py b/src/rotator_library/field_cache/types.py new file mode 100644 index 000000000..1eeb8152f --- /dev/null +++ b/src/rotator_library/field_cache/types.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Data types for provider field-cache rules.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal, Optional + +FieldCacheSource = Literal[ + "request", + "response", + "stream_event", + "unified_request", + "unified_response", + "unified_stream_event", +] +FieldCacheTarget = Literal["request", "unified_request", "metadata"] +FieldCacheMode = Literal["last", "all", "last_user_turn", "last_assistant_turn", "per_tool_call"] +FieldCacheScope = Literal["provider", "model", "credential", "session", "conversation", "classifier"] + +DEFAULT_SCOPE: tuple[FieldCacheScope, ...] = ("provider", "model", "classifier", "session") +_VALID_SOURCES = {"request", "response", "stream_event", "unified_request", "unified_response", "unified_stream_event"} +_VALID_TARGETS = {"request", "unified_request", "metadata"} +_VALID_SCOPES = {"provider", "model", "credential", "session", "conversation", "classifier"} + + +@dataclass(frozen=True) +class FieldCacheInjection: + """Where and how a cached value should be injected into a later payload.""" + + target: FieldCacheTarget + path: str + when_missing_only: bool = False + insert: bool = False + as_list: bool = False + + +@dataclass(frozen=True) +class FieldCacheRule: + """Declarative rule for extracting and re-injecting provider state. + + Rules are protocol/provider extensions, not session-affinity logic. Session + tracking decides continuity; field-cache rules preserve protocol state such + as reasoning content, thought signatures, prompt cache keys, and response IDs. + """ + + name: str + source: FieldCacheSource + path: str + mode: FieldCacheMode = "last" + scope: tuple[FieldCacheScope, ...] = DEFAULT_SCOPE + inject: Optional[FieldCacheInjection] = None + enabled: bool = True + ttl_seconds: Optional[int] = None + metadata: dict[str, Any] = field(default_factory=dict) + allow_missing_session: bool = False + + def __post_init__(self) -> None: + if not self.name or any(char in self.name for char in "/\\:"): + raise ValueError("FieldCacheRule.name must be non-empty and filesystem-safe") + if self.mode not in {"last", "all", "last_user_turn", "last_assistant_turn", "per_tool_call"}: + raise ValueError(f"Unsupported field-cache mode: {self.mode}") + if self.source not in _VALID_SOURCES: + raise ValueError(f"Unsupported field-cache source: {self.source}") + if self.inject and self.inject.target not in _VALID_TARGETS: + raise ValueError(f"Unsupported field-cache injection target: {self.inject.target}") + if not self.scope: + raise ValueError("FieldCacheRule.scope must contain at least one dimension") + invalid_scopes = [scope for scope in self.scope if scope not in _VALID_SCOPES] + if invalid_scopes: + raise ValueError(f"Unsupported field-cache scope: {invalid_scopes[0]}") + + +@dataclass(frozen=True) +class FieldCacheContext: + """Scope values used to isolate cached provider fields.""" + + provider: Optional[str] = None + model: Optional[str] = None + credential_id: Optional[str] = None + session_id: Optional[str] = None + conversation_id: Optional[str] = None + classifier: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + + def value_for_scope(self, scope: FieldCacheScope) -> Optional[str]: + if scope == "provider": + return self.provider + if scope == "model": + return self.model + if scope == "credential": + return self.credential_id + if scope == "session": + return self.session_id + if scope == "conversation": + return self.conversation_id + if scope == "classifier": + return self.classifier + raise ValueError(f"Unsupported field-cache scope: {scope}") diff --git a/src/rotator_library/native_provider/__init__.py b/src/rotator_library/native_provider/__init__.py new file mode 100644 index 000000000..dc63572ad --- /dev/null +++ b/src/rotator_library/native_provider/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Opt-in native provider execution helpers.""" + +from .context import NativeProviderContext +from .executor import NativeProviderExecutor +from .http import NativeHTTPTransport +from .streaming import stream_event_payload + +__all__ = ["NativeHTTPTransport", "NativeProviderContext", "NativeProviderExecutor", "stream_event_payload"] diff --git a/src/rotator_library/native_provider/context.py b/src/rotator_library/native_provider/context.py new file mode 100644 index 000000000..cbfb91012 --- /dev/null +++ b/src/rotator_library/native_provider/context.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Context objects for native provider execution.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from ..adapters import AdapterContext +from ..field_cache import FieldCacheContext, FieldCacheRule +from ..protocols import ProtocolContext + + +@dataclass +class NativeProviderContext: + """Metadata needed to execute a provider through native protocol helpers. + + This context intentionally mirrors the trace, adapter, protocol, and + field-cache contexts so provider-native execution can remain opt-in and + testable without changing the existing LiteLLM-backed path. + """ + + provider: str + model: str + protocol_name: str + endpoint: str + operation: str = "chat" + client_protocol_name: Optional[str] = None + headers: dict[str, str] = field(default_factory=dict) + credential_id: Optional[str] = None + session_id: Optional[str] = None + scope_key: Optional[str] = None + classifier: Optional[str] = None + transport: str = "http" + adapter_names: tuple[str, ...] = () + adapter_config: dict[str, dict[str, Any]] = field(default_factory=dict) + field_cache_rules: tuple[FieldCacheRule, ...] = () + transaction_logger: Optional[Any] = None + metadata: dict[str, Any] = field(default_factory=dict) + + def protocol_context(self, *, target_protocol: Optional[str] = None) -> ProtocolContext: + """Build a protocol context for parse/build/format passes.""" + + return ProtocolContext( + provider=self.provider, + model=self.model, + source_protocol=self.protocol_name, + target_protocol=target_protocol or self.client_protocol_name or self.protocol_name, + session_id=self.session_id, + credential_stable_id=self.credential_id, + transport=self.transport, + provider_options={"operation": self.operation}, + metadata={"operation": self.operation, **dict(self.metadata)}, + ) + + def adapter_context(self) -> AdapterContext: + """Build an adapter context for provider payload adapters.""" + + return AdapterContext( + provider=self.provider, + model=self.model, + protocol=self.protocol_name, + credential_id=self.credential_id, + session_id=self.session_id, + scope_key=self.scope_key, + classifier=self.classifier, + transport=self.transport, + metadata={"operation": self.operation, **dict(self.metadata)}, + adapter_config=dict(self.adapter_config), + transaction_logger=self.transaction_logger, + ) + + def field_cache_context(self) -> FieldCacheContext: + """Build a field-cache context with provider isolation metadata.""" + + return FieldCacheContext( + provider=self.provider, + model=self.model, + credential_id=self.credential_id, + session_id=self.session_id, + conversation_id=self.scope_key, + classifier=self.classifier, + metadata={"operation": self.operation, **dict(self.metadata)}, + ) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py new file mode 100644 index 000000000..6b234048a --- /dev/null +++ b/src/rotator_library/native_provider/executor.py @@ -0,0 +1,476 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Opt-in executor for provider-native protocol calls.""" + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import replace +from typing import Any, AsyncGenerator + +from ..adapters import get_adapter, run_adapter_chain +from ..field_cache import FieldCacheEngine, InMemoryFieldCacheStore +from ..field_cache.paths import FieldCachePathError, PathToken, parse_path +from ..protocols import ProtocolError, get_protocol, serialize_value +from ..protocols.types import UnifiedRequest +from ..transform_trace import REDACTED +from ..usage.accounting import extract_usage_record +from ..usage.costs import CostCalculator +from .context import NativeProviderContext +from .http import NativeHTTPTransport +from .streaming import stream_event_payload + + +class NativeProviderExecutor: + """Run one native provider request through protocol/adapter/cache passes. + + The default field-cache store is process-local per executor. That preserves + provider protocol state across native requests without adding a database; + production callers can still inject a persistent store when needed. + """ + + def __init__(self, *, field_cache_store: Any = None) -> None: + self.field_cache_store = field_cache_store or InMemoryFieldCacheStore() + + async def execute(self, raw_request: dict[str, Any], context: NativeProviderContext, transport: NativeHTTPTransport) -> dict[str, Any]: + """Execute a non-streaming native provider request.""" + + logger = context.transaction_logger + protocol = get_protocol(context.protocol_name) + self._ensure_supported_operation(protocol, context) + self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") + try: + self._trace(context, "raw_native_client_request", raw_request, direction="request", stage="client") + cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) + context = await self._inject_metadata(context, cache_engine) + protocol_context = context.protocol_context() + unified_request = protocol.parse_request(raw_request, protocol_context) + self._trace(context, "parsed_native_unified_request", unified_request, direction="request", stage="protocol") + await cache_engine.extract("unified_request", serialize_value(unified_request), context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_unified_request_field_cache_extraction", {"source": "unified_request"}, direction="request", stage="adapter", snapshot=False) + unified_request = await self._inject_unified_request(unified_request, context, cache_engine) + provider_request = protocol.build_request(unified_request, protocol_context) + self._trace(context, "built_native_provider_request", provider_request, direction="request", stage="protocol") + adapters = [get_adapter(name) for name in context.adapter_names] + adapter_context = context.adapter_context() + adapter_context.transaction_logger = None + provider_request = await run_adapter_chain(adapters, provider_request, adapter_context, stage="request") + self._trace(context, "after_request_adapter_chain", provider_request, direction="request", stage="adapter") + provider_request, _ = await cache_engine.inject( + "request", + provider_request, + context.field_cache_context(), + transaction_logger=logger, + ) + self._trace(context, "after_field_cache_injection", provider_request, direction="request", stage="adapter") + await cache_engine.extract("request", provider_request, context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_request_field_cache_extraction", {"source": "request"}, direction="request", stage="adapter", snapshot=False) + self._trace(context, "native_provider_request", provider_request, direction="request", stage="provider") + raw_response = await transport.post_json(context.endpoint, headers=context.headers, payload=provider_request) + self._trace(context, "raw_native_provider_response", raw_response, direction="response", stage="provider") + await cache_engine.extract("response", raw_response, context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_response_field_cache_extraction", {"source": "response", "payload": "raw_provider_response"}, direction="response", stage="adapter", snapshot=False) + unified_response = protocol.parse_response(raw_response, protocol_context) + self._trace(context, "parsed_native_unified_response", unified_response, direction="response", stage="protocol") + await cache_engine.extract("unified_response", serialize_value(unified_response), context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_unified_response_field_cache_extraction", {"source": "unified_response"}, direction="response", stage="adapter", snapshot=False) + response_protocol = get_protocol(context.client_protocol_name) if context.client_protocol_name else protocol + response_context = context.protocol_context(target_protocol=response_protocol.name) + self._trace(context, "native_response_protocol_selected", {"protocol": response_protocol.name}, direction="metadata", stage="protocol", snapshot=False) + provider_response = response_protocol.format_response(unified_response, response_context) + self._trace(context, "formatted_native_response", provider_response, direction="response", stage="protocol") + adapter_context = context.adapter_context() + adapter_context.transaction_logger = None + provider_response = await run_adapter_chain(adapters, provider_response, adapter_context, stage="response") + self._trace(context, "after_response_adapter_chain", provider_response, direction="response", stage="adapter") + usage_record = extract_usage_record( + provider_response, + provider=context.provider, + model=context.model, + source="native_provider_response", + ) + raw_usage_record = extract_usage_record( + raw_response, + provider=context.provider, + model=context.model, + source="native_provider_raw_response", + ) + if usage_record.provider_reported_cost is None and raw_usage_record.provider_reported_cost is not None: + usage_record = replace( + usage_record, + provider_reported_cost=raw_usage_record.provider_reported_cost, + cost_currency=raw_usage_record.cost_currency, + cost_source=raw_usage_record.cost_source, + ) + cost_breakdown = CostCalculator().calculate(usage_record, model=context.model, provider=context.provider) + self._trace( + context, + "usage_accounting_summary", + {"usage": usage_record.to_dict(), "cost": cost_breakdown.to_dict()}, + direction="metadata", + stage="final", + snapshot=False, + ) + self._trace(context, "final_client_response", provider_response, direction="response", stage="final") + return provider_response + except Exception as exc: + if logger: + logger.log_transform_error( + "native_provider_execute", + exc, + payload=raw_request, + stage="provider", + protocol=context.protocol_name, + metadata={"provider": context.provider, "model": context.model}, + ) + raise + + async def stream(self, raw_request: dict[str, Any], context: NativeProviderContext, transport: NativeHTTPTransport) -> AsyncGenerator[Any, None]: + """Execute a streaming native provider request and yield client events.""" + + logger = context.transaction_logger + protocol = get_protocol(context.protocol_name) + self._ensure_supported_operation(protocol, context) + self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") + try: + self._trace(context, "raw_native_client_request", raw_request, direction="request", stage="client") + cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) + context = await self._inject_metadata(context, cache_engine) + protocol_context = context.protocol_context() + request_payload = dict(raw_request) + request_payload["stream"] = True + unified_request = protocol.parse_request(request_payload, protocol_context) + self._trace(context, "parsed_native_unified_request", unified_request, direction="request", stage="protocol") + await cache_engine.extract("unified_request", serialize_value(unified_request), context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_unified_request_field_cache_extraction", {"source": "unified_request"}, direction="request", stage="adapter", snapshot=False) + unified_request = await self._inject_unified_request(unified_request, context, cache_engine) + provider_request = protocol.build_request(unified_request, protocol_context) + self._trace(context, "built_native_provider_request", provider_request, direction="request", stage="protocol") + adapters = [get_adapter(name) for name in context.adapter_names] + adapter_context = context.adapter_context() + adapter_context.transaction_logger = None + provider_request = await run_adapter_chain(adapters, provider_request, adapter_context, stage="request") + self._trace(context, "after_request_adapter_chain", provider_request, direction="request", stage="adapter") + provider_request, _ = await cache_engine.inject( + "request", + provider_request, + context.field_cache_context(), + transaction_logger=logger, + ) + self._trace(context, "after_field_cache_injection", provider_request, direction="request", stage="adapter") + await cache_engine.extract("request", provider_request, context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_request_field_cache_extraction", {"source": "request"}, direction="request", stage="adapter", snapshot=False) + self._trace(context, "native_provider_stream_request", provider_request, direction="request", stage="provider") + usage_record = extract_usage_record(None, provider=context.provider, model=context.model, source="native_provider_stream") + async for raw_chunk in transport.stream_json_lines(context.endpoint, headers=context.headers, payload=provider_request): + self._trace(context, "raw_native_provider_stream_chunk", raw_chunk, direction="stream", stage="provider") + event = protocol.parse_stream_event(raw_chunk, protocol_context) + self._trace(context, "parsed_native_unified_stream_event", event, direction="stream", stage="protocol", snapshot=False) + usage_record = _merge_stream_usage_records( + usage_record, + extract_usage_record(serialize_value(event), provider=context.provider, model=context.model, source="native_stream_event"), + extract_usage_record(raw_chunk, provider=context.provider, model=context.model, source="native_raw_stream_event"), + ) + if event.type == "done": + event_payload = stream_event_payload(event) + self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") + break + adapter_context = context.adapter_context() + # Native stream traces apply field-cache path redaction below. + # Suppress generic adapter-chain snapshots here so provider state + # cannot leak before rule-aware redaction runs. + adapter_context.transaction_logger = None + event = await run_adapter_chain(adapters, event, adapter_context, stage="stream_event") + self._trace(context, "after_stream_event_adapter_chain", event, direction="stream", stage="adapter", snapshot=False) + await cache_engine.extract("unified_stream_event", serialize_value(event), context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_unified_stream_event_field_cache_extraction", {"source": "unified_stream_event"}, direction="stream", stage="adapter", snapshot=False) + event_payload = stream_event_payload(event) + self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") + await cache_engine.extract("stream_event", event_payload, context.field_cache_context(), transaction_logger=logger) + self._trace( + context, + "after_field_cache_stream_extraction", + {"source": "stream_event"}, + direction="stream", + stage="adapter", + snapshot=False, + ) + response_protocol = get_protocol(context.client_protocol_name) if context.client_protocol_name else protocol + formatted = response_protocol.format_stream_event(event, context.protocol_context(target_protocol=response_protocol.name)) + self._trace(context, "formatted_client_stream_event", formatted, direction="stream", stage="final", snapshot=False) + yield formatted + cost_breakdown = CostCalculator().calculate(usage_record, model=context.model, provider=context.provider) + self._trace( + context, + "usage_accounting_summary", + {"usage": usage_record.to_dict(), "cost": cost_breakdown.to_dict()}, + direction="metadata", + stage="final", + snapshot=False, + ) + except Exception as exc: + if logger: + logger.log_transform_error( + "native_provider_stream", + exc, + payload=raw_request, + stage="provider", + protocol=context.protocol_name, + transport=context.transport, + metadata={"provider": context.provider, "model": context.model}, + ) + raise + + @staticmethod + def _trace( + context: NativeProviderContext, + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + metadata: dict[str, Any] | None = None, + snapshot: bool = True, + ) -> None: + if not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + pass_name, + _redact_field_cache_paths(data, context, direction), + direction=direction, + stage=stage, + protocol=context.protocol_name, + credential_id=context.credential_id, + transport=context.transport, + metadata={ + "provider": context.provider, + "model": context.model, + "session_id": context.session_id, + "scope_key": context.scope_key, + "classifier": context.classifier, + **(metadata or {}), + }, + snapshot=snapshot, + ) + + @staticmethod + def _ensure_supported_operation(protocol: Any, context: NativeProviderContext) -> None: + """Fail before transport when provider and protocol operations disagree.""" + + if protocol.supports_operation(context.operation): + return + raise ProtocolError( + f"provider {context.provider} requested unsupported operation {context.operation!r}", + protocol=protocol.name, + pass_name="native_operation_check", + payload={"provider": context.provider, "model": context.model, "operation": context.operation}, + ) + + async def _inject_metadata(self, context: NativeProviderContext, cache_engine: FieldCacheEngine) -> NativeProviderContext: + """Inject cached metadata before protocol/adapter contexts are built.""" + + metadata, operations = await cache_engine.inject( + "metadata", + dict(context.metadata), + context.field_cache_context(), + transaction_logger=context.transaction_logger, + ) + if operations: + self._trace(context, "after_metadata_field_cache_injection", metadata, direction="metadata", stage="adapter", snapshot=False) + if metadata == context.metadata: + return context + return replace(context, metadata=metadata) + + async def _inject_unified_request( + self, + unified_request: UnifiedRequest, + context: NativeProviderContext, + cache_engine: FieldCacheEngine, + ) -> UnifiedRequest: + """Inject cached values into a serialized unified request and hydrate it.""" + + serialized = serialize_value(unified_request) + injected, operations = await cache_engine.inject( + "unified_request", + serialized, + context.field_cache_context(), + transaction_logger=context.transaction_logger, + ) + if operations: + self._trace(context, "after_unified_request_field_cache_injection", injected, direction="request", stage="adapter") + if injected == serialized: + return unified_request + return _hydrate_unified_request(unified_request, injected) + + +def _merge_stream_usage_records(base: Any, event_record: Any, raw_record: Any) -> Any: + """Merge native stream usage, preserving raw provider cost when needed.""" + + selected = event_record if _usage_record_has_token_values(event_record) else base + if not _usage_record_has_token_values(selected) and _usage_record_has_token_values(raw_record): + selected = raw_record + if selected.provider_reported_cost is None and base.provider_reported_cost is not None: + selected = replace( + selected, + provider_reported_cost=base.provider_reported_cost, + cost_currency=base.cost_currency, + cost_source=base.cost_source, + ) + if selected.provider_reported_cost is None and raw_record.provider_reported_cost is not None: + selected = replace( + selected, + provider_reported_cost=raw_record.provider_reported_cost, + cost_currency=raw_record.cost_currency, + cost_source=raw_record.cost_source, + ) + return selected + + +def _usage_record_has_values(record: Any) -> bool: + return bool( + record.input_tokens + or record.completion_tokens + or record.reasoning_tokens + or record.cache_read_tokens + or record.cache_write_tokens + or record.raw_total_tokens + or record.provider_reported_cost is not None + ) + + +def _usage_record_has_token_values(record: Any) -> bool: + return bool( + record.input_tokens + or record.completion_tokens + or record.reasoning_tokens + or record.cache_read_tokens + or record.cache_write_tokens + or record.raw_total_tokens + ) + + +def _redact_field_cache_paths(data: Any, context: NativeProviderContext, direction: str) -> Any: + """Redact configured cache paths before broad native payload traces. + + Field-cache rules can inject opaque state under arbitrary configured keys, + so key-based trace redaction is not enough. Native traces apply the active + rules' source and injection paths to a copy before handing data to the normal + transaction trace sanitizer. + """ + + if not context.field_cache_rules: + return data + redacted = serialize_value(deepcopy(data)) + for rule in context.field_cache_rules: + paths: list[str] = [] + if direction == "request" and rule.inject: + paths.append(rule.inject.path) + if direction == "metadata" and rule.inject and rule.inject.target == "metadata": + paths.append(rule.inject.path) + if direction in {"response", "stream"}: + paths.append(rule.path) + for path in _trace_redaction_paths(paths, direction=direction): + try: + tokens = parse_path(path) + _redact_path(redacted, tokens) + _redact_leaf_key(redacted, tokens) + except (FieldCachePathError, TypeError, ValueError): + continue + return redacted + + +def _hydrate_unified_request(original: UnifiedRequest, injected: Any) -> UnifiedRequest: + """Hydrate common unified-request fields after serialized cache injection. + + Field-cache path injection operates on JSON-like dictionaries. Protocol + builders still expect `UnifiedRequest`, so this helper copies supported + top-level fields back onto the dataclass while leaving complex message/tool + objects untouched unless providers handle them through provider-payload rules. + """ + + if not isinstance(injected, dict): + return original + safe_fields = { + "operation", + "model", + "stream", + "input", + "modalities", + "files", + "generation_params", + "response_format", + "previous_response_id", + "metadata", + "raw", + "extra", + } + values = {field_name: getattr(original, field_name) for field_name in UnifiedRequest._fields} + for field_name in safe_fields: + if field_name in injected: + values[field_name] = injected[field_name] + return UnifiedRequest(**values) + + +def _trace_redaction_paths(paths: list[str], *, direction: str) -> list[str]: + """Return configured paths plus raw-stream envelope fallbacks for traces.""" + + expanded: list[str] = [] + for path in paths: + expanded.append(path) + if direction == "stream" and path.startswith("raw."): + expanded.append(path[4:]) + return expanded + + +def _redact_path(value: Any, tokens: tuple[PathToken, ...]) -> None: + if not tokens: + return + token = tokens[0] + rest = tokens[1:] + if token.kind == "key": + if isinstance(value, dict) and token.value in value: + if rest: + _redact_path(value[token.value], rest) + else: + value[token.value] = REDACTED + return + if token.kind == "index": + if isinstance(value, list) and value: + index = int(token.value) + if -len(value) <= index < len(value): + if rest: + _redact_path(value[index], rest) + else: + value[index] = REDACTED + return + if token.kind == "wildcard": + if isinstance(value, dict): + for key in list(value.keys()): + if rest: + _redact_path(value[key], rest) + else: + value[key] = REDACTED + elif isinstance(value, list): + for index, item in enumerate(value): + if rest: + _redact_path(item, rest) + else: + value[index] = REDACTED + + +def _redact_leaf_key(value: Any, tokens: tuple[PathToken, ...]) -> None: + """Redact the configured terminal key wherever stream traces duplicate it.""" + + leaf = next((token.value for token in reversed(tokens) if token.kind == "key"), None) + if not leaf: + return + if isinstance(value, dict): + for key, item in list(value.items()): + if key == leaf: + value[key] = REDACTED + else: + _redact_leaf_key(item, tokens) + elif isinstance(value, list): + for item in value: + _redact_leaf_key(item, tokens) diff --git a/src/rotator_library/native_provider/http.py b/src/rotator_library/native_provider/http.py new file mode 100644 index 000000000..a5f7118e7 --- /dev/null +++ b/src/rotator_library/native_provider/http.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Small HTTP transport wrapper for native provider calls.""" + +from __future__ import annotations + +import json +from typing import Any, AsyncIterator + + +class NativeHTTPTransport: + """Execute provider-native JSON HTTP requests through an injected client.""" + + def __init__(self, client: Any) -> None: + self.client = client + + async def post_json(self, endpoint: str, *, headers: dict[str, str], payload: dict[str, Any]) -> Any: + """POST JSON and return a decoded response body. + + The wrapper keeps HTTP behavior easy to mock. It does not own retries or + credential rotation; those remain in the existing executor/usage layer. + """ + + response = await self.client.post(endpoint, headers=headers, json=payload) + if hasattr(response, "raise_for_status"): + response.raise_for_status() + if hasattr(response, "json"): + return response.json() + return response + + async def stream_json_lines(self, endpoint: str, *, headers: dict[str, str], payload: dict[str, Any]) -> AsyncIterator[Any]: + """Yield provider stream chunks from an injected streaming-capable client. + + Provider-specific test clients can still expose `stream_json_lines()`. + When a normal `httpx.AsyncClient`-style object is injected, this method + now uses `client.stream()` directly so native streaming has a real HTTP + seam without enabling any provider that has not opted in safely. + """ + + if hasattr(self.client, "stream_json_lines"): + async for chunk in self.client.stream_json_lines(endpoint, headers=headers, json=payload): + yield chunk + return + if hasattr(self.client, "stream"): + async with self.client.stream("POST", endpoint, headers=headers, json=payload) as response: + if hasattr(response, "raise_for_status"): + response.raise_for_status() + if hasattr(response, "aiter_lines"): + async for line in response.aiter_lines(): + parsed = _parse_stream_line(line) + if parsed is not None: + yield parsed + return + if hasattr(response, "aiter_bytes"): + buffer = "" + async for chunk in response.aiter_bytes(): + text = chunk.decode("utf-8", errors="replace") if isinstance(chunk, (bytes, bytearray)) else str(chunk) + buffer += text + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + parsed = _parse_stream_line(line) + if parsed is not None: + yield parsed + parsed = _parse_stream_line(buffer) + if parsed is not None: + yield parsed + return + raise NotImplementedError("Injected native HTTP client does not expose streaming support") + + +def _parse_stream_line(line: Any) -> Any: + """Parse one HTTP streaming line while preserving provider sentinels.""" + + if line is None: + return None + text = line.decode("utf-8", errors="replace") if isinstance(line, (bytes, bytearray)) else str(line) + text = text.strip() + if not text: + return None + if text.startswith(":"): + return None + if text.startswith("data:"): + text = text[len("data:") :].strip() + if text == "[DONE]": + return "[DONE]" + try: + return json.loads(text) + except json.JSONDecodeError: + return text diff --git a/src/rotator_library/native_provider/streaming.py b/src/rotator_library/native_provider/streaming.py new file mode 100644 index 000000000..b043a2338 --- /dev/null +++ b/src/rotator_library/native_provider/streaming.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Streaming helpers for provider-native execution.""" + +from __future__ import annotations + +from typing import Any + +from ..streaming import StreamEvent, stream_event_from_sse_chunk + + +def stream_event_payload(event: Any) -> Any: + """Return a JSON-safe payload for stream field-cache and trace passes.""" + + if hasattr(event, "to_dict"): + return event.to_dict() + return event + + +def provider_supports_native_streaming(provider: Any, *, model: str = "", operation: str = "chat") -> bool: + """Return explicit provider native-streaming support. + + Providers must opt in. Missing methods and exceptions fail closed so routed + streaming can fallback before output rather than accidentally claiming native + streaming support. + """ + + method = getattr(provider, "supports_native_streaming", None) + if not method: + return False + try: + return bool(method(model=model, operation=operation)) + except TypeError: + return bool(method(model, operation)) + except Exception: + return False + + +def native_stream_event_from_formatted(formatted: Any, *, protocol: str = "openai_chat") -> StreamEvent: + """Convert a formatted native stream chunk into the common event seam.""" + + if isinstance(formatted, str): + return stream_event_from_sse_chunk(formatted, protocol=protocol) + return StreamEvent("parsed_chunk", protocol=protocol, data=formatted, raw=formatted) diff --git a/src/rotator_library/protocols/__init__.py b/src/rotator_library/protocols/__init__.py new file mode 100644 index 000000000..f528f2046 --- /dev/null +++ b/src/rotator_library/protocols/__init__.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Native protocol adapters for provider-independent request handling. + +Importing this package auto-discovers built-in protocol adapters, similar to the +provider plugin system. Runtime execution is not changed by Phase 1; the package +only exposes reusable protocol primitives for later phases. +""" + +from .base import ProtocolAdapter +from .operation import ( + OPERATION_AUDIO_TRANSCRIPTION, + OPERATION_AUDIO_TRANSLATION, + OPERATION_CHAT, + OPERATION_COUNT_TOKENS, + OPERATION_EMBEDDINGS, + OPERATION_IMAGE_EDIT, + OPERATION_IMAGE_GENERATION, + OPERATION_IMAGE_VARIATION, + OPERATION_MCP, + OPERATION_MESSAGES, + OPERATION_OLLAMA_CHAT, + OPERATION_OLLAMA_GENERATE, + OPERATION_RESPONSES, + OPERATION_SPEECH, + OPERATION_UNKNOWN, + normalize_operation, +) +from .registry import ( + PROTOCOL_ALIASES, + PROTOCOL_PLUGINS, + get_protocol, + get_protocol_class, + list_protocols, + register_protocol, + resolve_protocol_name, +) +from .types import ( + ContentBlock, + CostDetails, + ProtocolContext, + ProtocolError, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, + first_text, + serialize_value, + text_blocks, +) + +__all__ = [ + "PROTOCOL_ALIASES", + "PROTOCOL_PLUGINS", + "ContentBlock", + "CostDetails", + "OPERATION_AUDIO_TRANSCRIPTION", + "OPERATION_AUDIO_TRANSLATION", + "OPERATION_CHAT", + "OPERATION_COUNT_TOKENS", + "OPERATION_EMBEDDINGS", + "OPERATION_IMAGE_EDIT", + "OPERATION_IMAGE_GENERATION", + "OPERATION_IMAGE_VARIATION", + "OPERATION_MCP", + "OPERATION_MESSAGES", + "OPERATION_OLLAMA_CHAT", + "OPERATION_OLLAMA_GENERATE", + "OPERATION_RESPONSES", + "OPERATION_SPEECH", + "OPERATION_UNKNOWN", + "ProtocolAdapter", + "ProtocolContext", + "ProtocolError", + "ReasoningBlock", + "ToolCall", + "ToolDefinition", + "ToolResult", + "UnifiedMessage", + "UnifiedRequest", + "UnifiedResponse", + "UnifiedStreamEvent", + "Usage", + "first_text", + "get_protocol", + "get_protocol_class", + "list_protocols", + "normalize_operation", + "register_protocol", + "resolve_protocol_name", + "serialize_value", + "text_blocks", +] diff --git a/src/rotator_library/protocols/anthropic_messages.py b/src/rotator_library/protocols/anthropic_messages.py new file mode 100644 index 000000000..024bdfae5 --- /dev/null +++ b/src/rotator_library/protocols/anthropic_messages.py @@ -0,0 +1,385 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Anthropic Messages protocol adapter. + +This adapter captures the native Messages shape as a reusable base. The existing +compatibility routes remain active; this module gives future provider-native +execution a loss-conscious parser/builder with thinking and tool block support. +""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar, Iterable + +from .base import ProtocolAdapter +from .operation import OPERATION_COUNT_TOKENS, OPERATION_MESSAGES, OPERATION_UNKNOWN, normalize_operation +from .types import ( + ContentBlock, + ProtocolContext, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, + first_text, + text_blocks, +) + +_GENERATION_PARAMS = { + "max_tokens", + "metadata", + "stop_sequences", + "temperature", + "thinking", + "tool_choice", + "top_k", + "top_p", +} + +_REQUEST_CORE_FIELDS = {"model", "messages", "system", "tools", "stream", *_GENERATION_PARAMS} + + +class AnthropicMessagesProtocol(ProtocolAdapter): + """Adapter for Anthropic Messages requests, responses, and stream events. + + Thinking and redacted-thinking blocks are represented as reasoning blocks so + later field-cache rules can extract signatures without relying on a bespoke + provider implementation. + """ + + name: ClassVar[str] = "anthropic_messages" + aliases: ClassVar[tuple[str, ...]] = ("anthropic", "messages", "claude_messages") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_MESSAGES, OPERATION_COUNT_TOKENS) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + return UnifiedRequest( + operation=_operation_from_context(context, OPERATION_MESSAGES), + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=[self._parse_message(message) for message in request.get("messages") or []], + system=self._parse_system(request.get("system")), + tools=[self._parse_tool_definition(tool) for tool in request.get("tools") or []], + stream=bool(request.get("stream", False)), + generation_params={k: deepcopy(request[k]) for k in _GENERATION_PARAMS if k in request and k != "metadata"}, + metadata=deepcopy(request.get("metadata") or {}), + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": unified_request.model, + "messages": [self._format_message(message) for message in unified_request.messages], + } + system = self._format_system(unified_request.system) + if system is not None: + payload["system"] = system + if unified_request.tools: + payload["tools"] = [self._format_tool_definition(tool) for tool in unified_request.tools] + if unified_request.stream: + payload["stream"] = True + if unified_request.metadata: + payload["metadata"] = deepcopy(unified_request.metadata) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = _as_dict(raw_response) + operation = _response_operation(response, context) + message = UnifiedMessage( + role=str(response.get("role") or "assistant"), + content=self._parse_content(response.get("content")), + raw=deepcopy(response), + extra={"type": response.get("type")}, + ) + self._promote_message_blocks(message) + return UnifiedResponse( + operation=operation, + id=response.get("id"), + model=response.get("model") or getattr(context, "model", None), + messages=[] if operation == OPERATION_COUNT_TOKENS else [message] if response else [], + stop_reason=response.get("stop_reason"), + usage=self.extract_usage(response, context), + metadata={"stop_sequence": response.get("stop_sequence"), "type": response.get("type")}, + raw=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"id", "type", "role", "content", "model", "stop_reason", "stop_sequence", "usage"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if unified_response.operation == OPERATION_COUNT_TOKENS: + usage = unified_response.usage + payload = deepcopy(unified_response.extra) + # Normalized usage wins over raw preserved fields so later adapters + # can correct counts without stale provider keys shadowing them. + payload["input_tokens"] = usage.input_tokens if usage else 0 + return payload + message = unified_response.messages[0] if unified_response.messages else UnifiedMessage(role="assistant") + payload = { + "id": unified_response.id, + "type": unified_response.metadata.get("type", "message"), + "role": message.role, + "content": self._format_content(message.content), + "model": unified_response.model, + "stop_reason": unified_response.stop_reason, + "stop_sequence": unified_response.metadata.get("stop_sequence"), + "usage": self._format_usage(unified_response.usage), + } + payload.update(deepcopy(unified_response.extra)) + return {k: v for k, v in payload.items() if v is not None} + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + event = _decode_sse_data(raw_event) + if event == "[DONE]": + return UnifiedStreamEvent(type="done", operation=OPERATION_MESSAGES, raw=deepcopy(raw_event)) + data = _as_dict(event) + event_type = str(data.get("type") or "chunk") + + if event_type == "error" or data.get("error") is not None: + return UnifiedStreamEvent(type="error", operation=OPERATION_MESSAGES, error=deepcopy(data.get("error", data)), raw=deepcopy(raw_event), extra={"payload": data}) + if event_type == "message_start": + response = self.parse_response(data.get("message") or {}, context) + return UnifiedStreamEvent(type="message_start", operation=OPERATION_MESSAGES, message=response.messages[0] if response.messages else None, usage=response.usage, raw=deepcopy(raw_event), extra={"payload": data}) + if event_type == "message_delta": + return UnifiedStreamEvent(type="message_delta", operation=OPERATION_MESSAGES, usage=self.extract_usage(data.get("usage") or {}, context), raw=deepcopy(raw_event), extra={"payload": data, "stop_reason": (data.get("delta") or {}).get("stop_reason")}) + if event_type in {"content_block_start", "content_block_delta", "content_block_stop"}: + return self._parse_content_stream_event(data, raw_event) + return UnifiedStreamEvent(type=event_type, operation=OPERATION_MESSAGES, raw=deepcopy(raw_event), extra={"payload": data}) + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): + return raw_or_unified.usage + payload = _as_dict(raw_or_unified) + usage = payload.get("usage") if isinstance(payload.get("usage"), dict) else payload + if not isinstance(usage, dict) or not any(k.endswith("tokens") for k in usage): + return None + input_tokens = int(usage.get("input_tokens") or 0) + output_tokens = int(usage.get("output_tokens") or 0) + cache_write = int(usage.get("cache_creation_input_tokens") or 0) + cache_read = int(usage.get("cache_read_input_tokens") or 0) + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=int(usage.get("total_tokens") or input_tokens + output_tokens), + cache_read_tokens=cache_read, + cache_write_tokens=cache_write, + raw=deepcopy(usage), + ) + + def _parse_system(self, system: Any) -> list[ContentBlock]: + if system is None: + return [] + if isinstance(system, str): + return text_blocks(system) + return self._parse_content(system) + + def _format_system(self, blocks: Iterable[ContentBlock]) -> Any: + block_list = list(blocks) + if not block_list: + return None + if all(block.type == "text" and not isinstance(block.raw, dict) and not block.extra for block in block_list): + return first_text(block_list) or "" + return self._format_content(block_list) + + def _parse_message(self, message: dict[str, Any]) -> UnifiedMessage: + payload = dict(message or {}) + unified = UnifiedMessage( + role=str(payload.get("role") or "user"), + content=self._parse_content(payload.get("content")), + raw=deepcopy(message), + extra={k: deepcopy(v) for k, v in payload.items() if k not in {"role", "content"}}, + ) + self._promote_message_blocks(unified) + return unified + + def _format_message(self, message: UnifiedMessage) -> dict[str, Any]: + payload = {"role": message.role, "content": self._format_content(message.content)} + payload.update(deepcopy(message.extra)) + return payload + + def _parse_content(self, content: Any) -> list[ContentBlock]: + if content is None: + return [] + if isinstance(content, str): + return text_blocks(content) + if not isinstance(content, list): + return [ContentBlock(type="unknown", raw=deepcopy(content))] + return [self._parse_content_block(block) for block in content] + + def _parse_content_block(self, block: Any) -> ContentBlock: + if isinstance(block, str): + return ContentBlock(type="text", text=block, raw=block) + if not isinstance(block, dict): + return ContentBlock(type="unknown", raw=deepcopy(block)) + block_type = str(block.get("type") or "text") + if block_type == "text": + return ContentBlock(type="text", text=block.get("text", ""), raw=deepcopy(block)) + if block_type in {"image", "document"}: + return ContentBlock(type=block_type, source=deepcopy(block.get("source")), raw=deepcopy(block), extra=_without(block, {"type", "source"})) + if block_type in {"thinking", "redacted_thinking"}: + reasoning = ReasoningBlock( + type=block_type, + text=block.get("thinking"), + signature=block.get("signature"), + redacted=block_type == "redacted_thinking", + raw=deepcopy(block), + extra=_without(block, {"type", "thinking", "signature"}), + ) + return ContentBlock(type=block_type, reasoning=reasoning, raw=deepcopy(block)) + if block_type == "tool_use": + return ContentBlock( + type="tool_use", + tool_call=ToolCall(id=block.get("id"), name=block.get("name"), arguments=deepcopy(block.get("input")), type="tool_use", raw=deepcopy(block)), + raw=deepcopy(block), + extra=_without(block, {"type", "id", "name", "input"}), + ) + if block_type == "tool_result": + return ContentBlock( + type="tool_result", + tool_result=ToolResult(tool_call_id=block.get("tool_use_id"), content=deepcopy(block.get("content")), is_error=block.get("is_error"), raw=deepcopy(block)), + raw=deepcopy(block), + extra=_without(block, {"type", "tool_use_id", "content", "is_error"}), + ) + return ContentBlock(type=block_type, raw=deepcopy(block), extra=_without(block, {"type"})) + + def _format_content(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: + formatted = [] + for block in blocks: + if block.type == "text": + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": "text"} + payload["type"] = "text" + payload["text"] = block.text or "" + formatted.append(payload) + elif block.reasoning: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": block.reasoning.type} + payload["type"] = block.reasoning.type + if block.reasoning.text is not None: + payload["thinking"] = block.reasoning.text + if block.reasoning.signature is not None: + payload["signature"] = block.reasoning.signature + payload.update(deepcopy(block.reasoning.extra)) + formatted.append(payload) + elif block.tool_call: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": "tool_use"} + payload.update({"type": "tool_use", "id": block.tool_call.id, "name": block.tool_call.name, "input": deepcopy(block.tool_call.arguments)}) + formatted.append(payload) + elif block.tool_result: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": "tool_result"} + payload.update({"type": "tool_result", "tool_use_id": block.tool_result.tool_call_id, "content": deepcopy(block.tool_result.content)}) + if block.tool_result.is_error is not None: + payload["is_error"] = block.tool_result.is_error + formatted.append(payload) + else: + payload = {"type": block.type} + payload.update(deepcopy(block.extra)) + formatted.append(payload) + return formatted + + def _parse_tool_definition(self, tool: dict[str, Any]) -> ToolDefinition: + payload = dict(tool or {}) + return ToolDefinition( + name=str(payload.get("name") or ""), + description=payload.get("description"), + input_schema=deepcopy(payload.get("input_schema") or {}), + type="tool", + extra=_without(payload, {"name", "description", "input_schema"}), + ) + + def _format_tool_definition(self, tool: ToolDefinition) -> dict[str, Any]: + payload = {"name": tool.name, "input_schema": deepcopy(tool.input_schema)} + if tool.description is not None: + payload["description"] = tool.description + payload.update(deepcopy(tool.extra)) + return payload + + def _format_usage(self, usage: Usage | None) -> dict[str, int] | None: + if usage is None: + return None + payload = {"input_tokens": usage.input_tokens, "output_tokens": usage.output_tokens} + if usage.cache_write_tokens: + payload["cache_creation_input_tokens"] = usage.cache_write_tokens + if usage.cache_read_tokens: + payload["cache_read_input_tokens"] = usage.cache_read_tokens + return payload + + def _promote_message_blocks(self, message: UnifiedMessage) -> None: + for block in message.content: + if block.tool_call: + message.tool_calls.append(block.tool_call) + if block.reasoning: + message.reasoning.append(block.reasoning) + + def _parse_content_stream_event(self, data: dict[str, Any], raw_event: Any) -> UnifiedStreamEvent: + block = data.get("content_block") if isinstance(data.get("content_block"), dict) else None + delta = data.get("delta") if isinstance(data.get("delta"), dict) else None + content_block = None + if block: + content_block = self._parse_content_block(block) + elif delta: + delta_type = delta.get("type") + if delta_type == "text_delta": + content_block = ContentBlock(type="text", text=delta.get("text"), raw=deepcopy(delta)) + elif delta_type in {"thinking_delta", "signature_delta"}: + reasoning = ReasoningBlock(type=str(delta_type), text=delta.get("thinking"), signature=delta.get("signature"), extra=_without(delta, {"type", "thinking", "signature"})) + content_block = ContentBlock(type=str(delta_type), reasoning=reasoning, raw=deepcopy(delta)) + message = UnifiedMessage(role="assistant", content=[content_block] if content_block else []) + self._promote_message_blocks(message) + return UnifiedStreamEvent(type=str(data.get("type") or "content_block_delta"), operation=OPERATION_MESSAGES, delta=message, raw=deepcopy(raw_event), extra={"payload": data, "index": data.get("index")}) + + +def _operation_from_context(context: ProtocolContext | None, default: str) -> str: + supported = {OPERATION_MESSAGES, OPERATION_COUNT_TOKENS} + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation in supported: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation in supported: + return operation + return default + + +def _response_operation(response: dict[str, Any], context: ProtocolContext | None) -> str: + requested = _operation_from_context(context, OPERATION_MESSAGES) + if requested == OPERATION_COUNT_TOKENS: + return OPERATION_COUNT_TOKENS + if "input_tokens" in response and not response.get("content") and not response.get("id"): + return OPERATION_COUNT_TOKENS + return OPERATION_MESSAGES + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return {} + + +def _decode_sse_data(raw_event: Any) -> Any: + if not isinstance(raw_event, str): + return raw_event + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + if text == "[DONE]": + return text + try: + return json.loads(text) + except json.JSONDecodeError: + return raw_event + + +def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: + return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/src/rotator_library/protocols/base.py b/src/rotator_library/protocols/base.py new file mode 100644 index 000000000..293db1c05 --- /dev/null +++ b/src/rotator_library/protocols/base.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Base classes for native protocol adapters. + +Protocol adapters are intentionally override-friendly. They provide reusable +defaults for custom providers, but providers can override any method when a +service uses an almost-standard protocol with provider-specific quirks. +""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .operation import OPERATION_UNKNOWN, normalize_operation +from .types import ( + ProtocolContext, + ProtocolError, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, +) + + +class ProtocolAdapter: + """Base adapter for converting between raw protocol payloads and unified types. + + Subclasses should override only the methods they need. The default behavior + is deliberately conservative: preserve raw payloads and avoid lossy + assumptions. Later phases will layer transform logging, field-cache rules, + and provider override hooks around this interface. + """ + + name: ClassVar[str] = "base" + aliases: ClassVar[tuple[str, ...]] = () + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + future_transports: ClassVar[tuple[str, ...]] = () + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_UNKNOWN,) + + def supports_transport(self, transport_name: str) -> bool: + """Return whether this protocol can format the requested transport.""" + + return transport_name in self.supported_transports + + def is_future_transport(self, transport_name: str) -> bool: + """Return whether this protocol has an intentional future transport seam.""" + + return transport_name in self.future_transports + + def supports_operation(self, operation_name: str) -> bool: + """Return whether this adapter natively models an operation. + + Operation names are string based so custom protocols can add their own + values. The default base adapter only claims ``unknown`` and keeps raw + payloads intact; concrete adapters should list every operation they can + parse/build without relying on LiteLLM. + """ + + return normalize_operation(operation_name) in self.supported_operations + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + """Parse a raw client/provider request into a unified request.""" + + request = dict(raw_request or {}) + return UnifiedRequest( + operation=normalize_operation(request.get("operation")), + model=str(request.get("model") or getattr(context, "model", None) or ""), + stream=bool(request.get("stream", False)), + input=deepcopy(request.get("input")), + modalities=list(request.get("modalities") or []), + files=list(request.get("files") or []), + raw=deepcopy(raw_request), + extra={ + k: deepcopy(v) + for k, v in request.items() + if k not in {"operation", "model", "stream", "input", "modalities", "files"} + }, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + """Build a provider request from a unified request. + + The base implementation returns the original raw dict when present. This + keeps fallback providers safe and gives custom protocol subclasses a + predictable starting point. + """ + + if isinstance(unified_request.raw, dict): + return deepcopy(unified_request.raw) + if not isinstance(unified_request.raw, type(None)): + raise ProtocolError( + "cannot build dict request from non-dict raw payload", + protocol=self.name, + pass_name="build_request", + payload=unified_request.raw, + ) + payload = {"model": unified_request.model, "stream": unified_request.stream} + if unified_request.operation != OPERATION_UNKNOWN: + payload["operation"] = unified_request.operation + if unified_request.input is not None: + payload["input"] = deepcopy(unified_request.input) + if unified_request.modalities: + payload["modalities"] = deepcopy(unified_request.modalities) + if unified_request.files: + payload["files"] = deepcopy(unified_request.files) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + """Parse a raw response into a unified response.""" + + response = raw_response if isinstance(raw_response, dict) else {} + return UnifiedResponse( + operation=normalize_operation(response.get("operation") if isinstance(response, dict) else None), + id=response.get("id") if isinstance(response, dict) else None, + model=response.get("model") if isinstance(response, dict) else getattr(context, "model", None), + data=deepcopy(response.get("data") or []) if isinstance(response, dict) else [], + content_type=response.get("content_type") if isinstance(response, dict) else None, + raw=deepcopy(raw_response), + extra=deepcopy(response) if isinstance(response, dict) else {}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + """Format a unified response for a client protocol.""" + + if isinstance(unified_response.raw, dict): + return deepcopy(unified_response.raw) + payload = deepcopy(unified_response.extra) + if unified_response.operation != OPERATION_UNKNOWN: + payload.setdefault("operation", unified_response.operation) + if unified_response.id is not None: + payload.setdefault("id", unified_response.id) + if unified_response.model is not None: + payload.setdefault("model", unified_response.model) + if unified_response.data: + payload.setdefault("data", deepcopy(unified_response.data)) + if unified_response.content_type is not None: + payload.setdefault("content_type", unified_response.content_type) + return payload + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + """Parse one raw stream event. + + Subclasses should preserve the original event in ``raw`` because Phase 2 + transform logging needs both provider-native and unified states. + """ + + event_type = "done" if raw_event == "[DONE]" else "chunk" + return UnifiedStreamEvent(type=event_type, raw=deepcopy(raw_event)) + + def format_stream_event(self, unified_event: UnifiedStreamEvent, context: ProtocolContext | None = None) -> Any: + """Format one unified stream event for the target transport.""" + + return deepcopy(unified_event.raw) if unified_event.raw is not None else unified_event.to_dict() + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + """Extract normalized usage when the protocol can identify it.""" + + if isinstance(raw_or_unified, UnifiedResponse): + return raw_or_unified.usage + if isinstance(raw_or_unified, UnifiedStreamEvent): + return raw_or_unified.usage + if isinstance(raw_or_unified, dict) and isinstance(raw_or_unified.get("usage"), dict): + usage = raw_or_unified["usage"] + return Usage( + input_tokens=int(usage.get("prompt_tokens") or usage.get("input_tokens") or 0), + output_tokens=int(usage.get("completion_tokens") or usage.get("output_tokens") or 0), + total_tokens=int(usage.get("total_tokens") or 0), + raw=deepcopy(usage), + ) + return None diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py new file mode 100644 index 000000000..de9157d23 --- /dev/null +++ b/src/rotator_library/protocols/gemini.py @@ -0,0 +1,395 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Gemini generateContent protocol adapter. + +The adapter preserves Gemini-native content parts, thought signatures, safety +settings, tools, and generation configuration so later native providers can use +the same base without forcing an OpenAI-compatible intermediate shape. +""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar, Iterable + +from .base import ProtocolAdapter +from .operation import OPERATION_CHAT, OPERATION_COUNT_TOKENS, OPERATION_UNKNOWN, normalize_operation +from .types import ( + ContentBlock, + ProtocolContext, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, +) + +_REQUEST_CORE_FIELDS = { + "model", + "contents", + "systemInstruction", + "system_instruction", + "tools", + "generationConfig", + "generation_config", + "safetySettings", + "safety_settings", + "stream", +} + + +class GeminiProtocol(ProtocolAdapter): + """Adapter for Gemini ``generateContent`` and stream event shapes. + + Gemini parts are richer than simple chat messages. Unknown part fields remain + in ``extra`` and raw payloads are preserved so provider-specific subclasses + can refine behavior without losing data. + """ + + name: ClassVar[str] = "gemini" + aliases: ClassVar[tuple[str, ...]] = ("google_gemini", "generate_content") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_CHAT, OPERATION_COUNT_TOKENS, "generate", "stream_generate") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + generation_config = deepcopy(request.get("generationConfig") or request.get("generation_config") or {}) + safety_settings = deepcopy(request.get("safetySettings") or request.get("safety_settings") or []) + return UnifiedRequest( + operation=_operation_from_context(context, OPERATION_CHAT), + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=[self._parse_content(content) for content in request.get("contents") or []], + system=self._parse_system(request.get("systemInstruction") or request.get("system_instruction")), + tools=self._parse_tools(request.get("tools") or []), + stream=bool(request.get("stream", False)), + generation_params={"generationConfig": generation_config, "safetySettings": safety_settings}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = { + "contents": [self._format_content(message) for message in unified_request.messages], + } + if unified_request.model: + payload["model"] = unified_request.model + if unified_request.system: + payload["systemInstruction"] = {"parts": self._format_parts(unified_request.system)} + generation_config = unified_request.generation_params.get("generationConfig") + safety_settings = unified_request.generation_params.get("safetySettings") + if generation_config: + payload["generationConfig"] = deepcopy(generation_config) + if safety_settings: + payload["safetySettings"] = deepcopy(safety_settings) + if unified_request.tools: + payload["tools"] = self._format_tools(unified_request.tools) + if unified_request.stream: + payload["stream"] = True + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = _as_dict(raw_response) + messages: list[UnifiedMessage] = [] + stop_reason = None + for candidate in response.get("candidates") or []: + if not isinstance(candidate, dict): + continue + content = candidate.get("content") if isinstance(candidate.get("content"), dict) else {} + message = self._parse_content(content) + message.extra["candidate"] = _without(candidate, {"content"}) + messages.append(message) + if candidate.get("finishReason") is not None: + stop_reason = candidate.get("finishReason") + return UnifiedResponse( + operation=_response_operation(response, context), + id=response.get("responseId") or response.get("id"), + model=response.get("modelVersion") or getattr(context, "model", None), + messages=messages, + stop_reason=stop_reason, + usage=self.extract_usage(response, context), + metadata={"promptFeedback": deepcopy(response.get("promptFeedback")), "modelVersion": response.get("modelVersion")}, + raw=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"responseId", "id", "modelVersion", "candidates", "usageMetadata", "promptFeedback"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if unified_response.operation == OPERATION_COUNT_TOKENS: + usage = unified_response.usage + payload = deepcopy(unified_response.extra) + # Normalized usage wins over raw preserved fields so later adapters + # can correct counts without stale provider keys shadowing them. + payload["totalTokens"] = usage.total_tokens if usage else 0 + return payload + candidates = [] + for index, message in enumerate(unified_response.messages): + candidate = {"index": index, "content": self._format_content(message)} + if unified_response.stop_reason: + candidate["finishReason"] = unified_response.stop_reason + candidate.update(deepcopy(message.extra.get("candidate") or {})) + candidates.append(candidate) + payload = { + "responseId": unified_response.id, + "modelVersion": unified_response.model, + "candidates": candidates, + "usageMetadata": self._format_usage(unified_response.usage), + "promptFeedback": deepcopy(unified_response.metadata.get("promptFeedback")), + } + payload.update(deepcopy(unified_response.extra)) + return {k: v for k, v in payload.items() if v is not None} + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + event = _decode_sse_data(raw_event) + if event == "[DONE]": + return UnifiedStreamEvent(type="done", operation=OPERATION_CHAT, raw=deepcopy(raw_event)) + data = _as_dict(event) + response = self.parse_response(data, context) + message = response.messages[0] if response.messages else None + return UnifiedStreamEvent( + type="message_delta" if message else "chunk", + operation=response.operation, + delta=message, + usage=response.usage, + raw=deepcopy(raw_event), + extra={"payload": data, "finish_reason": response.stop_reason}, + ) + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): + return raw_or_unified.usage + payload = _as_dict(raw_or_unified) + usage = payload.get("usageMetadata") if isinstance(payload.get("usageMetadata"), dict) else payload + if not isinstance(usage, dict) or (not any(key.endswith("TokenCount") for key in usage) and "totalTokens" not in usage): + return None + input_tokens = int(usage.get("promptTokenCount") or 0) + output_tokens = int(usage.get("candidatesTokenCount") or 0) + reasoning_tokens = int(usage.get("thoughtsTokenCount") or 0) + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=int(usage.get("totalTokenCount") or usage.get("totalTokens") or input_tokens + output_tokens + reasoning_tokens), + cache_read_tokens=int(usage.get("cachedContentTokenCount") or 0), + reasoning_tokens=reasoning_tokens, + raw=deepcopy(usage), + ) + + def _parse_system(self, system: Any) -> list[ContentBlock]: + if system is None: + return [] + if isinstance(system, str): + return [ContentBlock(type="text", text=system, raw=system)] + if isinstance(system, dict): + return self._parse_parts(system.get("parts") or []) + return [] + + def _parse_content(self, content: dict[str, Any]) -> UnifiedMessage: + payload = dict(content or {}) + role = str(payload.get("role") or "model") + # Gemini uses "model" where chat protocols usually say "assistant". + normalized_role = "assistant" if role == "model" else role + message = UnifiedMessage( + role=normalized_role, + content=self._parse_parts(payload.get("parts") or []), + raw=deepcopy(content), + extra={"gemini_role": role, **_without(payload, {"role", "parts"})}, + ) + for block in message.content: + if block.tool_call: + message.tool_calls.append(block.tool_call) + if block.reasoning: + message.reasoning.append(block.reasoning) + return message + + def _format_content(self, message: UnifiedMessage) -> dict[str, Any]: + role = message.extra.get("gemini_role") or ("model" if message.role == "assistant" else message.role) + payload = {"role": role, "parts": self._format_parts(message.content)} + payload.update({k: deepcopy(v) for k, v in message.extra.items() if k != "gemini_role"}) + return payload + + def _parse_parts(self, parts: Iterable[Any]) -> list[ContentBlock]: + blocks = [] + for part in parts: + blocks.append(self._parse_part(part)) + return blocks + + def _parse_part(self, part: Any) -> ContentBlock: + if isinstance(part, str): + return ContentBlock(type="text", text=part, raw=part) + if not isinstance(part, dict): + return ContentBlock(type="unknown", raw=deepcopy(part)) + if "text" in part: + reasoning = None + if part.get("thought") or part.get("thoughtSignature"): + reasoning = ReasoningBlock(type="gemini_thought", text=part.get("text"), signature=part.get("thoughtSignature"), raw=deepcopy(part), extra=_without(part, {"text", "thought", "thoughtSignature"})) + return ContentBlock(type="text", text=part.get("text", ""), reasoning=reasoning, raw=deepcopy(part), extra=_without(part, {"text"})) + if "inlineData" in part or "inline_data" in part: + source = part.get("inlineData") or part.get("inline_data") + return ContentBlock(type="inline_data", source=deepcopy(source), raw=deepcopy(part), extra=_without(part, {"inlineData", "inline_data"})) + if "fileData" in part or "file_data" in part: + source = part.get("fileData") or part.get("file_data") + return ContentBlock(type="file_data", source=deepcopy(source), raw=deepcopy(part), extra=_without(part, {"fileData", "file_data"})) + if "functionCall" in part or "function_call" in part: + call = part.get("functionCall") or part.get("function_call") or {} + return ContentBlock(type="function_call", tool_call=ToolCall(name=call.get("name"), arguments=deepcopy(call.get("args")), type="function_call", raw=deepcopy(call)), raw=deepcopy(part), extra=_without(part, {"functionCall", "function_call"})) + if "functionResponse" in part or "function_response" in part: + response = part.get("functionResponse") or part.get("function_response") or {} + return ContentBlock(type="function_response", tool_result=ToolResult(tool_call_id=response.get("name"), content=deepcopy(response.get("response")), raw=deepcopy(response)), raw=deepcopy(part), extra=_without(part, {"functionResponse", "function_response"})) + return ContentBlock(type="unknown", raw=deepcopy(part), extra=deepcopy(part)) + + def _format_parts(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: + parts = [] + for block in blocks: + if block.tool_call: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["functionCall"] = {"name": block.tool_call.name, "args": deepcopy(block.tool_call.arguments)} + parts.append(payload) + elif block.tool_result: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["functionResponse"] = {"name": block.tool_result.tool_call_id, "response": deepcopy(block.tool_result.content)} + parts.append(payload) + elif block.type == "inline_data": + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["inlineData"] = deepcopy(block.source) + parts.append(payload) + elif block.type == "file_data": + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["fileData"] = deepcopy(block.source) + parts.append(payload) + else: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["text"] = block.text or "" + if block.reasoning: + payload["thought"] = True + if block.reasoning.signature: + payload["thoughtSignature"] = block.reasoning.signature + payload.update(deepcopy(block.extra)) + parts.append(payload) + return parts + + def _parse_tools(self, tools: Iterable[dict[str, Any]]) -> list[ToolDefinition]: + parsed: list[ToolDefinition] = [] + for container_index, tool in enumerate(tools): + payload = dict(tool or {}) + declarations = payload.get("functionDeclarations") or payload.get("function_declarations") or [] + if declarations: + for index, declaration in enumerate(declarations): + if not isinstance(declaration, dict): + continue + parsed.append( + ToolDefinition( + name=str(declaration.get("name") or ""), + description=declaration.get("description"), + input_schema=deepcopy(declaration.get("parameters") or {}), + type="function", + extra={"raw_container": deepcopy(tool), "container_index": container_index, "declaration_index": index}, + ) + ) + continue + parsed.append( + ToolDefinition( + name=str(payload.get("name") or payload.get("type") or "gemini_tool"), + description=payload.get("description"), + input_schema=deepcopy(payload.get("parameters") or {}), + type=str(payload.get("type") or next(iter(payload.keys()), "tool")), + extra={"raw": deepcopy(tool)}, + ) + ) + return parsed + + def _format_tools(self, tools: Iterable[ToolDefinition]) -> list[dict[str, Any]]: + grouped: dict[int, dict[str, Any]] = {} + ungrouped: list[dict[str, Any]] = [] + for tool in tools: + raw_container = tool.extra.get("raw_container") + container_index = tool.extra.get("container_index") + declaration_index = tool.extra.get("declaration_index") + if isinstance(raw_container, dict) and isinstance(container_index, int) and isinstance(declaration_index, int): + container = grouped.setdefault(container_index, deepcopy(raw_container)) + declarations = container.setdefault("functionDeclarations", []) + while len(declarations) <= declaration_index: + declarations.append({}) + declaration = deepcopy(declarations[declaration_index]) if isinstance(declarations[declaration_index], dict) else {} + declaration["name"] = tool.name + if tool.description is not None: + declaration["description"] = tool.description + declaration["parameters"] = deepcopy(tool.input_schema) + declarations[declaration_index] = declaration + continue + ungrouped.append(self._format_tool(tool)) + return [grouped[index] for index in sorted(grouped)] + ungrouped + + def _format_tool(self, tool: ToolDefinition) -> dict[str, Any]: + raw = tool.extra.get("raw") + if isinstance(raw, dict): + return deepcopy(raw) + return {"functionDeclarations": [{"name": tool.name, "description": tool.description, "parameters": deepcopy(tool.input_schema)}]} + + def _format_usage(self, usage: Usage | None) -> dict[str, int] | None: + if usage is None: + return None + payload = { + "promptTokenCount": usage.input_tokens, + "candidatesTokenCount": usage.output_tokens, + "totalTokenCount": usage.total_tokens, + } + if usage.reasoning_tokens: + payload["thoughtsTokenCount"] = usage.reasoning_tokens + if usage.cache_read_tokens: + payload["cachedContentTokenCount"] = usage.cache_read_tokens + return payload + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return {} + + +def _operation_from_context(context: ProtocolContext | None, default: str) -> str: + supported = {OPERATION_CHAT, OPERATION_COUNT_TOKENS, "generate", "stream_generate"} + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation in supported: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation in supported: + return operation + return default + + +def _response_operation(response: dict[str, Any], context: ProtocolContext | None) -> str: + requested = _operation_from_context(context, OPERATION_CHAT) + if requested == OPERATION_COUNT_TOKENS: + return OPERATION_COUNT_TOKENS + if "totalTokens" in response and "candidates" not in response: + return OPERATION_COUNT_TOKENS + return requested if requested in {"generate", "stream_generate"} else OPERATION_CHAT + + +def _decode_sse_data(raw_event: Any) -> Any: + if not isinstance(raw_event, str): + return raw_event + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + if text == "[DONE]": + return text + try: + return json.loads(text) + except json.JSONDecodeError: + return raw_event + + +def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: + return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/src/rotator_library/protocols/litellm_fallback.py b/src/rotator_library/protocols/litellm_fallback.py new file mode 100644 index 000000000..6c66352c6 --- /dev/null +++ b/src/rotator_library/protocols/litellm_fallback.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Explicit fallback protocol for the existing LiteLLM-shaped path. + +This adapter intentionally does very little beyond preserving raw payloads. It +gives later execution and transform-logging code a named protocol path whenever +native adapters do not yet cover a provider or request shape. +""" + +from __future__ import annotations + +from typing import ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_UNKNOWN + + +class LiteLLMFallbackProtocol(ProtocolAdapter): + """Protocol adapter that keeps the current LiteLLM-compatible payload shape. + + Providers should prefer native protocol adapters when available. This class + remains as a safe compatibility base for unsupported provider shapes and as + a clear marker in future transaction transform traces. + """ + + name: ClassVar[str] = "litellm_fallback" + aliases: ClassVar[tuple[str, ...]] = ("litellm", "fallback") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_UNKNOWN,) diff --git a/src/rotator_library/protocols/mcp.py b/src/rotator_library/protocols/mcp.py new file mode 100644 index 000000000..e51a1ec5c --- /dev/null +++ b/src/rotator_library/protocols/mcp.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""MCP JSON-RPC carrier protocol adapter. + +This is not a full MCP proxy implementation. It gives the native protocol layer a +lossless request/response carrier for future MCP gateway work, keeping method, +params, ids, results, errors, and JSON-RPC batch arrays intact for transform +logging and routing. +""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_MCP +from .types import ProtocolContext, UnifiedRequest, UnifiedResponse, UnifiedStreamEvent + +_REQUEST_CORE_FIELDS = {"jsonrpc", "id", "method", "params"} +_RESPONSE_CORE_FIELDS = {"jsonrpc", "id", "result", "error"} + + +class MCPProtocol(ProtocolAdapter): + """Adapter for MCP-style JSON-RPC request and response envelopes.""" + + name: ClassVar[str] = "mcp" + aliases: ClassVar[tuple[str, ...]] = ("model_context_protocol", "jsonrpc_mcp") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_MCP,) + supported_transports: ClassVar[tuple[str, ...]] = ("http",) + future_transports: ClassVar[tuple[str, ...]] = ("sse", "websocket") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + if isinstance(raw_request, list): + return UnifiedRequest(operation=OPERATION_MCP, input=deepcopy(raw_request), metadata={"batch": True}, raw=deepcopy(raw_request)) + request = dict(raw_request or {}) + metadata = { + "jsonrpc": request.get("jsonrpc", "2.0"), + "method": request.get("method"), + "has_id": "id" in request, + "has_params": "params" in request, + } + if "id" in request: + metadata["id"] = deepcopy(request.get("id")) + return UnifiedRequest( + operation=OPERATION_MCP, + input=deepcopy(request.get("params")) if "params" in request else None, + metadata=metadata, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> Any: + if unified_request.metadata.get("batch"): + return deepcopy(unified_request.input or []) + payload = { + "jsonrpc": unified_request.metadata.get("jsonrpc", "2.0"), + "method": unified_request.metadata.get("method"), + } + if unified_request.metadata.get("has_params", True): + payload["params"] = deepcopy(unified_request.input) + if unified_request.metadata.get("has_id"): + payload["id"] = deepcopy(unified_request.metadata.get("id")) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + if isinstance(raw_response, list): + return UnifiedResponse(operation=OPERATION_MCP, data=deepcopy(raw_response), metadata={"batch": True}, raw=deepcopy(raw_response)) + response = raw_response if isinstance(raw_response, dict) else {} + metadata = {"jsonrpc": response.get("jsonrpc", "2.0"), "has_id": "id" in response} + if "id" in response: + metadata["id"] = deepcopy(response.get("id")) + extra = {k: deepcopy(v) for k, v in response.items() if k not in _RESPONSE_CORE_FIELDS} + if "error" in response: + # JSON-RPC errors are not provider exceptions here; they are protocol + # payloads that must survive transform logging and response rebuilds. + extra["error"] = deepcopy(response["error"]) + return UnifiedResponse( + operation=OPERATION_MCP, + data=[deepcopy(response["result"])] if "result" in response else [], + metadata=metadata, + raw=deepcopy(raw_response), + extra=extra, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> Any: + if unified_response.metadata.get("batch"): + return deepcopy(unified_response.data) + payload = {"jsonrpc": unified_response.metadata.get("jsonrpc", "2.0")} + if unified_response.metadata.get("has_id"): + payload["id"] = deepcopy(unified_response.metadata.get("id")) + if unified_response.extra.get("error") is not None: + payload["error"] = deepcopy(unified_response.extra["error"]) + else: + payload["result"] = deepcopy(unified_response.data[0] if unified_response.data else None) + payload.update({k: deepcopy(v) for k, v in unified_response.extra.items() if k != "error"}) + return payload + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + data = raw_event if isinstance(raw_event, dict) else {"event": raw_event} + return UnifiedStreamEvent(type=str(data.get("method") or data.get("type") or "message"), operation=OPERATION_MCP, error=deepcopy(data.get("error")), raw=deepcopy(raw_event), extra=deepcopy(data)) diff --git a/src/rotator_library/protocols/ollama.py b/src/rotator_library/protocols/ollama.py new file mode 100644 index 000000000..748cca1c7 --- /dev/null +++ b/src/rotator_library/protocols/ollama.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Ollama-native chat, generate, and embeddings protocol adapter.""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_EMBEDDINGS, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, normalize_operation +from .types import ContentBlock, ProtocolContext, UnifiedMessage, UnifiedRequest, UnifiedResponse, UnifiedStreamEvent, Usage, first_text + +_OPTION_FIELDS = {"options", "format", "keep_alive", "template", "context", "raw", "suffix"} +_CORE_FIELDS = {"operation", "model", "messages", "prompt", "input", "stream", "system", *_OPTION_FIELDS} + + +class OllamaProtocol(ProtocolAdapter): + """Adapter for Ollama `/api/chat`, `/api/generate`, and embeddings shapes.""" + + name: ClassVar[str] = "ollama" + aliases: ClassVar[tuple[str, ...]] = ("ollama_native",) + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS) + supported_transports: ClassVar[tuple[str, ...]] = ("http", "jsonl") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + operation = _ollama_operation(request, context) + input_field = "input" + if operation == OPERATION_EMBEDDINGS and "input" not in request and "prompt" in request: + # Older Ollama embeddings endpoints use `prompt`; newer endpoints use + # `input`. Preserve the original spelling so round-trips are lossless. + input_field = "prompt" + return UnifiedRequest( + operation=operation, + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=[_message_from_ollama(message) for message in request.get("messages") or []], + system=[ContentBlock(type="text", text=str(request["system"]))] if "system" in request else [], + stream=bool(request.get("stream", False)), + input=deepcopy(request.get(input_field) if operation == OPERATION_EMBEDDINGS else request.get("prompt")), + generation_params={k: deepcopy(request[k]) for k in _OPTION_FIELDS if k in request}, + metadata={**({"embedding_input_field": input_field} if operation == OPERATION_EMBEDDINGS else {}), "has_stream": "stream" in request}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = {"model": unified_request.model} + if unified_request.stream or unified_request.metadata.get("has_stream"): + payload["stream"] = unified_request.stream + if unified_request.operation == OPERATION_OLLAMA_CHAT: + payload["messages"] = [_message_to_ollama(message) for message in unified_request.messages] + elif unified_request.operation == OPERATION_EMBEDDINGS: + input_field = str(unified_request.metadata.get("embedding_input_field") or "input") + payload[input_field] = deepcopy(unified_request.input) + else: + payload["prompt"] = deepcopy(unified_request.input) + if unified_request.system: + payload["system"] = "".join(block.text or "" for block in unified_request.system if block.type == "text") + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = raw_response if isinstance(raw_response, dict) else {} + output = [] + messages: list[UnifiedMessage] = [] + if isinstance(response.get("message"), dict): + message = _message_from_ollama(response["message"]) + messages.append(message) + output.append(message.to_dict()) + elif "response" in response: + output.append(response.get("response")) + return UnifiedResponse( + operation=_ollama_operation(response, context), + model=response.get("model") or getattr(context, "model", None), + messages=messages, + output=output, + data=deepcopy(response.get("embeddings") or response.get("embedding") or []), + usage=_ollama_usage(response), + raw=deepcopy(raw_response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "message", "response", "embeddings", "embedding", "prompt_eval_count", "eval_count", "total_duration", "load_duration", "prompt_eval_duration", "eval_duration"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + """Format a unified response back to an Ollama response shape. + + Ollama responses are often mutated by adapters after parsing. Do not + return `raw` wholesale here; rebuild the public fields from unified state + and then merge preserved extras/timing values. + """ + + payload = deepcopy(unified_response.extra) + if unified_response.model: + payload["model"] = unified_response.model + operation = unified_response.operation or _ollama_operation(payload, context) + if operation == OPERATION_OLLAMA_CHAT: + if unified_response.messages: + payload["message"] = _message_to_ollama(unified_response.messages[0]) + elif operation == OPERATION_EMBEDDINGS: + raw = unified_response.raw if isinstance(unified_response.raw, dict) else {} + key = "embedding" if "embedding" in raw and "embeddings" not in raw else "embeddings" + payload[key] = deepcopy(unified_response.data) + else: + payload["response"] = _ollama_response_text(unified_response) + if unified_response.usage and isinstance(unified_response.usage.raw, dict): + for key, value in unified_response.usage.raw.items(): + if key.endswith("duration"): + payload[key] = deepcopy(value) + if unified_response.usage.input_tokens: + payload["prompt_eval_count"] = unified_response.usage.input_tokens + if unified_response.usage.output_tokens: + payload["eval_count"] = unified_response.usage.output_tokens + return {k: v for k, v in payload.items() if v is not None} + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + data = _json_event(raw_event) + if not isinstance(data, dict): + return UnifiedStreamEvent(type="metadata", operation=OPERATION_OLLAMA_GENERATE, raw=deepcopy(raw_event), extra={"unparsed": True}) + event_type = "done" if data.get("done") else "message_delta" + delta = None + if isinstance(data.get("message"), dict): + delta = _message_from_ollama(data["message"]) + elif data.get("response"): + delta = UnifiedMessage(role="assistant", content=[ContentBlock(type="text", text=str(data["response"]))]) + return UnifiedStreamEvent(type=event_type, operation=_ollama_operation(data, context), delta=delta, usage=_ollama_usage(data), raw=deepcopy(raw_event), extra=deepcopy(data)) + + +def _ollama_operation(request: dict[str, Any], context: ProtocolContext | None = None) -> str: + explicit = normalize_operation(request.get("operation")) + if explicit in {OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS}: + return explicit + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation in {OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS}: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation in {OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS}: + return operation + if "messages" in request or "message" in request: + return OPERATION_OLLAMA_CHAT + if "prompt" in request and request.get("endpoint") != "embeddings": + return OPERATION_OLLAMA_GENERATE + if "embeddings" in request or "embedding" in request or "input" in request or request.get("endpoint") == "embeddings": + return OPERATION_EMBEDDINGS + return OPERATION_OLLAMA_GENERATE + + +def _message_from_ollama(message: dict[str, Any]) -> UnifiedMessage: + return UnifiedMessage(role=str(message.get("role") or "assistant"), content=[ContentBlock(type="text", text=str(message.get("content") or ""))], raw=deepcopy(message), extra={k: deepcopy(v) for k, v in message.items() if k not in {"role", "content"}}) + + +def _message_to_ollama(message: UnifiedMessage) -> dict[str, Any]: + payload = {"role": message.role, "content": "".join(block.text or "" for block in message.content if block.type == "text")} + payload.update(deepcopy(message.extra)) + return payload + + +def _ollama_response_text(response: UnifiedResponse) -> str: + if response.output: + return "".join(str(item) for item in response.output if item is not None) + if response.messages: + text = first_text(response.messages[0].content) + if text is not None: + return text + return "" + + +def _json_event(raw_event: Any) -> Any: + if isinstance(raw_event, dict): + return raw_event + if not isinstance(raw_event, str): + return None + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + try: + return json.loads(text) + except json.JSONDecodeError: + return None + + +def _ollama_usage(response: dict[str, Any]) -> Usage | None: + prompt_tokens = int(response.get("prompt_eval_count") or 0) + output_tokens = int(response.get("eval_count") or 0) + if not prompt_tokens and not output_tokens: + return None + return Usage(input_tokens=prompt_tokens, output_tokens=output_tokens, raw={k: deepcopy(v) for k, v in response.items() if k.endswith("count") or k.endswith("duration")}) diff --git a/src/rotator_library/protocols/openai_audio.py b/src/rotator_library/protocols/openai_audio.py new file mode 100644 index 000000000..392f5b19e --- /dev/null +++ b/src/rotator_library/protocols/openai_audio.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI-compatible audio transcription/translation and speech protocol adapter.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH, normalize_operation +from .types import ProtocolContext, UnifiedRequest, UnifiedResponse + +_AUDIO_OPTION_FIELDS = {"language", "response_format", "temperature", "timestamp_granularities"} +_SPEECH_OPTION_FIELDS = {"voice", "response_format", "speed"} +_CORE_FIELDS = {"operation", "model", "file", "input", "prompt", *_AUDIO_OPTION_FIELDS, *_SPEECH_OPTION_FIELDS} + + +class OpenAIAudioProtocol(ProtocolAdapter): + """Adapter for OpenAI audio transcription/translation and speech requests. + + Audio file and generated-audio bytes are intentionally preserved instead of + interpreted. Transports/providers decide multipart and binary handling; this + protocol only records enough structure for routing, tracing, and tests. + """ + + name: ClassVar[str] = "openai_audio" + aliases: ClassVar[tuple[str, ...]] = ("audio", "audio_transcription", "speech", "tts") + supported_operations: ClassVar[tuple[str, ...]] = ( + OPERATION_AUDIO_TRANSCRIPTION, + OPERATION_AUDIO_TRANSLATION, + OPERATION_SPEECH, + ) + supported_transports: ClassVar[tuple[str, ...]] = ("http",) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + operation = _audio_operation(request, context) + files = [] + if "file" in request: + files.append({"field": "file", "value": deepcopy(request["file"])}) + options = _SPEECH_OPTION_FIELDS if operation == OPERATION_SPEECH else _AUDIO_OPTION_FIELDS + return UnifiedRequest( + operation=operation, + model=str(request.get("model") or getattr(context, "model", None) or ""), + input=deepcopy(request.get("input") if operation == OPERATION_SPEECH else request.get("prompt")), + files=files, + generation_params={k: deepcopy(request[k]) for k in options if k in request}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = {"model": unified_request.model} + if unified_request.operation == OPERATION_SPEECH: + payload["input"] = deepcopy(unified_request.input) + elif unified_request.input is not None: + payload["prompt"] = deepcopy(unified_request.input) + for file_entry in unified_request.files: + if isinstance(file_entry, dict) and file_entry.get("field"): + payload[str(file_entry["field"])] = deepcopy(file_entry.get("value")) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> Any: + """Format audio responses without forcing binary/text payloads into JSON. + + Speech endpoints can return raw audio bytes while transcription endpoints + can return JSON or plain text depending on `response_format`. Returning + the preserved raw body for non-dict responses keeps the protocol adapter + honest until a transport layer decides headers and streaming behavior. + """ + + if unified_response.raw is not None and not isinstance(unified_response.raw, dict): + return deepcopy(unified_response.raw) + payload: dict[str, Any] = {} + if unified_response.output: + payload["text"] = deepcopy(unified_response.output[0]) + if unified_response.data: + payload["data"] = deepcopy(unified_response.data) + if unified_response.content_type: + payload["content_type"] = unified_response.content_type + payload.update(deepcopy(unified_response.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + if isinstance(raw_response, dict): + response = raw_response + operation = normalize_operation(response.get("operation")) + if operation == "unknown": + operation = OPERATION_AUDIO_TRANSCRIPTION + return UnifiedResponse( + operation=_context_operation(context, operation), + model=response.get("model") or getattr(context, "model", None), + output=[deepcopy(response["text"])] if "text" in response else [], + data=deepcopy(response.get("data") or []), + content_type=response.get("content_type") or "application/json", + raw=deepcopy(raw_response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"operation", "model", "text", "data", "content_type"}}, + ) + content_type = "text/plain" if isinstance(raw_response, str) else "application/octet-stream" + # Providers normally return transcription text as JSON/text and speech as + # bytes. Context can still override this heuristic for unusual endpoints. + default_operation = OPERATION_AUDIO_TRANSCRIPTION if isinstance(raw_response, str) else OPERATION_SPEECH + return UnifiedResponse(operation=_context_operation(context, default_operation), content_type=content_type, raw=deepcopy(raw_response), output=[deepcopy(raw_response)] if isinstance(raw_response, str) else []) + + +def _audio_operation(request: dict[str, Any], context: ProtocolContext | None = None) -> str: + explicit = normalize_operation(request.get("operation")) + if explicit in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: + return explicit + contextual = _context_operation(context, "unknown") + if contextual in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: + return contextual + if "voice" in request or ("input" in request and "file" not in request): + return OPERATION_SPEECH + return OPERATION_AUDIO_TRANSCRIPTION + + +def _context_operation(context: ProtocolContext | None, default: str) -> str: + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: + return operation + return default diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py new file mode 100644 index 000000000..c21452510 --- /dev/null +++ b/src/rotator_library/protocols/openai_chat.py @@ -0,0 +1,532 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI Chat Completions protocol adapter. + +The adapter models the common OpenAI-compatible chat shape used by many current +providers. It is a reusable base, not a final authority: providers can subclass +or override pieces when they need non-standard fields, stricter ordering, or +different stream semantics. +""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar, Iterable + +from .base import ProtocolAdapter +from .operation import OPERATION_CHAT +from .types import ( + ContentBlock, + CostDetails, + ProtocolContext, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, + first_text, + serialize_value, + text_blocks, +) + +_GENERATION_PARAMS = { + "frequency_penalty", + "logit_bias", + "logprobs", + "max_completion_tokens", + "max_tokens", + "n", + "parallel_tool_calls", + "presence_penalty", + "reasoning_effort", + "seed", + "service_tier", + "stop", + "stream_options", + "temperature", + "tool_choice", + "top_logprobs", + "top_p", + "user", +} + +_REQUEST_CORE_FIELDS = { + "model", + "messages", + "tools", + "stream", + "response_format", + "metadata", + *_GENERATION_PARAMS, +} + + +class OpenAIChatProtocol(ProtocolAdapter): + """Adapter for OpenAI Chat Completions request, response, and stream chunks. + + Unknown OpenAI-compatible extension fields are preserved in ``extra`` so a + custom provider can still use them through later adapter or field-cache + phases. Lossy conversions are avoided unless the source shape itself uses a + compact representation, such as string message content. + """ + + name: ClassVar[str] = "openai_chat" + aliases: ClassVar[tuple[str, ...]] = ( + "openai", + "chat_completions", + "openai_chat_completions", + ) + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_CHAT,) + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + messages = [self._parse_message(message) for message in request.get("messages") or []] + tools = [self._parse_tool_definition(tool) for tool in request.get("tools") or []] + generation_params = {k: deepcopy(request[k]) for k in _GENERATION_PARAMS if k in request} + extra = {k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS} + + return UnifiedRequest( + operation=OPERATION_CHAT, + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=messages, + tools=tools, + stream=bool(request.get("stream", False)), + generation_params=generation_params, + response_format=deepcopy(request.get("response_format")), + metadata=deepcopy(request.get("metadata") or {}), + raw=deepcopy(raw_request), + extra=extra, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": unified_request.model, + "messages": [self._format_message(message) for message in unified_request.messages], + } + if unified_request.tools: + payload["tools"] = [self._format_tool_definition(tool) for tool in unified_request.tools] + if unified_request.stream: + payload["stream"] = True + if unified_request.response_format is not None: + payload["response_format"] = deepcopy(unified_request.response_format) + if unified_request.metadata: + payload["metadata"] = deepcopy(unified_request.metadata) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = _as_dict(raw_response) + messages: list[UnifiedMessage] = [] + stop_reason = None + for choice in response.get("choices") or []: + if not isinstance(choice, dict): + continue + message_payload = choice.get("message") or {} + if message_payload: + messages.append(self._parse_message(message_payload)) + if choice.get("finish_reason") is not None: + stop_reason = choice.get("finish_reason") + + return UnifiedResponse( + operation=OPERATION_CHAT, + id=response.get("id"), + model=response.get("model") or getattr(context, "model", None), + messages=messages, + stop_reason=stop_reason, + usage=self.extract_usage(response, context), + metadata={ + "object": response.get("object"), + "created": response.get("created"), + "system_fingerprint": response.get("system_fingerprint"), + }, + raw=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"id", "object", "created", "model", "choices", "usage", "system_fingerprint"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + choices = [] + for index, message in enumerate(unified_response.messages): + choices.append( + { + "index": index, + "message": _format_response_message(self._format_message(message), message), + "finish_reason": unified_response.stop_reason, + } + ) + payload = { + "id": unified_response.id, + "object": unified_response.metadata.get("object", "chat.completion"), + "created": unified_response.metadata.get("created"), + "model": unified_response.model, + "choices": choices, + "usage": _format_openai_usage(unified_response.usage), + } + payload.update(deepcopy(unified_response.extra)) + return {k: v for k, v in payload.items() if v is not None} + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + event = _decode_sse_data(raw_event) + if event == "[DONE]": + return UnifiedStreamEvent(type="done", operation=OPERATION_CHAT, raw=deepcopy(raw_event)) + data = _as_dict(event) + if data.get("error") is not None: + return UnifiedStreamEvent(type="error", operation=OPERATION_CHAT, error=deepcopy(data["error"]), raw=deepcopy(raw_event), extra={"payload": data}) + + delta_message = None + finish_reason = None + for choice in data.get("choices") or []: + if not isinstance(choice, dict): + continue + delta = choice.get("delta") or {} + if delta: + delta_message = self._parse_message({"role": delta.get("role", "assistant"), **delta}) + finish_reason = choice.get("finish_reason") if choice.get("finish_reason") is not None else finish_reason + break + + usage = self.extract_usage(data, context) + return UnifiedStreamEvent( + type="message_delta" if delta_message else "chunk", + operation=OPERATION_CHAT, + delta=delta_message, + usage=usage, + raw=deepcopy(raw_event), + extra={ + "id": data.get("id"), + "model": data.get("model"), + "finish_reason": finish_reason, + "payload": data, + }, + ) + + def format_stream_event(self, unified_event: UnifiedStreamEvent, context: ProtocolContext | None = None) -> Any: + if unified_event.type == "done": + return "data: [DONE]\n\n" + if unified_event.raw is not None and (context is None or context.source_protocol in {None, "openai_chat"}): + payload = deepcopy(unified_event.raw) + if unified_event.delta is not None and isinstance(payload, dict) and isinstance(payload.get("choices"), list) and payload["choices"]: + choice = payload["choices"][0] + if isinstance(choice, dict): + formatted_delta = self._format_message(unified_event.delta) + original_delta = choice.get("delta") if isinstance(choice.get("delta"), dict) else {} + if "role" not in original_delta: + formatted_delta.pop("role", None) + choice["delta"] = formatted_delta + return payload + if unified_event.delta is not None: + delta = _format_response_message(self._format_message(unified_event.delta), unified_event.delta) + if unified_event.extra.get("finish_reason") is None: + delta.pop("role", None) + payload = { + "id": unified_event.extra.get("id"), + "object": "chat.completion.chunk", + "model": unified_event.extra.get("model"), + "choices": [{"index": 0, "delta": delta, "finish_reason": unified_event.extra.get("finish_reason")}], + "usage": _format_openai_usage(unified_event.usage), + } + return f"data: {json.dumps({k: v for k, v in payload.items() if v is not None})}\n\n" + return f"data: {json.dumps(unified_event.to_dict())}\n\n" + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): + return raw_or_unified.usage + payload = _as_dict(raw_or_unified) + usage = payload.get("usage") + if not isinstance(usage, dict): + return None + prompt_details = usage.get("prompt_tokens_details") or {} + completion_details = usage.get("completion_tokens_details") or {} + if not isinstance(prompt_details, dict): + prompt_details = {} + if not isinstance(completion_details, dict): + completion_details = {} + cost = None + cost_details = usage.get("cost_details") + if isinstance(cost_details, dict): + provider_cost = cost_details.get("total_cost") or cost_details.get("request_cost_usd") or cost_details.get("cost") or cost_details.get("estimated_cost") + cost = CostDetails( + provider_reported_cost=float(provider_cost) if provider_cost is not None else None, + currency=str(cost_details.get("currency") or "USD"), + source="usage.cost_details", + metadata={k: deepcopy(v) for k, v in cost_details.items() if k not in {"total_cost", "cost", "currency"}}, + ) + return Usage( + input_tokens=int(usage.get("prompt_tokens") or usage.get("input_tokens") or 0), + output_tokens=int(usage.get("completion_tokens") or usage.get("output_tokens") or 0), + total_tokens=int(usage.get("total_tokens") or 0), + cache_read_tokens=int(prompt_details.get("cached_tokens") or usage.get("cache_read_tokens") or 0), + cache_write_tokens=int(prompt_details.get("cache_creation_tokens") or usage.get("cache_creation_tokens") or 0), + reasoning_tokens=int(completion_details.get("reasoning_tokens") or usage.get("reasoning_tokens") or 0), + cost=cost, + raw=deepcopy(usage), + ) + + def _parse_message(self, message: dict[str, Any]) -> UnifiedMessage: + payload = dict(message or {}) + reasoning = _extract_reasoning(payload) + return UnifiedMessage( + role=str(payload.get("role") or "assistant"), + content=self._parse_content(payload.get("content")), + name=payload.get("name"), + tool_call_id=payload.get("tool_call_id"), + tool_calls=self._parse_message_tool_calls(payload), + reasoning=reasoning, + raw=deepcopy(message), + extra={k: deepcopy(v) for k, v in payload.items() if k not in {"role", "content", "name", "tool_call_id", "tool_calls", "reasoning", "reasoning_content"}}, + ) + + def _format_message(self, message: UnifiedMessage) -> dict[str, Any]: + payload: dict[str, Any] = {"role": message.role} + if message.name: + payload["name"] = message.name + if message.tool_call_id: + payload["tool_call_id"] = message.tool_call_id + content = self._format_content(message.content) + if content is not None: + payload["content"] = content + extra = deepcopy(message.extra) + legacy_function_call = extra.get("function_call") + if message.tool_calls and not legacy_function_call: + payload["tool_calls"] = [self._format_tool_call(call) for call in message.tool_calls] + elif message.tool_calls and legacy_function_call: + call = message.tool_calls[0] + extra["function_call"] = {"name": call.name or "", "arguments": _format_arguments(call.arguments)} + if message.reasoning: + # OpenAI-compatible providers use multiple names for reasoning text. + # Prefer the common extension field while keeping all blocks in extra. + text = "".join(block.text or "" for block in message.reasoning if block.text) + if text: + payload["reasoning_content"] = text + payload.update(extra) + return payload + + def _parse_message_tool_calls(self, payload: dict[str, Any]) -> list[ToolCall]: + """Return modern and legacy OpenAI function calls as unified tools.""" + + modern_calls = payload.get("tool_calls") or [] + if modern_calls: + return [self._parse_tool_call(call) for call in modern_calls] + legacy_call = payload.get("function_call") + if isinstance(legacy_call, dict): + return [ + ToolCall( + id=None, + name=legacy_call.get("name"), + arguments=legacy_call.get("arguments"), + type="function", + raw=deepcopy(legacy_call), + extra={"legacy_function_call": True}, + ) + ] + return [] + + def _parse_content(self, content: Any) -> list[ContentBlock]: + if content is None: + return [] + if isinstance(content, str): + return text_blocks(content) + if not isinstance(content, list): + return [ContentBlock(type="unknown", raw=deepcopy(content))] + blocks = [] + for block in content: + if isinstance(block, str): + blocks.append(ContentBlock(type="text", text=block, raw=block)) + continue + if not isinstance(block, dict): + blocks.append(ContentBlock(type="unknown", raw=deepcopy(block))) + continue + block_type = block.get("type", "text") + if block_type == "text": + blocks.append(ContentBlock(type="text", text=block.get("text", ""), raw=deepcopy(block), extra=_without(block, {"type", "text"}))) + elif block_type in {"image_url", "input_image"}: + blocks.append(ContentBlock(type=block_type, source=deepcopy(block.get("image_url") or block.get("source")), raw=deepcopy(block), extra=_without(block, {"type", "image_url", "source"}))) + else: + blocks.append(ContentBlock(type=str(block_type), raw=deepcopy(block), extra=_without(block, {"type"}))) + return blocks + + def _format_content(self, blocks: Iterable[ContentBlock]) -> Any: + block_list = list(blocks) + if not block_list: + return None + if all(block.type == "text" and not isinstance(block.raw, dict) and not block.extra for block in block_list): + return first_text(block_list) or "" + formatted = [] + for block in block_list: + if block.type == "text": + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": "text"} + payload["type"] = payload.get("type", "text") + payload["text"] = block.text or "" + payload.update(deepcopy(block.extra)) + formatted.append(payload) + elif block.type in {"image_url", "input_image"}: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": block.type} + payload["type"] = block.type + payload["image_url"] = deepcopy(block.source) + payload.update(deepcopy(block.extra)) + formatted.append(payload) + elif isinstance(block.raw, dict): + formatted.append(deepcopy(block.raw)) + else: + payload = {"type": block.type} + payload.update(deepcopy(block.extra)) + formatted.append(payload) + return formatted + + def _parse_tool_definition(self, tool: dict[str, Any]) -> ToolDefinition: + payload = dict(tool or {}) + function = payload.get("function") if isinstance(payload.get("function"), dict) else payload + return ToolDefinition( + name=str(function.get("name") or ""), + description=function.get("description"), + input_schema=deepcopy(function.get("parameters") or function.get("input_schema") or {}), + type=str(payload.get("type") or "function"), + extra={"raw": deepcopy(tool), **_without(payload, {"type", "function"})}, + ) + + def _format_tool_definition(self, tool: ToolDefinition) -> dict[str, Any]: + raw = tool.extra.get("raw") + if isinstance(raw, dict): + return deepcopy(raw) + return { + "type": tool.type, + "function": { + "name": tool.name, + "description": tool.description, + "parameters": deepcopy(tool.input_schema), + }, + } + + def _parse_tool_call(self, call: dict[str, Any]) -> ToolCall: + payload = dict(call or {}) + function = payload.get("function") if isinstance(payload.get("function"), dict) else {} + arguments: Any = function.get("arguments") + return ToolCall( + id=payload.get("id"), + name=function.get("name") or payload.get("name"), + arguments=arguments, + type=str(payload.get("type") or "function"), + index=payload.get("index"), + raw=deepcopy(call), + extra={**_without(function, {"name", "arguments"}), **_without(payload, {"id", "function", "type", "index", "name"})}, + ) + + def _format_tool_call(self, call: ToolCall) -> dict[str, Any]: + payload = deepcopy(call.raw) if isinstance(call.raw, dict) else {} + payload["type"] = call.type + if call.id: + payload["id"] = call.id + if call.index is not None: + payload["index"] = call.index + function = deepcopy(payload.get("function")) if isinstance(payload.get("function"), dict) else {} + function["name"] = call.name or "" + function["arguments"] = _format_arguments(call.arguments) + payload["function"] = function + return payload + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return {} + + +def _decode_sse_data(raw_event: Any) -> Any: + if not isinstance(raw_event, str): + return raw_event + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + if text == "[DONE]": + return text + try: + return json.loads(text) + except json.JSONDecodeError: + return raw_event + + +def _extract_reasoning(payload: dict[str, Any]) -> list[ReasoningBlock]: + blocks = [] + for field_name in ("reasoning_content", "reasoning"): + value = payload.get(field_name) + if value: + blocks.append(ReasoningBlock(type=field_name, text=str(value), extra={"source_field": field_name})) + return blocks + + +def _format_arguments(arguments: Any) -> str: + if arguments is None: + return "" + if isinstance(arguments, str): + return arguments + return json.dumps(serialize_value(arguments), separators=(",", ":")) + + +def _format_openai_usage(usage: Usage | None) -> dict[str, Any] | None: + """Format normalized usage using OpenAI Chat's public field names.""" + + if usage is None: + return None + payload: dict[str, Any] = { + "prompt_tokens": usage.input_tokens, + "completion_tokens": usage.output_tokens, + "total_tokens": usage.total_tokens or (usage.input_tokens + usage.output_tokens), + } + prompt_details: dict[str, Any] = {} + if usage.cache_read_tokens: + prompt_details["cached_tokens"] = usage.cache_read_tokens + if usage.cache_write_tokens: + prompt_details["cache_creation_tokens"] = usage.cache_write_tokens + if prompt_details: + payload["prompt_tokens_details"] = prompt_details + completion_details: dict[str, Any] = {} + if usage.reasoning_tokens: + completion_details["reasoning_tokens"] = usage.reasoning_tokens + if completion_details: + payload["completion_tokens_details"] = completion_details + if usage.cost: + cost_details: dict[str, Any] = dict(usage.cost.metadata) + if usage.cost.provider_reported_cost is not None: + cost_details["total_cost"] = usage.cost.provider_reported_cost + elif usage.cost.estimated_cost is not None: + cost_details["estimated_cost"] = usage.cost.estimated_cost + cost_details["currency"] = usage.cost.currency + if usage.cost.source: + cost_details["source"] = usage.cost.source + payload["cost_details"] = cost_details + return payload + + +def _format_response_message(payload: dict[str, Any], message: UnifiedMessage) -> dict[str, Any]: + """Return Chat Completions response-message shape. + + Request messages may legitimately preserve content-part arrays. Assistant + response messages from non-chat native protocols often arrive as text parts; + Chat Completions clients expect the final message content to be a string in + that common case. + """ + + if payload.get("content") is not None and message.content: + if all(block.type in {"text", "input_text", "output_text"} and not block.extra for block in message.content): + payload = dict(payload) + payload["content"] = _first_response_text(message.content) or "" + return payload + + +def _first_response_text(blocks: Iterable[ContentBlock]) -> Optional[str]: + parts = [block.text for block in blocks if block.type in {"text", "input_text", "output_text"} and block.text] + return "".join(parts) if parts else first_text(blocks) + + +def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: + return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/src/rotator_library/protocols/openai_embeddings.py b/src/rotator_library/protocols/openai_embeddings.py new file mode 100644 index 000000000..1ba955b99 --- /dev/null +++ b/src/rotator_library/protocols/openai_embeddings.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI-compatible embeddings protocol adapter.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_EMBEDDINGS +from .types import ProtocolContext, UnifiedRequest, UnifiedResponse, Usage + +_REQUEST_CORE_FIELDS = {"model", "input", "encoding_format", "dimensions", "user", "operation"} +_REQUEST_OPTION_FIELDS = {"encoding_format", "dimensions", "user"} + + +class OpenAIEmbeddingsProtocol(ProtocolAdapter): + """Adapter for `/v1/embeddings` style request and response payloads. + + The adapter intentionally treats embedding vectors as opaque data entries. + That keeps it usable for OpenAI-compatible providers with additional index, + metadata, or sparse-vector fields without narrowing the schema too early. + """ + + name: ClassVar[str] = "openai_embeddings" + aliases: ClassVar[tuple[str, ...]] = ("embeddings", "openai_embedding") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_EMBEDDINGS,) + supported_transports: ClassVar[tuple[str, ...]] = ("http",) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + return UnifiedRequest( + operation=OPERATION_EMBEDDINGS, + model=str(request.get("model") or getattr(context, "model", None) or ""), + input=deepcopy(request.get("input")), + generation_params={k: deepcopy(request[k]) for k in _REQUEST_OPTION_FIELDS if k in request}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload = {"model": unified_request.model, "input": deepcopy(unified_request.input)} + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = raw_response if isinstance(raw_response, dict) else {} + return UnifiedResponse( + operation=OPERATION_EMBEDDINGS, + model=response.get("model") or getattr(context, "model", None), + data=deepcopy(response.get("data") or []), + usage=self.extract_usage(response, context), + raw=deepcopy(raw_response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "data", "usage"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + payload = {"object": "list", "data": deepcopy(unified_response.data)} + if unified_response.model: + payload["model"] = unified_response.model + if unified_response.usage: + payload["usage"] = unified_response.usage.raw or unified_response.usage.to_dict() + payload.update(deepcopy(unified_response.extra)) + return payload + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, UnifiedResponse): + return raw_or_unified.usage + usage = raw_or_unified.get("usage") if isinstance(raw_or_unified, dict) else None + if not isinstance(usage, dict): + return None + return Usage( + input_tokens=int(usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("total_tokens") or 0), + output_tokens=int(usage.get("output_tokens") or 0), + total_tokens=int(usage.get("total_tokens") or 0), + raw=deepcopy(usage), + ) diff --git a/src/rotator_library/protocols/openai_images.py b/src/rotator_library/protocols/openai_images.py new file mode 100644 index 000000000..fbf382dbb --- /dev/null +++ b/src/rotator_library/protocols/openai_images.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI-compatible image generation/edit/variation protocol adapter.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_IMAGE_EDIT, OPERATION_IMAGE_GENERATION, OPERATION_IMAGE_VARIATION, normalize_operation +from .types import ProtocolContext, UnifiedRequest, UnifiedResponse + +_OPTION_FIELDS = {"n", "size", "quality", "style", "response_format", "user", "background", "moderation"} +_CORE_FIELDS = {"operation", "model", "prompt", "image", "mask", *_OPTION_FIELDS} + + +class OpenAIImagesProtocol(ProtocolAdapter): + """Adapter for image generation, edit, and variation request shapes. + + File references are preserved as metadata in ``UnifiedRequest.files``. The + adapter never reads file contents; multipart assembly belongs to transport or + provider execution code so protocol parsing remains side-effect free. + """ + + name: ClassVar[str] = "openai_images" + aliases: ClassVar[tuple[str, ...]] = ("images", "image_generation", "openai_image") + supported_operations: ClassVar[tuple[str, ...]] = ( + OPERATION_IMAGE_GENERATION, + OPERATION_IMAGE_EDIT, + OPERATION_IMAGE_VARIATION, + ) + supported_transports: ClassVar[tuple[str, ...]] = ("http",) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + operation = _image_operation(request) + files = [] + if "image" in request: + files.append({"field": "image", "value": deepcopy(request["image"])}) + if "mask" in request: + files.append({"field": "mask", "value": deepcopy(request["mask"])}) + return UnifiedRequest( + operation=operation, + model=str(request.get("model") or getattr(context, "model", None) or ""), + input=deepcopy(request.get("prompt")), + files=files, + generation_params={k: deepcopy(request[k]) for k in _OPTION_FIELDS if k in request}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = {} + if unified_request.model: + payload["model"] = unified_request.model + if unified_request.input is not None: + payload["prompt"] = deepcopy(unified_request.input) + for file_entry in unified_request.files: + if isinstance(file_entry, dict) and file_entry.get("field"): + payload[str(file_entry["field"])] = deepcopy(file_entry.get("value")) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = raw_response if isinstance(raw_response, dict) else {} + return UnifiedResponse( + operation=_context_operation(context, OPERATION_IMAGE_GENERATION), + model=response.get("model") or getattr(context, "model", None), + data=deepcopy(response.get("data") or []), + raw=deepcopy(raw_response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "data"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + payload = {"data": deepcopy(unified_response.data)} + if unified_response.model: + payload["model"] = unified_response.model + payload.update(deepcopy(unified_response.extra)) + return payload + + +def _image_operation(request: dict[str, Any]) -> str: + explicit = normalize_operation(request.get("operation")) + if explicit in {OPERATION_IMAGE_GENERATION, OPERATION_IMAGE_EDIT, OPERATION_IMAGE_VARIATION}: + return explicit + if "image" in request and "prompt" in request: + return OPERATION_IMAGE_EDIT + if "image" in request: + return OPERATION_IMAGE_VARIATION + return OPERATION_IMAGE_GENERATION + + +def _context_operation(context: ProtocolContext | None, default: str) -> str: + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation in {OPERATION_IMAGE_GENERATION, OPERATION_IMAGE_EDIT, OPERATION_IMAGE_VARIATION}: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation in {OPERATION_IMAGE_GENERATION, OPERATION_IMAGE_EDIT, OPERATION_IMAGE_VARIATION}: + return operation + return default diff --git a/src/rotator_library/protocols/operation.py b/src/rotator_library/protocols/operation.py new file mode 100644 index 000000000..91d7cb5b8 --- /dev/null +++ b/src/rotator_library/protocols/operation.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Shared protocol operation names. + +Operations are deliberately plain strings instead of a closed enum. The native +protocol system is meant to be extended by local/custom providers, so the core +must provide well-known names without blocking new operations that are not known +today. +""" + +from __future__ import annotations + +from typing import Final + +OPERATION_UNKNOWN: Final[str] = "unknown" +OPERATION_CHAT: Final[str] = "chat" +OPERATION_MESSAGES: Final[str] = "messages" +OPERATION_RESPONSES: Final[str] = "responses" +OPERATION_COUNT_TOKENS: Final[str] = "count_tokens" +OPERATION_EMBEDDINGS: Final[str] = "embeddings" +OPERATION_IMAGE_GENERATION: Final[str] = "image_generation" +OPERATION_IMAGE_EDIT: Final[str] = "image_edit" +OPERATION_IMAGE_VARIATION: Final[str] = "image_variation" +OPERATION_AUDIO_TRANSCRIPTION: Final[str] = "audio_transcription" +OPERATION_AUDIO_TRANSLATION: Final[str] = "audio_translation" +OPERATION_SPEECH: Final[str] = "speech" +OPERATION_OLLAMA_CHAT: Final[str] = "ollama_chat" +OPERATION_OLLAMA_GENERATE: Final[str] = "ollama_generate" +OPERATION_MCP: Final[str] = "mcp" + + +def normalize_operation(operation: str | None) -> str: + """Normalize an operation name while preserving custom extensions.""" + + if not operation: + return OPERATION_UNKNOWN + return str(operation).strip().lower() or OPERATION_UNKNOWN diff --git a/src/rotator_library/protocols/registry.py b/src/rotator_library/protocols/registry.py new file mode 100644 index 000000000..a27a192fa --- /dev/null +++ b/src/rotator_library/protocols/registry.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Auto-discovery registry for native protocol adapters.""" + +from __future__ import annotations + +import importlib +import inspect +import logging +import pkgutil +from typing import Type + +from .base import ProtocolAdapter + +lib_logger = logging.getLogger("rotator_library") + +PROTOCOL_PLUGINS: dict[str, Type[ProtocolAdapter]] = {} +PROTOCOL_ALIASES: dict[str, str] = {} +_PROTOCOL_INSTANCES: dict[str, ProtocolAdapter] = {} + +_INFRASTRUCTURE_MODULES = {"base", "operation", "registry", "types"} + + +def register_protocol(protocol_class: Type[ProtocolAdapter], *, replace: bool = False) -> Type[ProtocolAdapter]: + """Register a protocol adapter class and its aliases. + + The registry mirrors the provider plugin system while staying stricter about + duplicate names. This matters for custom protocol modules: accidental alias + collisions should fail early instead of silently changing conversion logic. + """ + + if not inspect.isclass(protocol_class) or not issubclass(protocol_class, ProtocolAdapter): + raise TypeError("protocol_class must inherit ProtocolAdapter") + if protocol_class is ProtocolAdapter: + raise TypeError("cannot register ProtocolAdapter itself") + + name = protocol_class.name + if not name: + raise ValueError(f"Protocol {protocol_class.__name__} must define a name") + alias_owner = PROTOCOL_ALIASES.get(name) + if alias_owner and alias_owner != name and not replace: + raise ValueError(f"Protocol name conflicts with registered alias: {name}") + + existing = PROTOCOL_PLUGINS.get(name) + if existing and existing is not protocol_class and not replace: + raise ValueError(f"Protocol name already registered: {name}") + + if replace and existing and existing is not protocol_class: + for alias, owner in list(PROTOCOL_ALIASES.items()): + if owner == name: + PROTOCOL_ALIASES.pop(alias, None) + PROTOCOL_PLUGINS[name] = protocol_class + _PROTOCOL_INSTANCES.pop(name, None) + + for alias in protocol_class.aliases: + existing_name = PROTOCOL_ALIASES.get(alias) + if existing_name and existing_name != name and not replace: + raise ValueError(f"Protocol alias already registered: {alias}") + if alias in PROTOCOL_PLUGINS and alias != name and not replace: + raise ValueError(f"Protocol alias conflicts with registered name: {alias}") + PROTOCOL_ALIASES[alias] = name + + lib_logger.debug("Registered protocol: %s", name) + return protocol_class + + +def resolve_protocol_name(name: str) -> str: + """Resolve aliases to canonical protocol names.""" + + if name in PROTOCOL_PLUGINS: + return name + if name in PROTOCOL_ALIASES: + return PROTOCOL_ALIASES[name] + raise KeyError(f"Unknown protocol: {name}") + + +def get_protocol_class(name: str) -> Type[ProtocolAdapter]: + """Return a registered protocol adapter class by name or alias.""" + + return PROTOCOL_PLUGINS[resolve_protocol_name(name)] + + +def get_protocol(name: str) -> ProtocolAdapter: + """Return a shared stateless protocol adapter instance by name or alias.""" + + canonical = resolve_protocol_name(name) + if canonical not in _PROTOCOL_INSTANCES: + _PROTOCOL_INSTANCES[canonical] = PROTOCOL_PLUGINS[canonical]() + return _PROTOCOL_INSTANCES[canonical] + + +def list_protocols() -> list[str]: + """Return canonical protocol names in deterministic order.""" + + return sorted(PROTOCOL_PLUGINS) + + +def _register_protocols() -> None: + """Discover protocol modules in this package and register adapter classes. + + Private modules and infrastructure modules are skipped so local experiments + can live next to production protocols without being imported accidentally. + """ + + package = importlib.import_module(__package__ or "rotator_library.protocols") + for _, module_name, _ in pkgutil.iter_modules(package.__path__): + if module_name.startswith("_") or module_name in _INFRASTRUCTURE_MODULES: + continue + module = importlib.import_module(f"{package.__name__}.{module_name}") + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if ( + inspect.isclass(attribute) + and issubclass(attribute, ProtocolAdapter) + and attribute is not ProtocolAdapter + and attribute.__module__ == module.__name__ + ): + register_protocol(attribute) + + +_register_protocols() diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py new file mode 100644 index 000000000..170b7cb46 --- /dev/null +++ b/src/rotator_library/protocols/responses.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI Responses protocol adapter. + +Responses is important enough to model natively rather than forcing it through a +chat-completions shape. This adapter focuses on loss-conscious parsing and +formatting; storage, routes, and WebSocket transport are later phases. +""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar, Iterable + +from .base import ProtocolAdapter +from .operation import OPERATION_RESPONSES +from .types import ( + ContentBlock, + CostDetails, + ProtocolContext, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, + first_text, + text_blocks, +) + +_GENERATION_PARAMS = { + "include", + "instructions", + "max_output_tokens", + "parallel_tool_calls", + "reasoning", + "store", + "temperature", + "text", + "tool_choice", + "top_p", + "truncation", + "user", +} + +_REQUEST_CORE_FIELDS = { + "model", + "input", + "metadata", + "previous_response_id", + "stream", + "tools", + *_GENERATION_PARAMS, +} + + +class ResponsesProtocol(ProtocolAdapter): + """Adapter for OpenAI Responses request, response, and event stream shapes. + + The protocol keeps output items in addition to parsed messages because later + response storage and continuation features need item-level fidelity. + """ + + name: ClassVar[str] = "responses" + aliases: ClassVar[tuple[str, ...]] = ("openai_responses", "response_api") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_RESPONSES,) + future_transports: ClassVar[tuple[str, ...]] = ("websocket",) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + generation_params = {k: deepcopy(request[k]) for k in _GENERATION_PARAMS if k in request and k != "instructions"} + return UnifiedRequest( + operation=OPERATION_RESPONSES, + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=self._parse_input(request.get("input")), + system=text_blocks(request.get("instructions")) if request.get("instructions") is not None else [], + tools=[self._parse_tool(tool) for tool in request.get("tools") or []], + stream=bool(request.get("stream", False)), + generation_params=generation_params, + previous_response_id=request.get("previous_response_id"), + metadata=deepcopy(request.get("metadata") or {}), + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": unified_request.model, + "input": [self._format_input_message(message) for message in unified_request.messages], + } + instructions = first_text(unified_request.system) + if instructions is not None: + payload["instructions"] = instructions + if unified_request.previous_response_id: + payload["previous_response_id"] = unified_request.previous_response_id + if unified_request.tools: + payload["tools"] = [self._format_tool(tool) for tool in unified_request.tools] + if unified_request.stream: + payload["stream"] = True + if unified_request.metadata: + payload["metadata"] = deepcopy(unified_request.metadata) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = _as_dict(raw_response) + output = deepcopy(response.get("output") or []) + messages: list[UnifiedMessage] = [] + for index, item in enumerate(output): + if isinstance(item, dict): + parsed = self._parse_output_item(item) + if parsed: + parsed.extra["_output_index"] = index + messages.append(parsed) + return UnifiedResponse( + operation=OPERATION_RESPONSES, + id=response.get("id"), + model=response.get("model") or getattr(context, "model", None), + messages=messages, + output=output, + stop_reason=response.get("status"), + usage=self.extract_usage(response, context), + metadata={"object": response.get("object"), "created_at": response.get("created_at")}, + raw=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"id", "object", "created_at", "model", "output", "usage", "status"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if unified_response.output: + output = deepcopy(unified_response.output) + for fallback_index, message in enumerate(unified_response.messages): + output_index = message.extra.get("_output_index", fallback_index) + if isinstance(output_index, int) and 0 <= output_index < len(output): + output[output_index] = self._format_output_message(message, output_index) + else: + output.append(self._format_output_message(message, fallback_index)) + else: + output = [self._format_output_message(message, index) for index, message in enumerate(unified_response.messages)] + payload = { + "id": unified_response.id, + "object": unified_response.metadata.get("object", "response"), + "created_at": unified_response.metadata.get("created_at"), + "model": unified_response.model, + "status": unified_response.stop_reason, + "output": output, + "usage": _format_responses_usage(unified_response.usage), + } + payload.update(deepcopy(unified_response.extra)) + return {k: v for k, v in payload.items() if v is not None} + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + event = _decode_sse_data(raw_event) + if event == "[DONE]": + return UnifiedStreamEvent(type="done", operation=OPERATION_RESPONSES, raw=deepcopy(raw_event)) + data = _as_dict(event) + event_type = str(data.get("type") or data.get("event") or "chunk") + if event_type in {"error", "response.error"} or data.get("error") is not None: + return UnifiedStreamEvent(type="error", operation=OPERATION_RESPONSES, error=deepcopy(data.get("error", data)), raw=deepcopy(raw_event), extra={"payload": data}) + if event_type in {"response.completed", "response.failed", "response.incomplete"}: + response = self.parse_response(data.get("response") or {}, context) + return UnifiedStreamEvent(type=event_type, operation=OPERATION_RESPONSES, message=response.messages[0] if response.messages else None, usage=response.usage, raw=deepcopy(raw_event), extra={"payload": data}) + if event_type == "response.output_text.delta": + message = UnifiedMessage(role="assistant", content=text_blocks(data.get("delta") or "")) + return UnifiedStreamEvent(type="message_delta", operation=OPERATION_RESPONSES, delta=message, raw=deepcopy(raw_event), extra={"payload": data, "output_index": data.get("output_index"), "content_index": data.get("content_index")}) + if event_type in {"response.output_item.added", "response.output_item.done"} and isinstance(data.get("item"), dict): + message = self._parse_output_item(data["item"]) + return UnifiedStreamEvent(type=event_type, operation=OPERATION_RESPONSES, message=message, raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type=event_type, operation=OPERATION_RESPONSES, raw=deepcopy(raw_event), extra={"payload": data}) + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): + return raw_or_unified.usage + payload = _as_dict(raw_or_unified) + usage = payload.get("usage") if isinstance(payload.get("usage"), dict) else payload + if not isinstance(usage, dict) or not any(key.endswith("tokens") for key in usage): + return None + input_details = usage.get("input_tokens_details") if isinstance(usage.get("input_tokens_details"), dict) else {} + output_details = usage.get("output_tokens_details") if isinstance(usage.get("output_tokens_details"), dict) else {} + cost = None + cost_details = usage.get("cost_details") + if isinstance(cost_details, dict): + provider_cost = cost_details.get("total_cost") or cost_details.get("request_cost_usd") or cost_details.get("cost") or cost_details.get("estimated_cost") + cost = CostDetails( + provider_reported_cost=float(provider_cost) if provider_cost is not None else None, + currency=str(cost_details.get("currency") or "USD"), + source="usage.cost_details", + metadata={k: deepcopy(v) for k, v in cost_details.items() if k not in {"total_cost", "cost", "currency"}}, + ) + return Usage( + input_tokens=int(usage.get("input_tokens") or 0), + output_tokens=int(usage.get("output_tokens") or 0), + total_tokens=int(usage.get("total_tokens") or 0), + cache_read_tokens=int(input_details.get("cached_tokens") or 0), + cache_write_tokens=int(input_details.get("cache_creation_tokens") or usage.get("cache_creation_tokens") or 0), + reasoning_tokens=int(output_details.get("reasoning_tokens") or 0), + cost=cost, + raw=deepcopy(usage), + ) + + def _parse_input(self, input_value: Any) -> list[UnifiedMessage]: + if input_value is None: + return [] + if isinstance(input_value, str): + return [UnifiedMessage(role="user", content=text_blocks(input_value), raw=input_value)] + if not isinstance(input_value, list): + return [UnifiedMessage(role="user", content=[ContentBlock(type="unknown", raw=deepcopy(input_value))], raw=deepcopy(input_value))] + messages = [] + for item in input_value: + if isinstance(item, dict): + messages.append(self._parse_input_item(item)) + else: + messages.append(UnifiedMessage(role="user", content=text_blocks(str(item)), raw=deepcopy(item))) + return messages + + def _parse_input_item(self, item: dict[str, Any]) -> UnifiedMessage: + item_type = item.get("type") + if item_type in {"message", None}: + return UnifiedMessage( + role=str(item.get("role") or "user"), + content=self._parse_content(item.get("content")), + raw=deepcopy(item), + extra={k: deepcopy(v) for k, v in item.items() if k not in {"type", "role", "content"}}, + ) + if item_type == "function_call_output": + return UnifiedMessage( + role="tool", + content=[ContentBlock(type="tool_result", tool_result=ToolResult(tool_call_id=item.get("call_id"), content=item.get("output")), raw=deepcopy(item))], + tool_call_id=item.get("call_id"), + raw=deepcopy(item), + ) + return UnifiedMessage(role=str(item.get("role") or "user"), content=[ContentBlock(type=str(item_type or "unknown"), raw=deepcopy(item))], raw=deepcopy(item)) + + def _format_input_message(self, message: UnifiedMessage) -> dict[str, Any]: + if isinstance(message.raw, dict): + payload = deepcopy(message.raw) + if payload.get("type") == "function_call_output": + payload["call_id"] = message.tool_call_id or payload.get("call_id") + result = message.content[0].tool_result if message.content and message.content[0].tool_result else None + if result: + payload["output"] = deepcopy(result.content) + return payload + payload["role"] = message.role + payload["content"] = self._format_content(message.content) + return payload + return {"type": "message", "role": message.role, "content": self._format_content(message.content)} + + def _parse_output_item(self, item: dict[str, Any]) -> UnifiedMessage | None: + item_type = item.get("type") + if item_type == "message": + return UnifiedMessage( + role=str(item.get("role") or "assistant"), + content=self._parse_content(item.get("content")), + raw=deepcopy(item), + extra={k: deepcopy(v) for k, v in item.items() if k not in {"type", "role", "content"}}, + ) + if item_type == "reasoning": + reasoning = ReasoningBlock(type="reasoning", text=_reasoning_text(item), extra={k: deepcopy(v) for k, v in item.items() if k not in {"type", "summary"}}) + reasoning.raw = deepcopy(item) + return UnifiedMessage(role="assistant", content=[ContentBlock(type="reasoning", reasoning=reasoning, raw=deepcopy(item))], reasoning=[reasoning], raw=deepcopy(item)) + if item_type in {"function_call", "custom_tool_call"}: + call = ToolCall(id=item.get("call_id") or item.get("id"), name=item.get("name"), arguments=item.get("arguments") or item.get("input"), type=str(item_type), raw=deepcopy(item)) + return UnifiedMessage(role="assistant", content=[ContentBlock(type=str(item_type), tool_call=call, raw=deepcopy(item))], tool_calls=[call], raw=deepcopy(item)) + return None + + def _format_output_message(self, message: UnifiedMessage, index: int) -> dict[str, Any]: + if isinstance(message.raw, dict): + payload = deepcopy(message.raw) + item_type = payload.get("type") + if item_type == "message": + payload["role"] = message.role + payload["content"] = self._format_content(message.content) + return payload + if item_type == "reasoning" and message.reasoning: + payload["summary"] = [{"type": "summary_text", "text": message.reasoning[0].text or ""}] + return payload + if item_type in {"function_call", "custom_tool_call"} and message.tool_calls: + call = message.tool_calls[0] + payload["call_id"] = call.id + payload["name"] = call.name + payload["arguments"] = deepcopy(call.arguments) + return payload + return {"id": f"msg_{index}", "type": "message", "role": message.role, "content": self._format_content(message.content)} + + def _parse_content(self, content: Any) -> list[ContentBlock]: + if content is None: + return [] + if isinstance(content, str): + return text_blocks(content) + if not isinstance(content, list): + return [ContentBlock(type="unknown", raw=deepcopy(content))] + blocks = [] + for block in content: + if isinstance(block, str): + blocks.append(ContentBlock(type="input_text", text=block, raw=block)) + continue + if not isinstance(block, dict): + blocks.append(ContentBlock(type="unknown", raw=deepcopy(block))) + continue + block_type = str(block.get("type") or "text") + if block_type in {"input_text", "output_text", "text"}: + blocks.append(ContentBlock(type=block_type, text=block.get("text", ""), raw=deepcopy(block), extra=_without(block, {"type", "text"}))) + elif block_type in {"input_image", "image_url"}: + blocks.append(ContentBlock(type=block_type, source=deepcopy(block.get("image_url") or block.get("source")), raw=deepcopy(block), extra=_without(block, {"type", "image_url", "source"}))) + else: + blocks.append(ContentBlock(type=block_type, raw=deepcopy(block), extra=_without(block, {"type"}))) + return blocks + + def _format_content(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: + formatted = [] + for block in blocks: + if block.type in {"input_text", "output_text", "text"}: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": block.type} + payload["type"] = block.type + payload["text"] = block.text or "" + payload.update(deepcopy(block.extra)) + formatted.append(payload) + elif block.type in {"input_image", "image_url"}: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": block.type} + payload["type"] = block.type + payload["image_url"] = deepcopy(block.source) + payload.update(deepcopy(block.extra)) + formatted.append(payload) + else: + payload = {"type": block.type} + payload.update(deepcopy(block.extra)) + formatted.append(payload) + return formatted + + def _parse_tool(self, tool: dict[str, Any]) -> ToolDefinition: + payload = dict(tool or {}) + parameters = payload.get("parameters") or payload.get("input_schema") or {} + return ToolDefinition( + name=str(payload.get("name") or ""), + description=payload.get("description"), + input_schema=deepcopy(parameters), + type=str(payload.get("type") or "function"), + extra={k: deepcopy(v) for k, v in payload.items() if k not in {"type", "name", "description", "parameters", "input_schema"}}, + ) + + def _format_tool(self, tool: ToolDefinition) -> dict[str, Any]: + payload = {"type": tool.type, "name": tool.name, "parameters": deepcopy(tool.input_schema)} + if tool.description is not None: + payload["description"] = tool.description + payload.update(deepcopy(tool.extra)) + return payload + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return {} + + +def _decode_sse_data(raw_event: Any) -> Any: + if not isinstance(raw_event, str): + return raw_event + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + if text == "[DONE]": + return text + try: + return json.loads(text) + except json.JSONDecodeError: + return raw_event + + +def _reasoning_text(item: dict[str, Any]) -> str | None: + summary = item.get("summary") + if isinstance(summary, list): + parts = [] + for part in summary: + if isinstance(part, dict) and part.get("text"): + parts.append(str(part["text"])) + elif isinstance(part, str): + parts.append(part) + return "".join(parts) if parts else None + return str(summary) if summary else None + + +def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: + return {k: deepcopy(v) for k, v in payload.items() if k not in keys} + + +def _format_responses_usage(usage: Usage | None) -> dict[str, Any] | None: + """Format normalized usage using OpenAI Responses public field names.""" + + if usage is None: + return None + payload: dict[str, Any] = { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "total_tokens": usage.total_tokens or (usage.input_tokens + usage.output_tokens), + } + input_details: dict[str, Any] = {} + if usage.cache_read_tokens: + input_details["cached_tokens"] = usage.cache_read_tokens + if usage.cache_write_tokens: + # OpenAI Responses does not have a universal cache-write field, but this + # extension keeps provider-reported cache creation visible without + # leaking the unified internal `cache_write_tokens` key. + input_details["cache_creation_tokens"] = usage.cache_write_tokens + if input_details: + payload["input_tokens_details"] = input_details + output_details: dict[str, Any] = {} + if usage.reasoning_tokens: + output_details["reasoning_tokens"] = usage.reasoning_tokens + if output_details: + payload["output_tokens_details"] = output_details + if usage.cost: + cost_details: dict[str, Any] = dict(usage.cost.metadata) + if usage.cost.provider_reported_cost is not None: + cost_details["total_cost"] = usage.cost.provider_reported_cost + elif usage.cost.estimated_cost is not None: + cost_details["estimated_cost"] = usage.cost.estimated_cost + cost_details["currency"] = usage.cost.currency + if usage.cost.source: + cost_details["source"] = usage.cost.source + payload["cost_details"] = cost_details + return payload diff --git a/src/rotator_library/protocols/types.py b/src/rotator_library/protocols/types.py new file mode 100644 index 000000000..b0fda21af --- /dev/null +++ b/src/rotator_library/protocols/types.py @@ -0,0 +1,429 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Protocol-neutral data structures used by native protocol adapters. + +These types intentionally model common LLM API concepts without pretending the +set is complete. Providers and future protocol implementations can preserve +non-standard fields in ``extra`` or ``raw`` instead of dropping them. That +preservation is important for transform-pass logging, field-cache rules, and +provider-specific overrides added in later experimental phases. +""" + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import asdict, dataclass, field, is_dataclass +from datetime import date, datetime +from decimal import Decimal +from pathlib import Path +from typing import Any, ClassVar, Iterable, Mapping, Optional + +from .operation import OPERATION_UNKNOWN + + +JsonObject = dict[str, Any] + + +def serialize_value(value: Any) -> Any: + """Return a JSON-friendly copy of a protocol value. + + The transaction logging phase needs reliable snapshots after every protocol + and adapter pass. This helper keeps that concern centralized and avoids + mutating live request/response objects while preparing logs or fixtures. + """ + + if isinstance(value, ProtocolSerializable): + return value.to_dict() + if is_dataclass(value): + return serialize_value(asdict(value)) + if isinstance(value, Mapping): + return {str(k): serialize_value(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set, frozenset)): + return [serialize_value(v) for v in value] + if isinstance(value, bytes): + try: + return value.decode("utf-8") + except UnicodeDecodeError: + return value.hex() + if isinstance(value, (datetime, date)): + return value.isoformat() + if isinstance(value, Decimal): + return float(value) + if isinstance(value, Path): + return str(value) + try: + import json + + json.dumps(value) + return deepcopy(value) + except (TypeError, ValueError): + return repr(value) + + +def copy_mapping(value: Optional[Mapping[str, Any]]) -> JsonObject: + """Return a deep-copied dict for extension fields.""" + + return serialize_value(dict(value or {})) + + +class ProtocolSerializable: + """Mixin for dataclasses that need stable dict serialization. + + Concrete protocol types define ``_fields`` so ``to_dict`` stays explicit and + future additions do not accidentally disappear from transform logs. + """ + + _fields: ClassVar[tuple[str, ...]] = () + + def to_dict(self) -> JsonObject: + return {field_name: serialize_value(getattr(self, field_name)) for field_name in self._fields} + + +@dataclass +class CostDetails(ProtocolSerializable): + """Normalized cost metadata from provider-reported or estimated sources.""" + + provider_reported_cost: Optional[float] = None + estimated_cost: Optional[float] = None + currency: str = "USD" + source: Optional[str] = None + metadata: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "provider_reported_cost", + "estimated_cost", + "currency", + "source", + "metadata", + ) + + +@dataclass +class Usage(ProtocolSerializable): + """Protocol-neutral token and cost usage values. + + Existing usage tracking has provider-specific extraction paths. Native + protocols use this shape first, then later phases can normalize it into the + current usage manager without replacing that engine. + """ + + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 + reasoning_tokens: int = 0 + cost: Optional[CostDetails] = None + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "input_tokens", + "output_tokens", + "total_tokens", + "cache_read_tokens", + "cache_write_tokens", + "reasoning_tokens", + "cost", + "raw", + "extra", + ) + + def __post_init__(self) -> None: + if self.total_tokens <= 0: + # Most APIs report reasoning/thinking tokens as a detail of output + # tokens. Protocol-specific normalizers can provide a larger total + # when a provider documents reasoning as a separate bucket. + self.total_tokens = self.input_tokens + self.output_tokens + + +@dataclass +class ReasoningBlock(ProtocolSerializable): + """Reasoning/thinking content and signatures preserved across protocols.""" + + type: str = "reasoning" + text: Optional[str] = None + signature: Optional[str] = None + redacted: bool = False + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ("type", "text", "signature", "redacted", "raw", "extra") + + +@dataclass +class ToolCall(ProtocolSerializable): + """Protocol-neutral tool/function call emitted by an assistant.""" + + id: Optional[str] = None + name: Optional[str] = None + arguments: Any = None + type: str = "function" + index: Optional[int] = None + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ("id", "name", "arguments", "type", "index", "raw", "extra") + + +@dataclass +class ToolResult(ProtocolSerializable): + """Protocol-neutral result associated with a prior tool call.""" + + tool_call_id: Optional[str] = None + content: Any = None + is_error: Optional[bool] = None + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ("tool_call_id", "content", "is_error", "raw", "extra") + + +@dataclass +class ToolDefinition(ProtocolSerializable): + """Protocol-neutral tool schema exposed to a model.""" + + name: str = "" + description: Optional[str] = None + input_schema: JsonObject = field(default_factory=dict) + type: str = "function" + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ("name", "description", "input_schema", "type", "extra") + + +@dataclass +class ContentBlock(ProtocolSerializable): + """A single message content block. + + ``type`` follows the source protocol when practical. The dedicated fields + cover common text, image, document, tool, and reasoning cases, while ``extra`` + keeps provider-specific payloads for later field-cache extraction. + """ + + type: str = "text" + text: Optional[str] = None + source: Any = None + tool_call: Optional[ToolCall] = None + tool_result: Optional[ToolResult] = None + reasoning: Optional[ReasoningBlock] = None + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "type", + "text", + "source", + "tool_call", + "tool_result", + "reasoning", + "raw", + "extra", + ) + + +@dataclass +class UnifiedMessage(ProtocolSerializable): + """A protocol-neutral chat/message turn.""" + + role: str + content: list[ContentBlock] = field(default_factory=list) + name: Optional[str] = None + tool_call_id: Optional[str] = None + tool_calls: list[ToolCall] = field(default_factory=list) + reasoning: list[ReasoningBlock] = field(default_factory=list) + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "role", + "content", + "name", + "tool_call_id", + "tool_calls", + "reasoning", + "raw", + "extra", + ) + + +@dataclass +class UnifiedRequest(ProtocolSerializable): + """A request after parsing from a client or provider protocol.""" + + operation: str = OPERATION_UNKNOWN + model: str = "" + messages: list[UnifiedMessage] = field(default_factory=list) + system: list[ContentBlock] = field(default_factory=list) + tools: list[ToolDefinition] = field(default_factory=list) + stream: bool = False + input: Any = None + modalities: list[str] = field(default_factory=list) + files: list[Any] = field(default_factory=list) + generation_params: JsonObject = field(default_factory=dict) + response_format: Any = None + previous_response_id: Optional[str] = None + metadata: JsonObject = field(default_factory=dict) + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "operation", + "model", + "messages", + "system", + "tools", + "stream", + "input", + "modalities", + "files", + "generation_params", + "response_format", + "previous_response_id", + "metadata", + "raw", + "extra", + ) + + +@dataclass +class UnifiedResponse(ProtocolSerializable): + """A complete provider/client response in protocol-neutral form.""" + + operation: str = OPERATION_UNKNOWN + id: Optional[str] = None + model: Optional[str] = None + messages: list[UnifiedMessage] = field(default_factory=list) + output: list[Any] = field(default_factory=list) + data: list[Any] = field(default_factory=list) + content_type: Optional[str] = None + stop_reason: Optional[str] = None + usage: Optional[Usage] = None + metadata: JsonObject = field(default_factory=dict) + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "operation", + "id", + "model", + "messages", + "output", + "data", + "content_type", + "stop_reason", + "usage", + "metadata", + "raw", + "extra", + ) + + +@dataclass +class UnifiedStreamEvent(ProtocolSerializable): + """A single protocol-neutral stream event. + + Future SSE and WebSocket transports should consume this type instead of raw + provider chunks so transport code can stay independent from protocol parsing. + """ + + type: str + operation: str = OPERATION_UNKNOWN + delta: Optional[UnifiedMessage] = None + message: Optional[UnifiedMessage] = None + tool_call: Optional[ToolCall] = None + usage: Optional[Usage] = None + error: Any = None + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "type", + "operation", + "delta", + "message", + "tool_call", + "usage", + "error", + "raw", + "extra", + ) + + +@dataclass +class ProtocolContext(ProtocolSerializable): + """Execution context passed through protocol methods. + + Only a small subset is needed in Phase 1, but the fields anticipate later + provider overrides, transaction tracing, field-cache scoping, and transport + selection without forcing those systems to exist yet. + """ + + provider: Optional[str] = None + model: Optional[str] = None + source_protocol: Optional[str] = None + target_protocol: Optional[str] = None + request_id: Optional[str] = None + session_id: Optional[str] = None + credential_stable_id: Optional[str] = None + transport: str = "http" + provider_options: JsonObject = field(default_factory=dict) + metadata: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "provider", + "model", + "source_protocol", + "target_protocol", + "request_id", + "session_id", + "credential_stable_id", + "transport", + "provider_options", + "metadata", + ) + + +class ProtocolError(ValueError): + """Error raised by protocol parsing/building passes.""" + + def __init__( + self, + message: str, + *, + protocol: str, + pass_name: str, + payload: Any = None, + ): + self.protocol = protocol + self.pass_name = pass_name + self.payload_preview = _payload_preview(payload) + details = f"{protocol}.{pass_name}: {message}" + if self.payload_preview is not None: + details = f"{details} | payload={self.payload_preview}" + super().__init__(details) + + +def _payload_preview(payload: Any, limit: int = 500) -> Optional[str]: + if payload is None: + return None + text = repr(serialize_value(payload)) + if len(text) > limit: + return text[: limit - 3] + "..." + return text + + +def text_blocks(text: Optional[str]) -> list[ContentBlock]: + """Return a single text block list for simple string content.""" + + if text is None: + return [] + return [ContentBlock(type="text", text=str(text))] + + +def first_text(blocks: Iterable[ContentBlock]) -> Optional[str]: + """Return concatenated text from content blocks, or ``None`` if absent.""" + + parts = [block.text for block in blocks if block.type == "text" and block.text] + return "".join(parts) if parts else None diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py new file mode 100644 index 000000000..5b871f34b --- /dev/null +++ b/src/rotator_library/providers/antigravity_provider.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Antigravity provider integration restored from safe retired pieces.""" + +from __future__ import annotations + +import os +from typing import Any, List, Optional + +import httpx + +from ..field_cache import FieldCacheInjection, FieldCacheRule +from .provider_interface import ProviderInterface + +BASE_URLS = [ + "https://daily-cloudcode-pa.sandbox.googleapis.com/v1internal", + "https://daily-cloudcode-pa.googleapis.com/v1internal", + "https://cloudcode-pa.googleapis.com/v1internal", +] +ANTIGRAVITY_HEADERS = { + "User-Agent": "antigravity/1.15.8 windows/amd64", + "X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1", + "Client-Metadata": '{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}', +} +AVAILABLE_MODELS = [ + "gemini-2.5-flash", + "gemini-2.5-flash-lite", + "gemini-3-pro-preview", + "gemini-3-flash", + "claude-sonnet-4.5", + "claude-opus-4.5", + "claude-opus-4.6", +] +MODEL_ALIAS_MAP = { + "rev19-uic3-1p": "gemini-2.5-computer-use-preview-10-2025", + "gemini-3-pro-image": "gemini-3-pro-image-preview", + "gemini-3-pro-low": "gemini-3-pro-preview", + "gemini-3-pro-high": "gemini-3-pro-preview", + "claude-sonnet-4-5": "claude-sonnet-4.5", + "claude-opus-4-5": "claude-opus-4.5", + "claude-opus-4-6": "claude-opus-4.6", +} +_AMBIGUOUS_REVERSE_ALIASES = {"gemini-3-pro-preview"} +MODEL_ALIAS_REVERSE = {public: internal for internal, public in MODEL_ALIAS_MAP.items() if public not in _AMBIGUOUS_REVERSE_ALIASES} +EXCLUDED_MODELS = {"chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-2.5-pro"} + + +class AntigravityProvider(ProviderInterface): + """Safe restored Antigravity integration path. + + The retired provider contained valuable model/header/endpoint knowledge mixed + with fragile device-profile and monolithic transform logic. This active + skeleton restores only stable declarations and helpers so native provider + work can proceed behind tests before any live routing is enabled. + """ + + provider_env_name = "antigravity" + protocol_name = "gemini" + adapter_names: tuple[str, ...] = ("antigravity_envelope",) + field_cache_rules = ( + FieldCacheRule( + name="antigravity_thought_signature", + source="response", + path="candidates.*.content.parts.*.thoughtSignature", + mode="all", + scope=("provider", "model", "credential", "session"), + inject=FieldCacheInjection(target="request", path="request.metadata.thoughtSignatures", as_list=True), + metadata={"purpose": "preserve Gemini thought signatures across Antigravity turns"}, + ), + ) + native_streaming_supported = False + model_quota_groups = { + "gemini_3_pro": ["gemini-3-pro-preview", "gemini-3-pro-low", "gemini-3-pro-high"], + "gemini_3_flash": ["gemini-3-flash"], + "gemini_2_5_flash": ["gemini-2.5-flash"], + "gemini_2_5_flash_lite": ["gemini-2.5-flash-lite"], + "claude_sonnet_4_5": ["claude-sonnet-4.5"], + "claude_opus_4_5": ["claude-opus-4.5"], + "claude_opus_4_6": ["claude-opus-4.6"], + } + default_rotation_mode = "sequential" + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """Fetch available Antigravity models or return the restored safe list.""" + + try: + response = await client.post(self.get_native_endpoint(operation="models"), headers=self.get_native_headers(api_key), json={}, timeout=30) + response.raise_for_status() + models = self._models_from_response(response.json()) + if models: + return [self._with_prefix(model) for model in models] + except Exception: + return [self._with_prefix(model) for model in AVAILABLE_MODELS] + return [self._with_prefix(model) for model in AVAILABLE_MODELS] + + def get_api_base(self) -> str: + """Return the first configured Antigravity base URL.""" + + configured = os.getenv("ANTIGRAVITY_API_BASE") + return configured.rstrip("/") if configured else BASE_URLS[0] + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "generate") -> dict[str, str]: + """Return static Antigravity headers plus bearer auth. + + Retired per-device fingerprint headers are intentionally not restored; + they are brittle and should only return if a current service requirement + is verified with tests. + """ + + headers = { + "Authorization": f"Bearer {credential_identifier}", + "Content-Type": "application/json", + **ANTIGRAVITY_HEADERS, + } + if operation == "stream_generate": + headers["Accept"] = "text/event-stream" + return headers + + def get_native_operation(self, model: str = "", request: dict[str, Any] | None = None, stream: bool = False) -> str: + """Return the Gemini generate operation used by Antigravity endpoints.""" + + return "stream_generate" if stream else "generate" + + def normalize_native_model(self, model: str) -> str: + """Strip the proxy prefix and map public aliases to upstream names.""" + + clean = model.split("/", 1)[1] if model.startswith("antigravity/") else model + return self._alias_to_internal(clean) + + def prepare_native_request(self, request: dict[str, Any], model: str = "", operation: str = "") -> dict[str, Any]: + """Return a request with the upstream model and Gemini contents shape. + + The provider intentionally keeps this to model alias and message-shape + handling only. Device profile and fingerprint behavior stays out of the + active integration until it is verified against current service behavior. + """ + + prepared = dict(request) + public_model = str(request.get("_proxy_model") or request.get("model") or "") + thinking_level = _thinking_level_from_model(public_model) + if model: + prepared["model"] = _model_with_thinking_variant(model, thinking_level) + prepared.pop("_proxy_model", None) + if thinking_level: + generation_config = prepared.setdefault("generationConfig", {}) + generation_config.setdefault("thinkingConfig", {})["thinkingLevel"] = thinking_level + prepared.setdefault("metadata", {})["thinking_level"] = thinking_level + if "contents" not in prepared and isinstance(prepared.get("messages"), list): + prepared["contents"] = [_message_to_gemini_content(message) for message in prepared.pop("messages")] + return prepared + + def supports_native_streaming(self, model: str = "", operation: str = "generate") -> bool: + """Return false until native stream wrapping is provider-safe.""" + + return False + + def get_native_endpoint(self, model: str = "", operation: str = "generate") -> str: + """Return Antigravity internal operation endpoints.""" + + if operation == "models": + return f"{self.get_api_base()}:fetchAvailableModels" + if operation == "stream_generate": + return f"{self.get_api_base()}:streamGenerateContent?alt=sse" + return f"{self.get_api_base()}:generateContent" + + def get_adapter_config(self, model: str = "") -> dict[str, dict[str, Any]]: + """Configure the safe Antigravity internal request envelope.""" + + return { + "antigravity_envelope": { + "project": os.getenv("ANTIGRAVITY_PROJECT", ""), + "user_agent": ANTIGRAVITY_HEADERS["User-Agent"], + "request_type": os.getenv("ANTIGRAVITY_REQUEST_TYPE", "CHAT_COMPLETION"), + } + } + + def get_model_tier_requirement(self, model: str) -> Optional[int]: + """Antigravity exposes no restored model-tier restriction.""" + + return None + + def normalize_model_for_tracking(self, model: str) -> str: + """Normalize internal names to public aliases while preserving prefix.""" + + if "/" in model: + provider, clean_model = model.split("/", 1) + return f"{provider}/{self._api_to_user_model(clean_model)}" + return self._api_to_user_model(model) + + @staticmethod + def _with_prefix(model: str) -> str: + if model.startswith("antigravity/"): + return model + return f"antigravity/{model}" + + @staticmethod + def _alias_to_internal(alias: str) -> str: + if alias in {"rev19-uic3-1p", "gemini-3-pro-image", "gemini-3-pro-low", "gemini-3-pro-high"}: + return MODEL_ALIAS_MAP.get(alias, alias) + return MODEL_ALIAS_REVERSE.get(alias, MODEL_ALIAS_MAP.get(alias, alias)) + + @staticmethod + def _api_to_user_model(internal: str) -> str: + return MODEL_ALIAS_MAP.get(internal, internal) + + def _models_from_response(self, payload: dict[str, Any]) -> list[str]: + raw_models: list[str] = [] + if isinstance(payload.get("models"), dict): + raw_models.extend(payload["models"].keys()) + if isinstance(payload.get("data"), list): + raw_models.extend(item.get("id") for item in payload["data"] if isinstance(item, dict) and item.get("id")) + result = [] + for model in raw_models: + public = self._api_to_user_model(str(model)) + if public not in EXCLUDED_MODELS and public in AVAILABLE_MODELS and public not in result: + result.append(public) + return result + + +def _message_to_gemini_content(message: Any) -> dict[str, Any]: + """Return a minimal Gemini content item from an OpenAI-style message.""" + + if not isinstance(message, dict): + return {"role": "user", "parts": [{"text": str(message)}]} + role = "model" if message.get("role") == "assistant" else "user" + content = message.get("content", "") + parts = content if isinstance(content, list) else [{"text": str(content)}] + return {"role": role, "parts": parts} + + +def _thinking_level_from_model(model: str) -> Optional[str]: + clean = model.split("/", 1)[1] if model.startswith("antigravity/") else model + if clean.endswith("-low"): + return "low" + if clean.endswith("-high"): + return "high" + return None + + +def _model_with_thinking_variant(model: str, thinking_level: Optional[str]) -> str: + if model == "gemini-3-pro-preview" and thinking_level in {"low", "high"}: + return f"gemini-3-pro-{thinking_level}" + return model diff --git a/src/rotator_library/providers/claude_code_provider.py b/src/rotator_library/providers/claude_code_provider.py new file mode 100644 index 000000000..40f5c373b --- /dev/null +++ b/src/rotator_library/providers/claude_code_provider.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Claude Code provider integration for native Anthropic Messages execution.""" + +from __future__ import annotations + +import os +from typing import Any, List + +import httpx + +from ..field_cache import FieldCacheInjection, FieldCacheRule +from .provider_interface import ProviderInterface + +DEFAULT_API_BASE = "https://api.anthropic.com" +FALLBACK_MODELS = ["claude_code/claude-sonnet-4-5", "claude_code/claude-opus-4-5"] + + +class ClaudeCodeProvider(ProviderInterface): + """Provider declaration for Claude Code style native requests. + + The provider starts as an explicit integration path rather than a guessed live + implementation. It declares protocol/adapters/cache rules and exposes mocked + auth/model helpers so later native wiring can use it without reintroducing a + monolithic provider transform. + """ + + provider_env_name = "claude_code" + protocol_name = "anthropic_messages" + adapter_names = ("suppress_developer_role",) + native_streaming_supported = False + field_cache_rules = ( + FieldCacheRule( + name="claude_code_thinking_signature", + source="response", + path="content.*.signature", + mode="all", + scope=("provider", "model", "credential", "session"), + inject=FieldCacheInjection(target="request", path="metadata.thinking_signatures", as_list=True), + metadata={"purpose": "preserve Claude thinking signatures for follow-up requests"}, + ), + ) + default_rotation_mode = "sequential" + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """Fetch provider models with a conservative fallback list. + + Model discovery is intentionally mock-friendly. If the configured service + does not expose a standard `/v1/models` response, provider work can later + override this without changing protocol declarations. + """ + + try: + response = await client.get(f"{self.get_api_base().rstrip('/')}/v1/models", headers=self.get_native_headers(api_key), timeout=30) + response.raise_for_status() + models = [item.get("id") for item in response.json().get("data", []) if isinstance(item, dict) and item.get("id")] + if models: + return [self._with_prefix(model) for model in models] + except Exception: + return list(FALLBACK_MODELS) + return list(FALLBACK_MODELS) + + def get_api_base(self) -> str: + """Return the configured Claude Code API base URL.""" + + return os.getenv("CLAUDE_CODE_API_BASE", DEFAULT_API_BASE) + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "messages") -> dict[str, str]: + """Return headers for native mocked HTTP requests.""" + + mode = os.getenv("CLAUDE_CODE_AUTH_HEADER", "auto").strip().lower() + use_api_key = mode == "x-api-key" or (mode == "auto" and credential_identifier.startswith("sk-ant-")) + headers = { + "anthropic-version": os.getenv("CLAUDE_CODE_ANTHROPIC_VERSION", "2023-06-01"), + "content-type": "application/json", + } + if use_api_key: + headers["x-api-key"] = credential_identifier + else: + headers["Authorization"] = f"Bearer {credential_identifier}" + return headers + + def get_native_operation(self, model: str = "", request: dict[str, Any] | None = None, stream: bool = False) -> str: + """Claude Code uses the Anthropic Messages operation for completions.""" + + return "messages" + + def normalize_native_model(self, model: str) -> str: + """Strip the proxy provider prefix before sending upstream.""" + + return model.split("/", 1)[1] if model.startswith("claude_code/") else model + + def supports_native_streaming(self, model: str = "", operation: str = "messages") -> bool: + """Return false until the generic native stream wrapper is compatible.""" + + return False + + def prepare_native_request(self, request: dict[str, Any], model: str = "", operation: str = "messages") -> dict[str, Any]: + """Ensure Anthropic Messages-required fields are present.""" + + prepared = dict(request) + prepared.setdefault("max_tokens", int(os.getenv("CLAUDE_CODE_MAX_TOKENS", "4096"))) + return prepared + + def get_native_endpoint(self, model: str = "", operation: str = "messages") -> str: + """Return the provider endpoint for a native operation.""" + + if operation == "models": + return f"{self.get_api_base().rstrip('/')}/v1/models" + return f"{self.get_api_base().rstrip('/')}/v1/messages" + + def get_adapter_config(self, model: str = "") -> dict[str, dict[str, Any]]: + """Configure adapters without hardcoding provider transforms.""" + + return {"suppress_developer_role": {"mode": "user"}} + + @staticmethod + def _with_prefix(model: str) -> str: + if model.startswith("claude_code/"): + return model + return f"claude_code/{model}" diff --git a/src/rotator_library/providers/codex_provider.py b/src/rotator_library/providers/codex_provider.py new file mode 100644 index 000000000..7a882a972 --- /dev/null +++ b/src/rotator_library/providers/codex_provider.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Codex provider integration for native Responses execution.""" + +from __future__ import annotations + +import os +from typing import Any, List + +import httpx + +from ..field_cache import FieldCacheInjection, FieldCacheRule +from .provider_interface import ProviderInterface + +DEFAULT_API_BASE = "https://api.openai.com" +FALLBACK_MODELS = ["codex/codex-mini-latest", "codex/gpt-5.1-codex"] + + +class CodexProvider(ProviderInterface): + """Provider declaration for Codex-style native Responses requests.""" + + provider_env_name = "codex" + protocol_name = "responses" + adapter_names: tuple[str, ...] = () + native_streaming_supported = False + field_cache_rules = ( + FieldCacheRule( + name="codex_previous_response_id", + source="response", + path="id", + scope=("provider", "model", "credential", "session"), + inject=FieldCacheInjection(target="request", path="previous_response_id", when_missing_only=True), + metadata={"purpose": "preserve Responses continuation IDs for Codex sessions"}, + ), + ) + default_rotation_mode = "sequential" + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """Fetch Codex-visible models with fallback names for offline tests.""" + + try: + response = await client.get(self.get_native_endpoint(operation="models"), headers=self.get_native_headers(api_key), timeout=30) + response.raise_for_status() + models = [item.get("id") for item in response.json().get("data", []) if isinstance(item, dict) and item.get("id")] + codex_models = [model for model in models if "codex" in model.lower()] + if codex_models: + return [self._with_prefix(model) for model in codex_models] + except Exception: + return list(FALLBACK_MODELS) + return list(FALLBACK_MODELS) + + def get_api_base(self) -> str: + """Return the configured Codex API base URL.""" + + return os.getenv("CODEX_API_BASE", DEFAULT_API_BASE) + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "responses") -> dict[str, str]: + """Return headers for Codex native HTTP calls.""" + + return {"Authorization": f"Bearer {credential_identifier}", "content-type": "application/json"} + + def get_native_operation(self, model: str = "", request: dict | None = None, stream: bool = False) -> str: + """Codex native calls use the Responses operation.""" + + return "responses" + + def normalize_native_model(self, model: str) -> str: + """Strip the proxy provider prefix before sending upstream.""" + + return model.split("/", 1)[1] if model.startswith("codex/") else model + + def supports_native_streaming(self, model: str = "", operation: str = "responses") -> bool: + """Return false until the generic native stream wrapper is compatible.""" + + return False + + def prepare_native_request(self, request: dict[str, Any], model: str = "", operation: str = "") -> dict[str, Any]: + """Convert chat-style proxy input into Responses API input. + + Direct `/v1/responses` calls already provide `input`; routed chat calls + usually provide `messages`. Converting here keeps the reusable Responses + protocol focused on its native shape while still making Codex mock-live + through the generic native executor. + """ + + prepared = dict(request) + if model: + prepared["model"] = model + if "input" not in prepared and isinstance(prepared.get("messages"), list): + prepared["input"] = [_message_to_responses_input(message) for message in prepared.pop("messages")] + return prepared + + def get_native_endpoint(self, model: str = "", operation: str = "responses") -> str: + """Return the native Codex endpoint for an operation.""" + + suffix = "/v1/models" if operation == "models" else "/v1/responses" + return f"{self.get_api_base().rstrip('/')}{suffix}" + + @staticmethod + def _with_prefix(model: str) -> str: + if model.startswith("codex/"): + return model + return f"codex/{model}" + + +def _message_to_responses_input(message: Any) -> dict[str, Any]: + """Return a minimal Responses input item from an OpenAI-style message.""" + + if not isinstance(message, dict): + return {"role": "user", "content": str(message)} + return {"role": message.get("role", "user"), "content": message.get("content", "")} diff --git a/src/rotator_library/providers/copilot_provider.py b/src/rotator_library/providers/copilot_provider.py new file mode 100644 index 000000000..9d48846d4 --- /dev/null +++ b/src/rotator_library/providers/copilot_provider.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Copilot provider integration for native OpenAI Chat execution.""" + +from __future__ import annotations + +import os +from typing import List + +import httpx + +from .provider_interface import ProviderInterface + +DEFAULT_API_BASE = "https://api.githubcopilot.com" +FALLBACK_MODELS = ["copilot/gpt-4.1", "copilot/claude-sonnet-4-5"] + + +class CopilotProvider(ProviderInterface): + """Provider declaration for Copilot-style OpenAI-compatible chat calls. + + The skeleton intentionally avoids inventing field-cache rules until a stable + provider session/conversation field is identified. This keeps Copilot native + support explicit and testable without guessing hidden behavior. + """ + + provider_env_name = "copilot" + protocol_name = "openai_chat" + adapter_names = ("suppress_developer_role",) + field_cache_rules: tuple = () + default_rotation_mode = "sequential" + native_streaming_supported = False + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """Fetch Copilot-visible models with a safe fallback list.""" + + try: + response = await client.get(self.get_native_endpoint(operation="models"), headers=self.get_native_headers(api_key), timeout=30) + response.raise_for_status() + models = [item.get("id") for item in response.json().get("data", []) if isinstance(item, dict) and item.get("id")] + if models: + return [self._with_prefix(model) for model in models] + except Exception: + return list(FALLBACK_MODELS) + return list(FALLBACK_MODELS) + + def get_api_base(self) -> str: + """Return the configured Copilot API base URL.""" + + return os.getenv("COPILOT_API_BASE", DEFAULT_API_BASE) + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "chat") -> dict[str, str]: + """Return headers for Copilot native HTTP calls.""" + + return { + "Authorization": f"Bearer {credential_identifier}", + "content-type": "application/json", + "Copilot-Integration-Id": os.getenv("COPILOT_INTEGRATION_ID", "llm-api-key-proxy"), + } + + def get_native_operation(self, model: str = "", request: dict | None = None, stream: bool = False) -> str: + """Copilot exposes an OpenAI-compatible chat operation.""" + + return "chat" + + def normalize_native_model(self, model: str) -> str: + """Strip the proxy provider prefix before sending upstream.""" + + return model.split("/", 1)[1] if model.startswith("copilot/") else model + + def supports_native_streaming(self, model: str = "", operation: str = "chat") -> bool: + """Return false until the generic native stream wrapper is compatible.""" + + return False + + def get_native_endpoint(self, model: str = "", operation: str = "chat") -> str: + """Return the Copilot endpoint for a native operation.""" + + suffix = "/models" if operation == "models" else "/chat/completions" + return f"{self.get_api_base().rstrip('/')}{suffix}" + + def get_adapter_config(self, model: str = "") -> dict[str, dict[str, str]]: + """Configure role suppression declaratively for OpenAI-compatible chat.""" + + return {"suppress_developer_role": {"mode": "system"}} + + @staticmethod + def _with_prefix(model: str) -> str: + if model.startswith("copilot/"): + return model + return f"copilot/{model}" diff --git a/src/rotator_library/providers/gemini_cli_provider.py b/src/rotator_library/providers/gemini_cli_provider.py index fe0080ca3..a861a095c 100644 --- a/src/rotator_library/providers/gemini_cli_provider.py +++ b/src/rotator_library/providers/gemini_cli_provider.py @@ -13,6 +13,7 @@ from .provider_interface import ProviderInterface, QuotaGroupMap, UsageResetConfigDef from .gemini_auth_base import GeminiAuthBase from .provider_cache import ProviderCache +from ..field_cache import FieldCacheInjection, FieldCacheRule from .utilities.gemini_cli_quota_tracker import GeminiCliQuotaTracker from .utilities.gemini_shared_utils import ( env_bool, @@ -156,6 +157,24 @@ class GeminiCliProvider( # Provider name for env var lookups (QUOTA_GROUPS_GEMINI_CLI_*) provider_env_name: str = "gemini_cli" + # Native protocol declarations for the Phase 5 opt-in execution seam. The + # existing custom Gemini CLI execution path remains active; these metadata + # hooks let future native routing reuse the Gemini protocol and field-cache + # engine without rewriting this provider first. + protocol_name: str = "gemini" + adapter_names: tuple[str, ...] = () + field_cache_rules = ( + FieldCacheRule( + name="gemini_cli_thought_signature", + source="response", + path="candidates.*.content.parts.*.thoughtSignature", + mode="all", + scope=("provider", "model", "credential", "session"), + inject=FieldCacheInjection(target="request", path="metadata.thoughtSignatures", as_list=True), + metadata={"purpose": "mirror existing Gemini CLI thoughtSignature preservation in the native cache seam"}, + ), + ) + # Tier name -> priority mapping (from centralized tier utilities) # Lower numbers = higher priority (ULTRA=1 > PRO=2 > FREE=3) tier_priorities = TIER_PRIORITIES diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index fb5811bf4..710b44546 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -270,6 +270,14 @@ class ProviderInterface(ABC, metaclass=SingletonABCMeta): Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]] ] = {} + # Native protocol/adapter declarations introduced for the experimental + # protocol stack. Defaults are intentionally no-op so existing providers keep + # the LiteLLM-backed execution path until they opt into native protocols. + protocol_name: Optional[str] = None + adapter_names: Tuple[str, ...] = () + field_cache_rules: Tuple[Any, ...] = () + native_streaming_supported: bool = False + @abstractmethod async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: """ @@ -313,6 +321,132 @@ def get_session_tracking_hints( """ return None + def get_protocol_name(self, model: str = "") -> Optional[str]: + """Return the native protocol adapter name this provider prefers. + + Providers may override this method when protocol choice varies by model. + Returning ``None`` keeps the current fallback execution behavior. Later + provider phases will use this as the bridge from provider declarations to + native protocol parsing/building. + """ + + return self.protocol_name + + def get_adapter_names(self, model: str = "") -> Tuple[str, ...]: + """Return ordered adapter names for this provider/model. + + The order is significant and is preserved by the adapter chain runner. + Providers can override this for model-specific quirks without mutating the + global adapter registry. + """ + + return tuple(self.adapter_names) + + def get_adapter_config(self, model: str = "") -> Dict[str, Dict[str, Any]]: + """Return adapter-specific config keyed by adapter name. + + Config is intentionally a plain dict so custom providers can define it in + env/JSON later without importing adapter classes. Phase 10 will add formal + config loading and validation. + """ + + return {} + + def get_field_cache_rules(self, model: str = "") -> Tuple[Any, ...]: + """Return field-cache rules for provider-specific protocol state. + + Rules preserve provider state such as reasoning content, thought + signatures, prompt cache keys, provider session IDs, and response IDs. + They are not a replacement for ``SessionTracker``; session tracking still + decides continuity and credential affinity. + """ + + return tuple(self.field_cache_rules) + + def supports_native_streaming(self, model: str = "", operation: str = "chat") -> bool: + """Return whether this provider explicitly supports native streaming. + + The default is intentionally false. Phase 8 keeps live streaming + conservative: providers must opt in before routed streaming can use the + native stream executor instead of current custom/LiteLLM behavior. + """ + + return self.native_streaming_supported + + def supports_native_operation(self, model: str = "", operation: str = "chat") -> bool: + """Return whether this provider supports a native operation.""" + + protocol_name = self.get_protocol_name(model) + if not protocol_name: + return False + try: + from ..protocols import get_protocol + + return get_protocol(protocol_name).supports_operation(operation) + except Exception: + return False + + def should_use_native_protocol(self, model: str = "", operation: str = "chat", *, stream: bool = False, execution: str = "auto") -> bool: + """Return whether routing should use this provider's native protocol.""" + + if stream and not self.supports_native_streaming(model, operation): + return False + return bool(self.get_protocol_name(model) and self.supports_native_operation(model, operation)) + + def get_native_operation(self, model: str = "", request: Optional[Dict[str, Any]] = None, stream: bool = False) -> str: + """Return the provider-native operation for a request. + + Providers that expose native protocols often use operation names that are + not simply ``chat``: Anthropic-compatible providers use ``messages``, + Responses providers use ``responses``, and Gemini-style providers use a + generate operation. The default stays ``chat`` so existing OpenAI-chat + providers remain compatible unless they opt in to something richer. + """ + + return "chat" + + def get_native_endpoint(self, model: str = "", operation: str = "chat") -> str: + """Return the upstream endpoint for a native operation.""" + + raise NotImplementedError(f"{self.__class__.__name__} does not define a native endpoint") + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "chat") -> Dict[str, str]: + """Return non-payload HTTP headers for native requests.""" + + raise NotImplementedError(f"{self.__class__.__name__} does not define native headers") + + def normalize_native_model(self, model: str) -> str: + """Return the upstream model name for native provider calls. + + The proxy-facing model commonly includes a provider prefix such as + ``provider/model``. Native upstream APIs usually expect only ``model``. + Providers may override this for aliases, but stripping the first prefix + is the safe default for native execution. + """ + + return model.split("/", 1)[1] if "/" in model else model + + def prepare_native_request(self, request: Dict[str, Any], model: str = "", operation: str = "") -> Dict[str, Any]: + """Return a provider-adjusted native request payload. + + This hook is intentionally limited to payload shape/model aliases and is + called before protocol parsing. It must not add credentials; auth belongs + in ``get_native_headers()`` so traces never mix request payloads with + secrets. + """ + + return dict(request) + + def get_model_pricing(self, model: str = "") -> Optional[Any]: + """Return optional local pricing metadata for advisory cost tracking. + + Providers can return `usage.costs.ModelPricing` or a compatible dict. + The default is `None`, which lets cost accounting safely fall back to + LiteLLM model metadata or report pricing as unavailable. + """ + + return None + async def acompletion( self, client: httpx.AsyncClient, **kwargs ) -> Union[ diff --git a/src/rotator_library/responses/__init__.py b/src/rotator_library/responses/__init__.py new file mode 100644 index 000000000..0ba921406 --- /dev/null +++ b/src/rotator_library/responses/__init__.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Responses API service, storage, and streaming helpers.""" + +from .bridge import ResponsesBridge +from .service import ResponsesService, ResponsesServiceError +from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore, create_configured_responses_store +from .streaming import ResponsesSSEFormatter, ResponsesStreamEvent, ResponsesWebSocketFormatter +from .types import ResponsesStoreSettings, StoredResponse, generate_response_id + +__all__ = [ + "InMemoryResponsesStore", + "ProviderCacheResponsesStore", + "ResponsesBridge", + "ResponsesService", + "ResponsesServiceError", + "ResponsesStoreSettings", + "ResponsesSSEFormatter", + "ResponsesStreamEvent", + "ResponsesStore", + "ResponsesWebSocketFormatter", + "StoredResponse", + "create_configured_responses_store", + "generate_response_id", +] diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py new file mode 100644 index 000000000..2dcda9c7b --- /dev/null +++ b/src/rotator_library/responses/bridge.py @@ -0,0 +1,348 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Bridge between Responses requests and the current chat-completions executor.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Optional + +from ..protocols import ContentBlock, ToolDefinition, UnifiedMessage, UnifiedRequest, serialize_value +from ..protocols.responses import ResponsesProtocol +from .types import generate_response_id + +_CHAT_GENERATION_KEYS = { + "frequency_penalty", + "logit_bias", + "logprobs", + "max_completion_tokens", + "max_tokens", + "max_output_tokens", + "metadata", + "n", + "parallel_tool_calls", + "presence_penalty", + "reasoning", + "reasoning_effort", + "response_format", + "seed", + "stop", + "stream_options", + "temperature", + "tool_choice", + "top_logprobs", + "top_p", + "user", +} + + +class ResponsesBridge: + """Temporary compatibility bridge from Responses to chat completions. + + Later provider phases should call native Responses-capable providers directly. + Until then, this bridge makes `/v1/responses` useful while preserving enough + metadata and trace detail to debug fields that cannot be represented in chat. + """ + + def __init__(self, protocol: Optional[ResponsesProtocol] = None) -> None: + self.protocol = protocol or ResponsesProtocol() + + def to_chat_kwargs( + self, + unified: UnifiedRequest, + *, + parent_response: Optional[dict[str, Any]] = None, + parent_responses: Optional[list[dict[str, Any]]] = None, + ) -> dict[str, Any]: + """Convert a parsed Responses request to chat-completions kwargs.""" + + messages: list[dict[str, Any]] = [] + system_text = _blocks_to_text(unified.system) + if system_text: + messages.append({"role": "system", "content": system_text}) + if parent_responses is not None: + for parent in parent_responses: + messages.extend(_parent_request_to_messages(parent.get("request") or {})) + messages.extend(_parent_output_to_messages(parent.get("output") or parent.get("response", {}).get("output") or [])) + elif parent_response: + messages.extend(_parent_output_to_messages(parent_response.get("output") or [])) + messages.extend(_message_to_chat(message) for message in unified.messages) + kwargs: dict[str, Any] = { + "model": unified.model, + "messages": messages, + "stream": unified.stream, + } + if unified.tools: + kwargs["tools"] = [_tool_to_chat(tool) for tool in unified.tools] + for key, value in unified.generation_params.items(): + if key in _CHAT_GENERATION_KEYS: + kwargs[_chat_generation_key(key)] = deepcopy(value) + if unified.metadata: + kwargs.setdefault("metadata", deepcopy(unified.metadata)) + unsupported = { + "previous_response_id": unified.previous_response_id, + "extra": deepcopy(unified.extra), + } + kwargs["_responses_bridge"] = {k: v for k, v in unsupported.items() if v} + hints = responses_session_hints(unified.previous_response_id) + if hints: + kwargs["_session_tracking_hints"] = hints + return kwargs + + def from_chat_response( + self, + chat_response: Any, + unified_request: UnifiedRequest, + *, + response_id: Optional[str] = None, + ) -> dict[str, Any]: + """Convert a chat-completions response into a Responses object.""" + + response = _as_dict(chat_response) + output = [] + for index, choice in enumerate(response.get("choices") or []): + if not isinstance(choice, dict): + continue + message = choice.get("message") or {} + output.extend(_chat_message_to_output_items(message, index)) + responses_payload = { + "id": response_id or response.get("id") or generate_response_id(), + "object": "response", + "created_at": response.get("created") or response.get("created_at"), + "model": response.get("model") or unified_request.model, + "status": _status_from_chat(response), + "output": output, + "usage": _usage_to_responses(response), + } + if unified_request.metadata: + responses_payload["metadata"] = deepcopy(unified_request.metadata) + return self.protocol.format_response(self.protocol.parse_response(responses_payload)) + + +def _chat_generation_key(key: str) -> str: + if key == "max_output_tokens": + return "max_tokens" + return key + + +def responses_session_hints(previous_response_id: Optional[str], *, affinity_key: Optional[str] = None) -> dict[str, Any] | None: + """Return proxy-internal sticky routing evidence for Responses continuations.""" + + if not previous_response_id: + return None + anchor = f"responses_previous_response_id:{previous_response_id}" + return { + "strong_anchors": [anchor], + "affinity_key": affinity_key or anchor, + } + + +def _message_to_chat(message: UnifiedMessage) -> dict[str, Any]: + if message.role == "tool": + result = _tool_result_from_message(message) + payload = {"role": "tool", "content": result["content"]} + if result.get("tool_call_id"): + payload["tool_call_id"] = result["tool_call_id"] + return payload + payload = {"role": message.role, "content": _blocks_to_chat_content(message.content)} + if message.name: + payload["name"] = message.name + if message.tool_call_id: + payload["tool_call_id"] = message.tool_call_id + if message.tool_calls: + payload["tool_calls"] = [call.to_dict() for call in message.tool_calls] + return payload + + +def _tool_result_from_message(message: UnifiedMessage) -> dict[str, Any]: + for block in message.content: + if block.tool_result is not None: + content = block.tool_result.content + return { + "tool_call_id": message.tool_call_id or block.tool_result.tool_call_id, + "content": content if isinstance(content, str) else serialize_value(content), + } + return {"tool_call_id": message.tool_call_id, "content": _blocks_to_chat_content(message.content)} + + +def _blocks_to_chat_content(blocks: list[ContentBlock]) -> Any: + if not blocks: + return "" + if len(blocks) == 1 and blocks[0].text is not None: + return blocks[0].text + content = [] + for block in blocks: + if block.text is not None: + content.append({"type": "text", "text": block.text}) + elif block.source is not None: + content.append({"type": block.type, "source": deepcopy(block.source)}) + elif block.raw is not None: + content.append(deepcopy(block.raw)) + return content or "" + + +def _blocks_to_text(blocks: list[ContentBlock]) -> str: + return "".join(block.text or "" for block in blocks if block.text) + + +def _tool_to_chat(tool: ToolDefinition) -> dict[str, Any]: + if tool.type == "function": + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": deepcopy(tool.input_schema), + }, + } + payload = {"type": tool.type, "name": tool.name} + payload.update(deepcopy(tool.extra)) + return payload + + +def _parent_output_to_messages(output: list[Any]) -> list[dict[str, Any]]: + messages = [] + for item in output: + if not isinstance(item, dict): + continue + if item.get("type") == "message": + text = _responses_content_to_text(item.get("content") or []) + if text: + messages.append({"role": item.get("role") or "assistant", "content": text}) + elif item.get("type") in {"function_call", "custom_tool_call"}: + tool_call = _responses_tool_call_to_chat(item) + if tool_call: + messages.append({"role": "assistant", "content": None, "tool_calls": [tool_call]}) + return messages + + +def _responses_tool_call_to_chat(item: dict[str, Any]) -> dict[str, Any] | None: + """Convert a stored Responses tool-call item to OpenAI chat shape.""" + + call_id = item.get("call_id") or item.get("id") + name = item.get("name") or item.get("tool_name") or item.get("type") + arguments = item.get("arguments", item.get("input", "")) + if call_id is None or name is None: + return None + return { + "id": call_id, + "type": "function", + "function": {"name": str(name), "arguments": arguments if isinstance(arguments, str) else serialize_value(arguments)}, + } + + +def _parent_request_to_messages(request: dict[str, Any]) -> list[dict[str, Any]]: + """Replay convertible parent Responses input items as chat messages.""" + + if not isinstance(request, dict): + return [] + protocol = ResponsesProtocol() + try: + unified = protocol.parse_request(request) + except Exception: + return _raw_input_to_messages(request.get("input")) + return [_message_to_chat(message) for message in unified.messages] + + +def _raw_input_to_messages(value: Any) -> list[dict[str, Any]]: + if value in (None, ""): + return [] + if isinstance(value, str): + return [{"role": "user", "content": value}] + if isinstance(value, list): + messages: list[dict[str, Any]] = [] + for item in value: + if isinstance(item, str): + messages.append({"role": "user", "content": item}) + elif isinstance(item, dict) and item.get("role") and item.get("content") is not None: + messages.append({"role": item.get("role"), "content": item.get("content")}) + return messages + return [] + + +def _chat_message_to_output_items(message: dict[str, Any], index: int) -> list[dict[str, Any]]: + items = [] + content = message.get("content") + if content is not None: + items.append( + { + "id": f"msg_{index}", + "type": "message", + "role": message.get("role") or "assistant", + "content": [{"type": "output_text", "text": content if isinstance(content, str) else serialize_value(content)}], + } + ) + for tool_call in message.get("tool_calls") or []: + function = tool_call.get("function") if isinstance(tool_call, dict) else None + if isinstance(function, dict): + items.append( + { + "id": tool_call.get("id"), + "type": "function_call", + "call_id": tool_call.get("id"), + "name": function.get("name"), + "arguments": function.get("arguments"), + } + ) + return items + + +def _responses_content_to_text(content: list[Any]) -> str: + parts = [] + for block in content: + if isinstance(block, dict): + text = block.get("text") + if text: + parts.append(str(text)) + elif isinstance(block, str): + parts.append(block) + return "".join(parts) + + +def _status_from_chat(response: dict[str, Any]) -> str: + for choice in response.get("choices") or []: + finish_reason = choice.get("finish_reason") if isinstance(choice, dict) else None + if finish_reason in {"length", "content_filter"}: + return "incomplete" + return "completed" + + +def _usage_to_responses(usage: Any) -> Any: + if isinstance(usage, dict) and isinstance(usage.get("usage"), dict): + nested = dict(usage["usage"]) + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in usage and key not in nested: + nested[key] = deepcopy(usage[key]) + usage = nested + if not isinstance(usage, dict): + return usage + result = { + "input_tokens": usage.get("prompt_tokens", usage.get("input_tokens", 0)), + "output_tokens": usage.get("completion_tokens", usage.get("output_tokens", 0)), + "total_tokens": usage.get("total_tokens", 0), + } + prompt_details = usage.get("prompt_tokens_details") or usage.get("input_tokens_details") + if isinstance(prompt_details, dict): + result["input_tokens_details"] = {"cached_tokens": prompt_details.get("cached_tokens", 0)} + completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") + if isinstance(completion_details, dict): + result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in usage: + result[key] = deepcopy(usage[key]) + if "request_cost_usd" in result and "cost_details" not in result: + result["cost_details"] = {"request_cost_usd": result["request_cost_usd"]} + if "estimated_cost" in result and "cost_details" not in result: + result["cost_details"] = {"estimated_cost": result["estimated_cost"]} + return result + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return serialize_value(value) if isinstance(serialize_value(value), dict) else {} diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py new file mode 100644 index 000000000..e3c68008f --- /dev/null +++ b/src/rotator_library/responses/service.py @@ -0,0 +1,1021 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Service layer for the OpenAI-compatible Responses API.""" + +from __future__ import annotations + +import asyncio +import json +import time +from copy import deepcopy +from typing import Any, AsyncGenerator, Optional + +from ..protocols import ProtocolContext +from ..streaming import StreamEvent, StreamMonitor +from ..config.experimental import get_stream_runtime_settings +from ..usage.accounting import extract_usage_record +from ..usage.costs import CostCalculator +from ..protocols.responses import ResponsesProtocol +from .bridge import ResponsesBridge, responses_session_hints +from .store import InMemoryResponsesStore, ResponsesStore +from .streaming import ( + ResponsesSSEFormatter, + ResponsesStreamEvent, + ResponsesStreamState, + output_item_added_payload, + output_item_done_payload, + output_text_delta_payload, + parse_chat_sse_chunk, + response_completed_payload, + response_created_payload, + response_failed_payload, +) +from .types import ResponsesStoreSettings, StoredResponse +from .types import generate_response_id + + +class ResponsesServiceError(ValueError): + """Error with an HTTP-compatible status code for proxy routes.""" + + def __init__(self, message: str, *, status_code: int = 400, error_type: str = "invalid_request_error") -> None: + self.status_code = status_code + self.error_type = error_type + super().__init__(message) + + +class ResponsesService: + """Create, store, retrieve, and delete Responses API objects. + + Phase 4 deliberately bridges through the existing chat-completions execution + path. Native Responses-capable provider execution will replace the bridge for + covered providers in later phases without changing the route/storage surface. + """ + + def __init__( + self, + *, + protocol: Optional[ResponsesProtocol] = None, + bridge: Optional[ResponsesBridge] = None, + store: Optional[ResponsesStore] = None, + store_settings: Optional[ResponsesStoreSettings] = None, + ) -> None: + self.store_settings = store_settings or ResponsesStoreSettings() + self.protocol = protocol or ResponsesProtocol() + self.bridge = bridge or ResponsesBridge(self.protocol) + self.store = store or InMemoryResponsesStore(max_items=self.store_settings.max_items) + + async def create_response( + self, + raw_request: dict[str, Any], + client: Any, + *, + request: Optional[Any] = None, + transaction_logger: Optional[Any] = None, + ) -> dict[str, Any]: + """Create a non-streaming Responses object through the chat bridge.""" + + if not raw_request.get("model"): + raise ResponsesServiceError("'model' is required", status_code=400) + if raw_request.get("stream"): + raise ResponsesServiceError("Use stream_response for streaming requests", status_code=400) + + self._trace(transaction_logger, "responses_raw_request", raw_request, direction="request", stage="client") + try: + unified = self.protocol.parse_request(raw_request, ProtocolContext(source_protocol="responses")) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_parse_request", exc, raw_request) + raise + if transaction_logger: + self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") + + parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) + try: + parent_lineage = await self._load_response_lineage(parent) + chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_responses=[stored.to_dict() for stored in parent_lineage] if parent_lineage else None) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_bridge_chat_request", exc, unified.to_dict()) + raise + bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) + session_hints = chat_kwargs.pop("_session_tracking_hints", None) + session_hints = _responses_session_hints(unified.previous_response_id, parent, session_hints) + session_info: dict[str, Any] = {} + chat_kwargs.update(_internal_client_kwargs(client, session_hints, session_info)) + trace_chat_kwargs = _without_internal_kwargs(chat_kwargs) + self._trace( + transaction_logger, + "responses_bridge_chat_request", + trace_chat_kwargs, + direction="request", + stage="adapter", + metadata={"bridge_metadata": {**bridge_metadata, "has_session_hints": bool(session_hints)}}, + ) + + chat_response = await client.acompletion(request=request, **chat_kwargs) + if transaction_logger: + self._trace(transaction_logger, "responses_bridge_chat_response", self._response_to_dict(chat_response), direction="response", stage="provider") + + try: + response_payload = self.bridge.from_chat_response(chat_response, unified) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_bridge_chat_response", exc, self._response_to_dict(chat_response)) + raise + _record_responses_session_anchor(session_info, response_payload) + self._trace(transaction_logger, "responses_parsed_response", response_payload, direction="response", stage="protocol") + self._trace_responses_usage(transaction_logger, response_payload, unified.model, source="responses_response") + + if raw_request.get("store", True): + stored = self._stored_response(raw_request, response_payload, parent, session_info=session_info) + try: + await self.store.save(stored) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_store_response", exc, stored.to_dict()) + raise + self._trace(transaction_logger, "responses_stored_response", stored.to_dict(), direction="metadata", stage="final") + + self._trace(transaction_logger, "responses_final_response", response_payload, direction="response", stage="final") + return response_payload + + async def stream_response( + self, + raw_request: dict[str, Any], + client: Any, + *, + request: Optional[Any] = None, + transaction_logger: Optional[Any] = None, + transport: str = "sse", + ) -> AsyncGenerator[str, None]: + """Stream a Responses API request as HTTP SSE events.""" + + formatter = ResponsesSSEFormatter() + async for event in self.stream_events(raw_request, client, request=request, transaction_logger=transaction_logger, transport=transport): + formatted = formatter.format_stream_event(event) + self._trace( + transaction_logger, + "responses_sse_formatted_event", + formatted, + direction="stream", + stage="final", + metadata={"event_name": event.event_name, "terminal": event.terminal, "transport": transport}, + scrub_strings=True, + ) + if event.heartbeat: + self._trace( + transaction_logger, + "responses_sse_formatted_heartbeat", + formatted, + direction="stream", + stage="final", + metadata={"transport": transport}, + scrub_strings=True, + ) + yield formatted + + async def validate_stream_request(self, raw_request: dict[str, Any]) -> None: + """Validate stream-only preconditions before an HTTP response starts.""" + + if not raw_request.get("model"): + raise ResponsesServiceError("'model' is required", status_code=400) + previous_response_id = raw_request.get("previous_response_id") + if previous_response_id: + await self._load_previous_response(str(previous_response_id), None) + + async def stream_events( + self, + raw_request: dict[str, Any], + client: Any, + *, + request: Optional[Any] = None, + transaction_logger: Optional[Any] = None, + transport: str = "sse", + ) -> AsyncGenerator[ResponsesStreamEvent, None]: + """Yield transport-neutral Responses events for streaming transports.""" + + if not raw_request.get("model"): + raise ResponsesServiceError("'model' is required", status_code=400) + stream_request = dict(raw_request) + stream_request["stream"] = True + self._trace(transaction_logger, "responses_raw_request", stream_request, direction="request", stage="client") + try: + unified = self.protocol.parse_request(stream_request, ProtocolContext(source_protocol="responses", transport=transport)) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_parse_request", exc, stream_request) + raise + if transaction_logger: + self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") + parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) + try: + parent_lineage = await self._load_response_lineage(parent) + chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_responses=[stored.to_dict() for stored in parent_lineage] if parent_lineage else None) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_bridge_chat_request", exc, unified.to_dict()) + raise + bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) + session_hints = chat_kwargs.pop("_session_tracking_hints", None) + session_hints = _responses_session_hints(unified.previous_response_id, parent, session_hints) + session_info: dict[str, Any] = {} + chat_kwargs.update(_internal_client_kwargs(client, session_hints, session_info)) + chat_kwargs["stream"] = True + trace_chat_kwargs = _without_internal_kwargs(chat_kwargs) + self._trace( + transaction_logger, + "responses_bridge_chat_request", + trace_chat_kwargs, + direction="request", + stage="adapter", + metadata={"bridge_metadata": {**bridge_metadata, "has_session_hints": bool(session_hints)}, "transport": transport}, + ) + + response_id = generate_response_id() + state = ResponsesStreamState(response_id=response_id, model=unified.model) + usage = None + item_started = False + monitor = StreamMonitor(clock=time.monotonic) + stream_settings = get_stream_runtime_settings() + chat_stream = None + stream_iterator = None + upstream_closed = False + pending_next_task = None + pending_next_started_at = None + pending_next_last_heartbeat_at = None + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + ttfb_started_at = time.monotonic() + + async def cancel_task(task: Any) -> None: + """Cancel and await an in-flight stream task before closing its source.""" + + if task is None or task.done(): + return + task.cancel() + try: + await task + except (asyncio.CancelledError, StopAsyncIteration): + return + except Exception: + return + + async def close_upstream(reason: str) -> None: + """Best-effort close for upstream Responses bridge streams.""" + + nonlocal upstream_closed + if upstream_closed: + return + attempted = False + for candidate in (stream_iterator, chat_stream): + if candidate is None: + continue + attempted = True + try: + closer = getattr(candidate, "aclose", None) + if callable(closer): + await closer() + upstream_closed = True + self._trace(transaction_logger, "responses_stream_upstream_closed", {"reason": reason}, direction="stream", stage="provider", metadata={"transport": transport}) + return + closer = getattr(candidate, "close", None) + if callable(closer): + closer() + upstream_closed = True + self._trace(transaction_logger, "responses_stream_upstream_closed", {"reason": reason}, direction="stream", stage="provider", metadata={"transport": transport}) + return + except Exception as exc: + self._trace(transaction_logger, "responses_stream_upstream_close_failed", {"reason": reason, "error_type": type(exc).__name__}, direction="stream", stage="provider", metadata={"transport": transport}) + continue + if attempted: + self._trace(transaction_logger, "responses_stream_upstream_close_failed", {"reason": reason, "error_type": "no_close_method"}, direction="stream", stage="provider", metadata={"transport": transport}) + + async def next_upstream_chunk(*, first: bool) -> tuple[str, Any]: + """Return the next upstream chunk or a control marker.""" + + timeout = stream_settings.ttfb_timeout_seconds if first else stream_settings.stall_timeout_seconds + heartbeat = stream_settings.heartbeat_seconds + nonlocal pending_next_task, pending_next_started_at, pending_next_last_heartbeat_at + if pending_next_task is None: + pending_next_task = asyncio.create_task(stream_iterator.__anext__()) + pending_next_started_at = time.monotonic() + pending_next_last_heartbeat_at = pending_next_started_at + next_task = pending_next_task + started_at = ttfb_started_at if first else (pending_next_started_at or time.monotonic()) + while True: + if next_task.done(): + pending_next_task = None + pending_next_started_at = None + pending_next_last_heartbeat_at = None + return "chunk", next_task.result() + if request is not None and await request.is_disconnected(): + self._trace(transaction_logger, "responses_stream_disconnected", {"reason": "client_disconnected"}, direction="stream", stage="client", metadata={"transport": transport}) + if stream_settings.cancel_upstream_on_disconnect: + await cancel_task(next_task) + pending_next_task = None + await close_upstream("client_disconnected") + return "disconnect", None + elapsed = time.monotonic() - started_at + waits = [] + if timeout is not None: + remaining_timeout = timeout - elapsed + if remaining_timeout <= 0: + await cancel_task(next_task) + pending_next_task = None + await close_upstream("ttfb_timeout" if first else "stall_timeout") + raise ResponsesServiceError( + f"Responses stream {'TTFB' if first else 'stall'} timeout", + status_code=504, + error_type="api_connection", + ) + waits.append(remaining_timeout) + if heartbeat is not None: + last_heartbeat_at = pending_next_last_heartbeat_at or started_at + remaining_heartbeat = heartbeat - (time.monotonic() - last_heartbeat_at) + if remaining_heartbeat <= 0: + pending_next_last_heartbeat_at = time.monotonic() + return "heartbeat", None + waits.append(remaining_heartbeat) + wait_timeout = min(waits) if waits else None + if wait_timeout is None: + chunk = await next_task + pending_next_task = None + pending_next_started_at = None + pending_next_last_heartbeat_at = None + return "chunk", chunk + done, _ = await asyncio.wait({next_task}, timeout=wait_timeout) + if done: + pending_next_task = None + pending_next_started_at = None + pending_next_last_heartbeat_at = None + return "chunk", next_task.result() + last_heartbeat_at = pending_next_last_heartbeat_at or started_at + if heartbeat is not None and time.monotonic() - last_heartbeat_at >= heartbeat: + pending_next_last_heartbeat_at = time.monotonic() + return "heartbeat", None + + async def acquire_upstream_stream() -> tuple[str, Any]: + """Acquire the upstream stream under the same TTFB/disconnect policy.""" + + nonlocal acquire_task, acquire_started_at, acquire_last_heartbeat_at + if acquire_task is None: + acquire_task = asyncio.create_task(client.acompletion(request=request, **chat_kwargs)) + acquire_started_at = ttfb_started_at + acquire_last_heartbeat_at = acquire_started_at + task = acquire_task + started_at = acquire_started_at or time.monotonic() + timeout = stream_settings.ttfb_timeout_seconds + heartbeat = stream_settings.heartbeat_seconds + while True: + if task.done(): + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + return "stream", task.result() + if request is not None and await request.is_disconnected(): + self._trace(transaction_logger, "responses_stream_disconnected", {"reason": "client_disconnected", "phase": "acquire"}, direction="stream", stage="client", metadata={"transport": transport}) + await cancel_task(task) + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + return "disconnect", None + waits = [] + elapsed = time.monotonic() - started_at + if timeout is not None: + remaining_timeout = timeout - elapsed + if remaining_timeout <= 0: + await cancel_task(task) + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + raise ResponsesServiceError("Responses stream TTFB timeout", status_code=504, error_type="api_connection") + waits.append(remaining_timeout) + if heartbeat is not None: + last_heartbeat_at = acquire_last_heartbeat_at or started_at + remaining_heartbeat = heartbeat - (time.monotonic() - last_heartbeat_at) + if remaining_heartbeat <= 0: + acquire_last_heartbeat_at = time.monotonic() + return "heartbeat", None + waits.append(remaining_heartbeat) + wait_timeout = min(waits) if waits else None + if wait_timeout is None: + stream = await task + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + return "stream", stream + done, _ = await asyncio.wait({task}, timeout=wait_timeout) + if done: + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + return "stream", task.result() + last_heartbeat_at = acquire_last_heartbeat_at or started_at + if heartbeat is not None and time.monotonic() - last_heartbeat_at >= heartbeat: + acquire_last_heartbeat_at = time.monotonic() + return "heartbeat", None + + if transaction_logger: + self._trace( + transaction_logger, + "stream_started", + {"event": StreamEvent("started", protocol="responses").to_dict(), "metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="client", + metadata={"transport": transport}, + ) + created = response_created_payload(response_id, unified.model) + self._trace(transaction_logger, "responses_stream_event_created", created, direction="stream", stage="final", metadata={"transport": transport}) + await self._store_stream_current_state(stream_request, created, parent, transaction_logger=transaction_logger) + yield ResponsesStreamEvent("response.created", created) + try: + while chat_stream is None: + marker, acquired = await acquire_upstream_stream() + if marker == "disconnect": + return + if marker == "heartbeat": + self._trace(transaction_logger, "responses_stream_heartbeat", {"comment": "heartbeat", "phase": "acquire"}, direction="stream", stage="transport", metadata={"transport": transport}) + yield ResponsesStreamEvent("heartbeat", {"comment": "heartbeat"}) + continue + chat_stream = acquired + stream_iterator = chat_stream.__aiter__() + first_chunk = True + while True: + try: + marker, raw_chunk = await next_upstream_chunk(first=first_chunk) + except StopAsyncIteration: + break + if marker == "disconnect": + return + if marker == "heartbeat": + self._trace(transaction_logger, "responses_stream_heartbeat", {"comment": "heartbeat"}, direction="stream", stage="transport", metadata={"transport": transport}) + yield ResponsesStreamEvent("heartbeat", {"comment": "heartbeat"}) + continue + first_chunk = False + if monitor.metrics.first_byte_at is None: + monitor.record_event(StreamEvent("raw_chunk", protocol="responses", raw=raw_chunk)) + if transaction_logger: + self._trace( + transaction_logger, + "stream_first_byte", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="provider", + metadata={"transport": transport}, + ) + self._trace(transaction_logger, "raw_chat_bridge_stream_chunk", raw_chunk, direction="stream", stage="provider") + cost_usage = _responses_sse_cost_usage(raw_chunk) + if cost_usage is not None: + usage = _merge_responses_stream_usage(usage, cost_usage) + self._trace(transaction_logger, "responses_stream_cost_event", cost_usage, direction="stream", stage="metadata", metadata={"transport": transport}) + continue + chunk = parse_chat_sse_chunk(raw_chunk) + if not chunk or chunk.get("type") == "done": + continue + self._trace(transaction_logger, "parsed_unified_stream_event", chunk, direction="stream", stage="protocol") + if chunk.get("error") is not None or chunk.get("type") == "error": + raise ResponsesServiceError(_stream_error_message(chunk), status_code=502, error_type="upstream_error") + if chunk.get("usage"): + usage = _merge_responses_stream_usage(_responses_chunk_usage(chunk), usage) + delta = _chunk_text_delta(chunk) + if not delta: + continue + if not item_started: + item_started = True + added = output_item_added_payload(state) + self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": transport}) + yield ResponsesStreamEvent("response.output_item.added", added) + state = ResponsesStreamState( + response_id=state.response_id, + model=state.model, + output_text=state.output_text + delta, + output_item_id=state.output_item_id, + ) + event = output_text_delta_payload(state, delta) + first_visible = monitor.metrics.first_visible_output_at is None + monitor.record_event( + StreamEvent( + "delta", + protocol="responses", + data=event, + visible_output=True, + ) + ) + if first_visible: + if transaction_logger: + self._trace( + transaction_logger, + "stream_first_visible_output", + {"event": event, "metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": transport}, + ) + self._trace(transaction_logger, "responses_stream_event_output_text_delta", event, direction="stream", stage="final", metadata={"transport": transport}) + await self._store_stream_current_state(stream_request, _current_stream_payload(state), parent, transaction_logger=transaction_logger, session_info=session_info) + yield ResponsesStreamEvent("response.output_text.delta", event) + + if not item_started: + added = output_item_added_payload(state) + self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": transport}) + yield ResponsesStreamEvent("response.output_item.added", added) + done_item = output_item_done_payload(state) + self._trace(transaction_logger, "responses_stream_event_output_item_done", done_item, direction="stream", stage="final", metadata={"transport": transport}) + yield ResponsesStreamEvent("response.output_item.done", done_item) + completed = response_completed_payload(state, _usage_to_responses_stream(usage)) + _record_responses_session_anchor(session_info, completed) + self._trace_responses_usage(transaction_logger, completed, unified.model, source="responses_stream") + stored = await self._store_stream_response(stream_request, completed, parent, transaction_logger=transaction_logger, session_info=session_info) + if stored: + self._trace(transaction_logger, "responses_stored_stream_response", completed, direction="metadata", stage="final") + else: + self._trace(transaction_logger, "responses_store_skipped", {"response_id": completed.get("id")}, direction="metadata", stage="final") + monitor.complete() + if transaction_logger: + self._trace( + transaction_logger, + "stream_completed", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": transport}, + ) + self._trace( + transaction_logger, + "stream_metrics_final", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": transport}, + ) + self._trace(transaction_logger, "responses_stream_event_completed", completed, direction="stream", stage="final", metadata={"transport": transport}) + yield ResponsesStreamEvent("response.completed", completed) + self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": transport}) + yield ResponsesStreamEvent("done", {}, terminal=True) + except Exception as exc: + monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) + failed = response_failed_payload(response_id, unified.model, _stream_failure_error(exc)) + if state.output_text: + failed["output"] = [output_item_done_payload(state)["item"]] + self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) + stored = await self._store_stream_response(stream_request, failed, parent, failed=True, transaction_logger=transaction_logger, session_info=session_info) + if stored: + self._trace(transaction_logger, "responses_stored_failed_stream_response", {"response_id": failed.get("id"), "status": "failed"}, direction="metadata", stage="final") + self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": transport}, scrub_strings=True) + if transaction_logger: + self._trace( + transaction_logger, + "stream_metrics_final", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": transport, "failed": True}, + ) + yield ResponsesStreamEvent("response.failed", failed) + self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": transport, "failed": True}) + yield ResponsesStreamEvent("done", {}, terminal=True) + finally: + if chat_stream is None and acquire_task is not None and acquire_task.done() and not acquire_task.cancelled(): + try: + chat_stream = acquire_task.result() + stream_iterator = chat_stream.__aiter__() + except Exception: + chat_stream = None + await cancel_task(pending_next_task) + await cancel_task(acquire_task) + if chat_stream is not None and not upstream_closed: + await close_upstream("wrapper_exit") + + async def get_response(self, response_id: str) -> dict[str, Any]: + """Return a stored response payload or raise a 404-compatible error.""" + + stored = await self.store.get(response_id) + if stored is None: + raise ResponsesServiceError(f"Response not found: {response_id}", status_code=404, error_type="not_found_error") + return deepcopy(stored.response) + + async def delete_response(self, response_id: str) -> dict[str, Any]: + """Delete a stored response and return a compatible deletion object.""" + + deleted = await self.store.delete(response_id) + if not deleted: + raise ResponsesServiceError(f"Response not found: {response_id}", status_code=404, error_type="not_found_error") + return {"id": response_id, "object": "response.deleted", "deleted": True} + + async def list_input_items(self, response_id: str) -> dict[str, Any]: + """Return stored input items for a response continuation.""" + + items = await self.store.list_input_items(response_id) + if items is None: + raise ResponsesServiceError(f"Response not found: {response_id}", status_code=404, error_type="not_found_error") + return {"object": "list", "data": items} + + async def _load_response_lineage(self, parent: Optional[StoredResponse], *, max_depth: int = 20) -> list[StoredResponse]: + """Return parent continuation lineage from oldest to newest.""" + + if parent is None: + return [] + lineage: list[StoredResponse] = [] + seen: set[str] = set() + current: Optional[StoredResponse] = parent + while current is not None and current.id not in seen and len(lineage) < max_depth: + seen.add(current.id) + lineage.append(current) + previous_id = current.request.get("previous_response_id") if isinstance(current.request, dict) else None + if not previous_id: + break + current = await self.store.get(str(previous_id)) + return list(reversed(lineage)) + + async def _load_previous_response(self, response_id: Optional[str], transaction_logger: Optional[Any]) -> Optional[StoredResponse]: + if not response_id: + return None + parent = await self.store.get(response_id) + if parent is None: + raise ResponsesServiceError(f"Previous response not found: {response_id}", status_code=404, error_type="not_found_error") + if transaction_logger: + self._trace( + transaction_logger, + "responses_previous_response_loaded", + parent.to_dict(), + direction="metadata", + stage="adapter", + metadata={ + "previous_response_id": response_id, + "output_count": len(parent.output_items), + "input_item_count": len(parent.input_items), + "bridge_context_expanded": True, + }, + ) + return parent + + def _stored_response( + self, + raw_request: dict[str, Any], + response_payload: dict[str, Any], + parent: Optional[StoredResponse], + *, + session_info: Optional[dict[str, Any]] = None, + ) -> StoredResponse: + session_info = session_info or {} + session_affinity_key = session_info.get("session_affinity_key") or (parent.metadata.get("session_affinity_key") if parent else None) + return StoredResponse( + id=str(response_payload["id"]), + created_at=float(response_payload.get("created_at") or time.time()), + model=str(response_payload.get("model") or raw_request.get("model") or ""), + status=str(response_payload.get("status") or "completed"), + request=deepcopy(raw_request), + response=deepcopy(response_payload), + input_items=_input_items(raw_request), + output_items=deepcopy(response_payload.get("output") or []), + usage=deepcopy(response_payload.get("usage")) if isinstance(response_payload.get("usage"), dict) else None, + metadata={ + "previous_response_id": parent.id if parent else raw_request.get("previous_response_id"), + "response_id": response_payload.get("id"), + "session_affinity_key": session_affinity_key, + }, + session_id=session_info.get("session_id") or (parent.session_id if parent else None), + scope_key=session_info.get("scope_key") or (parent.scope_key if parent else None), + classifier=session_info.get("classifier") or (parent.classifier if parent else None), + expires_at=_expires_at(self.store_settings), + ) + + @staticmethod + def _response_to_dict(response: Any) -> Any: + if isinstance(response, dict): + return deepcopy(response) + if hasattr(response, "model_dump"): + return response.model_dump() + if hasattr(response, "dict"): + return response.dict() + return repr(response) + + @staticmethod + def _trace( + transaction_logger: Optional[Any], + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + metadata: Optional[dict[str, Any]] = None, + scrub_strings: bool = False, + ) -> None: + if not transaction_logger: + return + transaction_logger.log_transform_pass( + pass_name, + data, + direction=direction, + stage=stage, + protocol="responses", + metadata=metadata or {}, + scrub_strings=scrub_strings, + ) + + @staticmethod + def _log_transform_error(transaction_logger: Optional[Any], pass_name: str, error: BaseException, payload: Any) -> None: + if transaction_logger: + transaction_logger.log_transform_error(pass_name, error, payload=payload, stage="adapter", protocol="responses") + + def _trace_responses_usage( + self, + transaction_logger: Optional[Any], + response_payload: dict[str, Any], + model: str, + *, + source: str, + ) -> None: + """Trace normalized Responses usage without changing stored payloads.""" + + if not transaction_logger: + return + usage = response_payload.get("usage") if isinstance(response_payload, dict) else None + if not usage: + return + record = extract_usage_record(usage, provider="responses", model=model, source=source) + cost_breakdown = CostCalculator().calculate(record, model=model, provider="responses") + self._trace( + transaction_logger, + "usage_accounting_summary", + {"usage": record.to_dict(), "cost": cost_breakdown.to_dict()}, + direction="metadata", + stage="final", + metadata={"source": source, "pricing_source": cost_breakdown.pricing_source}, + ) + + async def _store_stream_response( + self, + raw_request: dict[str, Any], + response_payload: dict[str, Any], + parent: Optional[StoredResponse], + *, + failed: bool = False, + transaction_logger: Optional[Any] = None, + session_info: Optional[dict[str, Any]] = None, + ) -> bool: + if not raw_request.get("store", True): + return False + if failed and not self.store_settings.store_failed: + return False + stored = self._stored_response(raw_request, response_payload, parent, session_info=session_info) + try: + await self.store.save(stored) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_store_stream_response", exc, stored.to_dict()) + raise + return True + + async def _store_stream_current_state( + self, + raw_request: dict[str, Any], + response_payload: dict[str, Any], + parent: Optional[StoredResponse], + *, + transaction_logger: Optional[Any], + session_info: Optional[dict[str, Any]] = None, + ) -> bool: + """Optionally persist in-progress stream state for retrieval surfaces.""" + + if not self.store_settings.store_in_progress or not raw_request.get("store", True): + return False + stored = self._stored_response(raw_request, response_payload, parent, session_info=session_info) + try: + await self.store.save(stored) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_store_stream_current_state", exc, stored.to_dict()) + raise + self._trace( + transaction_logger, + "responses_stored_stream_current_state", + {"response_id": response_payload.get("id"), "status": response_payload.get("status")}, + direction="metadata", + stage="final", + ) + return True + + +def _input_items(raw_request: dict[str, Any]) -> list[Any]: + value = raw_request.get("input") + if value is None: + return [] + return deepcopy(value if isinstance(value, list) else [value]) + + +def _current_stream_payload(state: ResponsesStreamState) -> dict[str, Any]: + """Return a retrievable in-progress Responses object for stream state.""" + + payload = response_completed_payload(state) + payload["status"] = "in_progress" + return payload + + +def _expires_at(settings: ResponsesStoreSettings) -> Optional[float]: + """Return the expiration timestamp for a new stored response, if enabled.""" + + ttl = settings.ttl_seconds + if ttl is None or ttl <= 0: + return None + return time.time() + ttl + + +def _internal_client_kwargs(client: Any, session_hints: Any, session_info: dict[str, Any]) -> dict[str, Any]: + """Return hidden kwargs only for the internal RotatingClient path.""" + + if not _supports_internal_context_kwargs(client): + return {} + kwargs: dict[str, Any] = {"_request_context_callback": _capture_request_context(session_info)} + if session_hints: + kwargs["_session_tracking_hints"] = session_hints + return kwargs + + +def _supports_internal_context_kwargs(client: Any) -> bool: + """Return whether a client is the proxy's internal rotating client.""" + + return hasattr(client, "_request_builder") and hasattr(client, "_executor") + + +def _capture_request_context(session_info: dict[str, Any]): + """Build a callback that records non-secret request context metadata.""" + + def capture(context: Any) -> None: + session_info["session_id"] = getattr(context, "session_id", None) + session_info["session_affinity_key"] = getattr(context, "session_affinity_key", None) + session_info["scope_key"] = getattr(context, "usage_manager_key", None) + session_info["classifier"] = getattr(context, "classifier", None) + session_info["session_tracker"] = getattr(context, "session_tracker", None) + session_info["provider"] = getattr(context, "provider", None) + session_info["model"] = getattr(context, "model", None) + session_info["tracking_namespace"] = getattr(context, "session_tracking_namespace", None) + + return capture + + +def _without_internal_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + """Return trace/provider-visible kwargs without proxy-internal controls.""" + + return {key: deepcopy(value) for key, value in kwargs.items() if not key.startswith("_")} + + +def _responses_session_hints(previous_response_id: Optional[str], parent: Optional[StoredResponse], fallback: Any) -> Any: + """Prefer parent stored affinity when building Responses continuation hints.""" + + if not previous_response_id: + return None + parent_affinity = parent.metadata.get("session_affinity_key") if parent else None + return responses_session_hints(previous_response_id, affinity_key=parent_affinity) or fallback + + +def _record_responses_session_anchor(session_info: dict[str, Any], response_payload: dict[str, Any]) -> None: + """Record emitted Responses IDs as response-derived session evidence.""" + + tracker = session_info.get("session_tracker") + session_id = session_info.get("session_id") + if not tracker or not session_id or not response_payload.get("id"): + return + tracker.record_response( + session_id, + provider=session_info.get("provider"), + model=session_info.get("model"), + scope_key=session_info.get("scope_key"), + tracking_namespace=session_info.get("tracking_namespace"), + response={"id": response_payload.get("id"), "object": "response"}, + ) + + +def _stream_error_message(chunk: dict[str, Any]) -> str: + """Return a compact, client-safe message for upstream stream error chunks.""" + + error = chunk.get("error") + if isinstance(error, dict): + message = error.get("message") or error.get("type") + if message: + return str(message) + message = chunk.get("message") + return str(message) if message else "Upstream stream error" + + +def _stream_failure_error(exc: Exception) -> dict[str, Any]: + """Return a client-safe Responses stream failure object.""" + + error_type = getattr(exc, "error_type", None) or exc.__class__.__name__ + result = {"message": str(exc), "type": str(error_type)} + if isinstance(exc, ResponsesServiceError) and error_type == "api_connection": + text = str(exc).lower() + if "ttfb" in text: + result["timeout_type"] = "ttfb" + elif "stall" in text: + result["timeout_type"] = "stall" + return result + + +def _responses_sse_cost_usage(chunk: Any) -> Optional[dict[str, Any]]: + """Extract Responses stream cost metadata from SSE comments/events.""" + + if not isinstance(chunk, str): + return None + payload = _responses_sse_cost_payload(chunk) + if payload is None: + return None + if isinstance(payload, (int, float, str)): + payload = {"provider_reported_cost": payload, "source": "responses_sse_cost"} + if not isinstance(payload, dict): + return None + cost = payload.get("provider_reported_cost", payload.get("request_cost_usd", payload.get("total_cost", payload.get("cost", payload.get("estimated_cost"))))) + if cost is None: + return None + return { + "provider_reported_cost": cost, + "currency": payload.get("currency", "USD"), + "cost_details": payload, + } + + +def _responses_chunk_usage(chunk: dict[str, Any]) -> Any: + """Return stream chunk usage with sibling cost metadata preserved.""" + + usage = chunk.get("usage") + if not isinstance(usage, dict): + return usage + merged = dict(usage) + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in chunk and key not in merged: + merged[key] = deepcopy(chunk[key]) + return merged + + +def _responses_sse_cost_payload(chunk: str) -> Any: + event_type = None + data_lines: list[str] = [] + for line in chunk.strip().splitlines(): + stripped = line.strip() + if stripped.startswith(":"): + comment = stripped[1:].strip() + if comment.startswith("cost"): + return _parse_cost_text(comment[4:].strip()) + continue + if stripped.startswith("event:"): + event_type = stripped[6:].strip() + continue + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + if event_type == "cost" and data_lines: + return _parse_cost_text("\n".join(data_lines).strip()) + return None + + +def _parse_cost_text(text: str) -> Any: + if not text: + return None + try: + return json.loads(text) + except json.JSONDecodeError: + try: + return float(text) + except ValueError: + return None + + +def _merge_responses_stream_usage(primary: Any, fallback_cost: Any) -> Any: + """Merge earlier stream cost metadata into later token usage when needed.""" + + if not isinstance(primary, dict): + return fallback_cost if primary is None else primary + if not isinstance(fallback_cost, dict): + return primary + merged = deepcopy(primary) + has_cost = any(key in merged for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd")) + if has_cost: + return merged + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency"): + if key in fallback_cost: + merged[key] = deepcopy(fallback_cost[key]) + return merged + + +def _chunk_text_delta(chunk: dict[str, Any]) -> str: + choices = chunk.get("choices") if isinstance(chunk.get("choices"), list) else [] + if not choices: + return "" + delta = choices[0].get("delta") if isinstance(choices[0], dict) else None + if not isinstance(delta, dict): + return "" + content = delta.get("content") + return content if isinstance(content, str) else "" + + +def _usage_to_responses_stream(usage: Any) -> Any: + if not isinstance(usage, dict): + return usage + result = { + "input_tokens": usage.get("prompt_tokens", usage.get("input_tokens", 0)), + "output_tokens": usage.get("completion_tokens", usage.get("output_tokens", 0)), + "total_tokens": usage.get("total_tokens", 0), + } + prompt_details = usage.get("prompt_tokens_details") or usage.get("input_tokens_details") + if isinstance(prompt_details, dict): + result["input_tokens_details"] = {"cached_tokens": prompt_details.get("cached_tokens", 0)} + completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") + if isinstance(completion_details, dict): + result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency"): + if key in usage: + result[key] = deepcopy(usage[key]) + return result diff --git a/src/rotator_library/responses/store.py b/src/rotator_library/responses/store.py new file mode 100644 index 000000000..e8e3f387d --- /dev/null +++ b/src/rotator_library/responses/store.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Storage backends for Responses API objects.""" + +from __future__ import annotations + +import json +from copy import deepcopy +from pathlib import Path +from typing import Any, Optional, Protocol + +from .types import StoredResponse + + +class ResponsesStore(Protocol): + """Minimal async store for response retrieval and continuation.""" + + async def save(self, response: StoredResponse) -> None: ... + + async def get(self, response_id: str) -> Optional[StoredResponse]: ... + + async def delete(self, response_id: str) -> bool: ... + + async def list_input_items(self, response_id: str) -> Optional[list[Any]]: ... + + +class InMemoryResponsesStore: + """Process-local Responses store. + + This is the Phase 4 default because it has no async lifecycle and avoids a + new persistence dependency. A provider-cache-backed store can be injected by + later configuration code when disk persistence is desired. + """ + + def __init__(self, *, max_items: int | None = None) -> None: + self._responses: dict[str, StoredResponse] = {} + self.max_items = max_items if max_items and max_items > 0 else None + + async def save(self, response: StoredResponse) -> None: + self._prune_expired() + self._responses[response.id] = StoredResponse.from_dict(response.to_dict()) + self._prune_overflow() + + async def get(self, response_id: str) -> Optional[StoredResponse]: + response = self._responses.get(response_id) + if response is None: + return None + if response.is_expired(): + self._responses.pop(response_id, None) + return None + return StoredResponse.from_dict(response.to_dict()) + + async def delete(self, response_id: str) -> bool: + return self._responses.pop(response_id, None) is not None + + async def list_input_items(self, response_id: str) -> Optional[list[Any]]: + response = await self.get(response_id) + if response is None: + return None + return deepcopy(response.input_items) + + def _prune_expired(self) -> None: + for response_id, response in list(self._responses.items()): + if response.is_expired(): + self._responses.pop(response_id, None) + + def _prune_overflow(self) -> None: + if not self.max_items: + return + while len(self._responses) > self.max_items: + oldest_id = min(self._responses.values(), key=lambda response: response.created_at).id + self._responses.pop(oldest_id, None) + + +class ProviderCacheResponsesStore: + """Responses store backed by an injected `ProviderCache` instance. + + The wrapper does not instantiate `ProviderCache`, because that class starts + background tasks. The caller owns cache lifecycle and shutdown. + """ + + def __init__(self, provider_cache: Any, *, prefix: str = "responses") -> None: + self._cache = provider_cache + self._prefix = prefix + + async def save(self, response: StoredResponse) -> None: + await self._cache.store_async(self._key(response.id), json.dumps(response.to_dict(), ensure_ascii=False)) + flush = getattr(self._cache, "_save_to_disk", None) + if callable(flush): + await flush() + + async def get(self, response_id: str) -> Optional[StoredResponse]: + raw = await self._cache.retrieve_async(self._key(response_id)) + if raw is None: + return None + response = StoredResponse.from_dict(json.loads(raw)) + if response.is_expired(): + await self.delete(response_id) + return None + return response + + async def delete(self, response_id: str) -> bool: + delete = getattr(self._cache, "delete_async", None) + if delete: + return bool(await delete(self._key(response_id))) + # ProviderCache currently exposes clear(), not key-level deletion. When + # key deletion is unavailable, avoid clearing unrelated provider state. + return False + + async def list_input_items(self, response_id: str) -> Optional[list[Any]]: + response = await self.get(response_id) + if response is None: + return None + return deepcopy(response.input_items) + + def _key(self, response_id: str) -> str: + safe_id = response_id.replace("/", "_").replace("\\", "_").replace(":", "_") + return f"{self._prefix}:{safe_id}" + + +def create_configured_responses_store(*, config: Any = None, env: Any = None) -> ResponsesStore: + """Create the configured Responses store backend. + + Memory remains the default. The provider-cache backend uses the existing JSON + cache implementation so durable storage does not require a new database. + """ + + from ..config.experimental import get_responses_store_runtime_settings, get_responses_store_settings + from ..providers.provider_cache import create_provider_cache + + runtime = get_responses_store_runtime_settings(config=config, env=env) + settings = get_responses_store_settings(config=config, env=env) + if runtime.backend == "memory": + return InMemoryResponsesStore(max_items=settings.max_items) + cache_dir = Path(runtime.cache_dir) if runtime.cache_dir else None + provider_cache = create_provider_cache( + runtime.cache_name, + cache_dir=cache_dir, + memory_ttl_seconds=runtime.cache_memory_ttl_seconds, + disk_ttl_seconds=runtime.cache_disk_ttl_seconds, + env_prefix=f"{runtime.cache_name.upper().replace('-', '_')}_CACHE", + ) + return ProviderCacheResponsesStore(provider_cache, prefix=runtime.cache_prefix) diff --git a/src/rotator_library/responses/streaming.py b/src/rotator_library/responses/streaming.py new file mode 100644 index 000000000..f76111134 --- /dev/null +++ b/src/rotator_library/responses/streaming.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""HTTP SSE formatting for the Responses API compatibility layer.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any + +from ..protocols import serialize_value + + +@dataclass(frozen=True) +class ResponsesStreamState: + """Mutable-by-replacement state accumulated from chat stream chunks.""" + + response_id: str + model: str + output_text: str = "" + output_item_id: str = "msg_0" + + +@dataclass(frozen=True) +class ResponsesStreamEvent: + """Transport-neutral Responses stream event. + + Service code yields these events. Formatters decide how to serialize them for + SSE, WebSocket, or future transports without duplicating protocol logic. + """ + + event_name: str + payload: dict[str, Any] + terminal: bool = False + + @property + def heartbeat(self) -> bool: + """Return whether this event is transport metadata, not model output.""" + + return self.event_name == "heartbeat" + + +class ResponsesSSEFormatter: + """Format Responses API events as HTTP Server-Sent Events.""" + + transport = "sse" + + def format_event(self, event_name: str, payload: dict[str, Any]) -> str: + return f"event: {event_name}\ndata: {json.dumps(serialize_value(payload), ensure_ascii=False)}\n\n" + + def format_stream_event(self, event: ResponsesStreamEvent) -> str: + """Format one transport-neutral event for HTTP SSE.""" + + if event.heartbeat: + return self.format_heartbeat(str(event.payload.get("comment") or "heartbeat")) + if event.terminal: + return self.done() + return self.format_event(event.event_name, event.payload) + + def format_heartbeat(self, comment: str = "heartbeat") -> str: + """Return a non-visible SSE comment heartbeat frame.""" + + safe_comment = comment.replace("\r", " ").replace("\n", " ") + return f": {safe_comment}\n\n" + + def done(self) -> str: + """Return the final compatibility sentinel used by many SSE clients.""" + + return "data: [DONE]\n\n" + + +class ResponsesWebSocketFormatter: + """Placeholder transport seam for future WebSocket Responses support.""" + + transport = "websocket" + future_supported = True + + def format_event(self, event_name: str, payload: dict[str, Any]) -> str: + return json.dumps({"event": event_name, "data": serialize_value(payload)}, ensure_ascii=False) + + def format_stream_event(self, event: ResponsesStreamEvent) -> str: + """Format one transport-neutral event as a WebSocket message payload.""" + + if event.heartbeat: + return json.dumps({"event": "heartbeat", "data": {"visible_output": False}}, ensure_ascii=False) + if event.terminal: + return json.dumps({"event": "done", "data": {}}, ensure_ascii=False) + return self.format_event(event.event_name, event.payload) + + +def parse_chat_sse_chunk(chunk: Any) -> dict[str, Any] | None: + """Decode a chat-completions stream chunk into a dict if possible.""" + + if isinstance(chunk, dict): + return chunk + if not isinstance(chunk, str): + return None + text = chunk.strip() + if not text: + return None + event_name = None + data_lines: list[str] = [] + for line in text.splitlines(): + if line.startswith("event:"): + event_name = line[len("event:") :].strip() + elif line.startswith("data:"): + data_lines.append(line[len("data:") :].strip()) + if data_lines: + text = "\n".join(data_lines).strip() + if text == "[DONE]": + return {"type": "done"} + try: + payload = json.loads(text) + except json.JSONDecodeError: + return None + if isinstance(payload, dict) and event_name and payload.get("type") is None: + payload["type"] = event_name + return payload if isinstance(payload, dict) else None + + +def response_created_payload(response_id: str, model: str) -> dict[str, Any]: + return {"id": response_id, "object": "response", "status": "in_progress", "model": model, "output": []} + + +def output_item_added_payload(state: ResponsesStreamState) -> dict[str, Any]: + return { + "response_id": state.response_id, + "output_index": 0, + "item": {"id": state.output_item_id, "type": "message", "role": "assistant", "content": []}, + } + + +def output_text_delta_payload(state: ResponsesStreamState, delta: str) -> dict[str, Any]: + return { + "response_id": state.response_id, + "item_id": state.output_item_id, + "output_index": 0, + "content_index": 0, + "delta": delta, + } + + +def output_item_done_payload(state: ResponsesStreamState) -> dict[str, Any]: + return { + "response_id": state.response_id, + "output_index": 0, + "item": { + "id": state.output_item_id, + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": state.output_text}], + }, + } + + +def response_completed_payload(state: ResponsesStreamState, usage: Any = None) -> dict[str, Any]: + payload = { + "id": state.response_id, + "object": "response", + "status": "completed", + "model": state.model, + "output": [output_item_done_payload(state)["item"]], + } + if usage is not None: + payload["usage"] = usage + return payload + + +def response_failed_payload(response_id: str, model: str, error: Any) -> dict[str, Any]: + return {"id": response_id, "object": "response", "status": "failed", "model": model, "error": serialize_value(error)} diff --git a/src/rotator_library/responses/types.py b/src/rotator_library/responses/types.py new file mode 100644 index 000000000..9b2a91e29 --- /dev/null +++ b/src/rotator_library/responses/types.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Data types for the Responses API compatibility layer.""" + +from __future__ import annotations + +import secrets +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +from ..protocols import serialize_value + + +def generate_response_id() -> str: + """Return a local Responses-compatible identifier. + + Upstream IDs are preserved when providers return them. This helper is only + used by the bridge path when the current chat-completions backend has no + native Responses ID to expose. + """ + + return f"resp_{secrets.token_urlsafe(18).replace('-', '').replace('_', '')[:24]}" + + +@dataclass +class ResponsesStoreSettings: + """Runtime policy for Responses object storage. + + The defaults preserve Phase 4 behavior. Operators can opt into TTL, bounded + memory, failed-response persistence, or in-progress updates without changing + the store interface or introducing a database. + """ + + ttl_seconds: Optional[int] = None + max_items: Optional[int] = None + store_failed: bool = True + store_in_progress: bool = False + + +@dataclass +class StoredResponse: + """Persisted response object used for retrieval and continuation. + + The shape stores both client-facing response data and enough request/session + metadata for `previous_response_id` debugging. It deliberately avoids storing + credential secrets; callers should pass only stable identifiers if they need + credential correlation later. + """ + + id: str + model: str + status: str + response: dict[str, Any] + request: dict[str, Any] = field(default_factory=dict) + input_items: list[Any] = field(default_factory=list) + output_items: list[Any] = field(default_factory=list) + usage: Optional[dict[str, Any]] = None + metadata: dict[str, Any] = field(default_factory=dict) + created_at: float = field(default_factory=time.time) + session_id: Optional[str] = None + scope_key: Optional[str] = None + classifier: Optional[str] = None + expires_at: Optional[float] = None + + def is_expired(self, now: Optional[float] = None) -> bool: + """Return whether the response should be treated as unavailable.""" + + return self.expires_at is not None and (now if now is not None else time.time()) >= self.expires_at + + def to_dict(self) -> dict[str, Any]: + """Serialize with JSON-safe values for disk/cache persistence.""" + + return serialize_value( + { + "id": self.id, + "created_at": self.created_at, + "model": self.model, + "status": self.status, + "request": self.request, + "response": self.response, + "input_items": self.input_items, + "output_items": self.output_items, + "usage": self.usage, + "metadata": self.metadata, + "session_id": self.session_id, + "scope_key": self.scope_key, + "classifier": self.classifier, + "expires_at": self.expires_at, + } + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "StoredResponse": + """Rehydrate a stored response from a JSON-compatible dict.""" + + return cls( + id=str(data["id"]), + created_at=float(data.get("created_at") or time.time()), + model=str(data.get("model") or ""), + status=str(data.get("status") or "completed"), + request=dict(data.get("request") or {}), + response=dict(data.get("response") or {}), + input_items=list(data.get("input_items") or []), + output_items=list(data.get("output_items") or []), + usage=data.get("usage") if isinstance(data.get("usage"), dict) else None, + metadata=dict(data.get("metadata") or {}), + session_id=data.get("session_id"), + scope_key=data.get("scope_key"), + classifier=data.get("classifier"), + expires_at=data.get("expires_at"), + ) diff --git a/src/rotator_library/retry_policy.py b/src/rotator_library/retry_policy.py new file mode 100644 index 000000000..313c35c53 --- /dev/null +++ b/src/rotator_library/retry_policy.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Retry, cooldown, and target-failover policy helpers. + +This module centralizes decisions that sit above the existing error classifier. +It deliberately delegates parsing and classification to `error_handler.py` so the +proxy keeps its current retry-after parser and credential-rotation semantics. +""" + +from __future__ import annotations + +import asyncio +import os +import time +from collections import deque +from dataclasses import dataclass +from typing import Any, Optional + +from .error_handler import ClassifiedError, classify_error, should_retry_same_key, should_rotate_on_error +from .routing import FallbackPolicy +from .routing.policy import normalize_route_error_type +from .routing.types import FallbackGroup + +DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS = 30 + + +@dataclass(frozen=True) +class ProviderCooldownDecision: + """Decision describing whether a provider-level cooldown should start.""" + + should_start: bool + duration: int = 0 + reason: str = "not_applicable" + scope: str = "provider" + model: Optional[str] = None + backoff_level: int = 0 + + +@dataclass(frozen=True) +class FailureHistoryEntry: + """One sanitized provider/model failure event kept in memory only.""" + + timestamp: float + provider: str + model: Optional[str] + error_type: str + scope: str + duration: int + reason: str + + +def classify_route_error(error: BaseException, provider: Optional[str] = None) -> str: + """Map an exception into the vocabulary consumed by fallback policy.""" + + if isinstance(error, asyncio.CancelledError): + return "cancelled" + explicit = getattr(error, "error_type", None) + if explicit: + return normalize_route_error_type(str(explicit)) + return normalize_route_error_type(classify_error(error, provider).error_type) + + +def should_retry_same_credential(classified_error: ClassifiedError, small_cooldown_threshold: int) -> bool: + """Return whether the current credential should be retried before rotation.""" + + return should_retry_same_key(classified_error, small_cooldown_threshold) + + +def should_rotate_credential(classified_error: ClassifiedError) -> bool: + """Return whether a classified failure should rotate to another credential.""" + + return should_rotate_on_error(classified_error) + + +def decide_provider_cooldown( + classified_error: ClassifiedError, + *, + small_cooldown_threshold: int, + provider_cooldown_min_seconds: int, + default_duration: int = DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS, + cooldown_on_quota: bool = False, + provider: Optional[str] = None, + model: Optional[str] = None, + original_error: Any = None, + failure_history: "FailureHistory | None" = None, +) -> ProviderCooldownDecision: + """Return whether a provider-wide cooldown should be activated. + + Small retry-after values are intentionally left to same-credential retry to + preserve cache/session locality. Larger retry-after values can indicate a + provider-wide or IP-level throttle and are therefore safe to coordinate via + the provider cooldown manager. + """ + + error_type = classified_error.error_type + if error_type == "quota_exceeded" and not cooldown_on_quota: + return ProviderCooldownDecision(False, reason="quota_cooldown_disabled") + if error_type not in {"rate_limit", "server_error", "api_connection", "quota_exceeded"}: + return ProviderCooldownDecision(False, reason="non_provider_cooldown_error") + scope = "model" if model and is_model_capacity_error(original_error or classified_error.original_exception) else "provider" + + retry_after = classified_error.retry_after + if retry_after is not None: + if retry_after <= 0: + return ProviderCooldownDecision(False, reason="non_positive_retry_after") + if retry_after < small_cooldown_threshold: + return ProviderCooldownDecision(False, reason="small_retry_after") + if retry_after < provider_cooldown_min_seconds: + return ProviderCooldownDecision(False, reason="below_provider_cooldown_minimum") + return ProviderCooldownDecision(True, duration=int(retry_after), reason="retry_after", scope=scope, model=model if scope == "model" else None) + + if error_type in {"server_error", "api_connection"} and default_duration >= provider_cooldown_min_seconds: + backoff_level = 0 + duration = int(default_duration) + if scope == "model": + return ProviderCooldownDecision(True, duration=duration, reason="model_capacity_cooldown", scope=scope, model=model, backoff_level=backoff_level) + if failure_history is None: + return ProviderCooldownDecision(False, reason="missing_failure_history", scope=scope, model=model if scope == "model" else None) + backoff = failure_history.backoff_for(provider=provider, error_type=error_type, scope=scope, model=model if scope == "model" else None, default_duration=duration) + duration = backoff.duration + backoff_level = backoff.level + if backoff_level <= 0: + return ProviderCooldownDecision(False, reason="transient_backoff_threshold_not_met", scope=scope, model=model if scope == "model" else None) + return ProviderCooldownDecision(True, duration=duration, reason="model_capacity_cooldown" if scope == "model" else "default_transient_cooldown", scope=scope, model=model if scope == "model" else None, backoff_level=backoff_level) + return ProviderCooldownDecision(False, reason="missing_retry_after") + + +@dataclass(frozen=True) +class BackoffDecision: + """Bounded backoff duration derived from recent transient failures.""" + + duration: int + level: int = 0 + + +class FailureHistory: + """Bounded in-memory provider/model failure history. + + This is intentionally process-local. It provides enough recent context for + conservative cooldown backoff and future observability without introducing a + persistence layer or changing credential usage accounting. + """ + + def __init__(self, *, max_entries: int | None = None, clock: Any = None) -> None: + settings = _retry_settings() + self.max_entries = max(1, max_entries if max_entries is not None else settings.failure_history_max_entries) + self._entries: deque[FailureHistoryEntry] = deque(maxlen=self.max_entries) + self._clock = clock or time.time + + def record(self, *, provider: str, model: Optional[str], error_type: str, scope: str, duration: int, reason: str) -> None: + """Record one sanitized cooldown/failure event.""" + + self._entries.append( + FailureHistoryEntry( + timestamp=float(self._clock()), + provider=provider, + model=model, + error_type=error_type, + scope=scope, + duration=duration, + reason=reason, + ) + ) + + def snapshot(self) -> tuple[FailureHistoryEntry, ...]: + """Return recent entries for tests and future read-only reporting.""" + + return tuple(self._entries) + + def clear(self, *, provider: str, model: Optional[str] = None, scope: Optional[str] = None, error_type: Optional[str] = None) -> None: + """Clear matching failure entries after a successful provider/model call.""" + + kept = [ + entry + for entry in self._entries + if not ( + entry.provider == provider + and (scope is None or entry.scope == scope) + and (error_type is None or entry.error_type == error_type) + and (entry.scope != "model" or model is None or entry.model == model) + ) + ] + self._entries.clear() + self._entries.extend(kept) + + def backoff_for(self, *, provider: Optional[str], error_type: str, scope: str, model: Optional[str], default_duration: int) -> BackoffDecision: + """Return bounded backoff for repeated transient failures.""" + + settings = _retry_settings() + window = settings.provider_backoff_window_seconds + threshold = max(1, settings.provider_backoff_threshold) + base = max(1, settings.provider_backoff_base_seconds or default_duration) + max_seconds = max(base, settings.provider_backoff_max_seconds) + now = float(self._clock()) + recent = [ + entry + for entry in self._entries + if now - entry.timestamp <= window + and entry.provider == provider + and entry.error_type == error_type + and entry.scope == scope + and (scope != "model" or entry.model == model) + ] + if len(recent) + 1 < threshold: + return BackoffDecision(default_duration, level=0) + level = len(recent) + 1 - threshold + 1 + return BackoffDecision(min(max_seconds, base * (2 ** (level - 1))), level=level) + + +def is_model_capacity_error(error: Any) -> bool: + """Return whether an error indicates model/deployment capacity exhaustion.""" + + if error is None: + return False + if isinstance(error, dict): + text = str(error).lower() + else: + parts = [str(error)] + response = getattr(error, "response", None) + if response is not None: + parts.append(str(getattr(response, "text", ""))) + body = getattr(error, "body", None) + if body is not None: + parts.append(str(body)) + text = " ".join(parts).lower() + return "model_capacity_exhausted" in text or "model capacity" in text or "capacity exhausted" in text + + +def provider_cooldown_env() -> tuple[int, int, bool]: + """Read provider-cooldown env controls with conservative defaults.""" + + settings = _retry_settings() + return settings.provider_cooldown_min_seconds, settings.provider_cooldown_default_seconds, settings.provider_cooldown_on_quota + + +def is_target_failover_eligible( + error_type: str, + *, + group: FallbackGroup | None = None, + stream: bool = False, + emitted_output: bool = False, +) -> bool: + """Return whether a target failure may advance to the next target.""" + + return FallbackPolicy().should_fallback(error_type, group=group, stream=stream, emitted_output=emitted_output) + + +def _env_int(name: str, default: int) -> int: + try: + return int(os.environ.get(name, default)) + except (TypeError, ValueError): + return default + + +def _retry_settings() -> Any: + """Load retry settings lazily to avoid config import cycles at startup.""" + + from .config.experimental import get_retry_runtime_settings + + return get_retry_runtime_settings() diff --git a/src/rotator_library/routing/__init__.py b/src/rotator_library/routing/__init__.py new file mode 100644 index 000000000..9892616bc --- /dev/null +++ b/src/rotator_library/routing/__init__.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Routing and fallback group primitives.""" + +from .config import RoutingConfigError, load_routing_config_from_env, parse_route_target +from .attempts import clone_context_for_target +from .executor import FallbackAttemptRunner, FallbackExhaustedError +from .policy import FallbackPolicy +from .resolver import FallbackResolver +from .types import FallbackGroup, RouteTarget, RoutingConfig, RoutingDecision, TargetGroup, TargetSelector + +__all__ = [ + "FallbackGroup", + "FallbackAttemptRunner", + "FallbackExhaustedError", + "FallbackPolicy", + "FallbackResolver", + "RouteTarget", + "RoutingConfig", + "RoutingConfigError", + "RoutingDecision", + "TargetGroup", + "TargetSelector", + "clone_context_for_target", + "load_routing_config_from_env", + "parse_route_target", +] diff --git a/src/rotator_library/routing/attempts.py b/src/rotator_library/routing/attempts.py new file mode 100644 index 000000000..bfdc5f695 --- /dev/null +++ b/src/rotator_library/routing/attempts.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Helpers for creating per-target request contexts.""" + +from __future__ import annotations + +from dataclasses import replace +from typing import Any, Sequence + +from ..core.types import RequestContext +from .types import RouteTarget + + +def clone_context_for_target( + context: RequestContext, + target: RouteTarget, + *, + credentials: Sequence[str] | None = None, + usage_manager_key: str | None = None, + provider_config: dict[str, Any] | None = None, + credential_secrets: dict[str, str] | None = None, + target_index: int = 0, +) -> RequestContext: + """Return a target-specific request context without mutating the original. + + Fallback routing must preserve the original request context for traceability + and for safe retry decisions. This helper copies the request kwargs, updates + the selected model/provider, and carries routing metadata for executor traces. + """ + + kwargs: dict[str, Any] = dict(context.kwargs) + kwargs["model"] = target.prefixed_model + next_usage_key = usage_manager_key if usage_manager_key is not None else target.provider + return replace( + context, + model=target.prefixed_model, + provider=target.provider, + kwargs=kwargs, + credentials=list(credentials) if credentials is not None else list(context.credentials), + usage_manager_key=next_usage_key, + provider_config=provider_config if provider_config is not None else context.provider_config, + credential_secrets=dict(credential_secrets) if credential_secrets is not None else dict(context.credential_secrets), + routing_target_index=target_index, + session_tracking_namespace=_namespace_for_target(context.session_tracking_namespace, target, scope_key=next_usage_key), + ) + + +def _namespace_for_target(namespace: str | None, target: RouteTarget, *, scope_key: str) -> str | None: + """Rewrite standard session namespaces for the fallback target provider/model.""" + + if not namespace or ":provider:" not in namespace or ":model:" not in namespace: + return namespace + prefix, _, rest = namespace.partition(":provider:") + if not prefix.startswith("scope:"): + return namespace + _provider, sep, _model = rest.partition(":model:") + if not sep: + return namespace + return f"scope:{scope_key}:provider:{target.provider}:model:{target.prefixed_model}" diff --git a/src/rotator_library/routing/config.py b/src/rotator_library/routing/config.py new file mode 100644 index 000000000..3d7bffb71 --- /dev/null +++ b/src/rotator_library/routing/config.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Environment parser for Phase 6 fallback routing configuration.""" + +from __future__ import annotations + +import os +from collections.abc import Mapping + +from .policy import normalize_route_error_set +from .types import DEFAULT_FAILOVER_ON, DEFAULT_STOP_ON, HARD_STOP_ON, FallbackGroup, RouteTarget, RoutingConfig + + +class RoutingConfigError(ValueError): + """Raised when fallback routing configuration is invalid.""" + + +def parse_route_target(spec: str) -> RouteTarget: + """Parse `provider/model` with optional `@execution` suffix. + + The suffix is intentionally small because Phase 6 prioritizes ordered + fallback groups. Rich selector syntax belongs to the config polish phase. + """ + + text = spec.strip() + if not text: + raise RoutingConfigError("route target cannot be empty") + target_text, _, execution = text.partition("@") + if "/" not in target_text: + raise RoutingConfigError(f"route target requires provider/model: {spec}") + provider, model = target_text.split("/", 1) + if not provider or not model: + raise RoutingConfigError(f"route target requires provider/model: {spec}") + return RouteTarget(provider=provider.strip(), model=model.strip(), execution=(execution.strip() or "auto")) + + +def load_routing_config_from_env(env: Mapping[str, str] | None = None, config: object | None = None) -> RoutingConfig: + """Load fallback groups and model-route aliases from JSON then environment. + + Environment variables intentionally remain the final override layer so the + existing `.env` deployment model keeps working exactly as before. The + optional JSON object is a convenience for structured routing plans. + """ + + source = env if env is not None else os.environ + if config is None: + from ..config.experimental import load_experimental_config + + config = load_experimental_config(env=source) + groups: dict[str, FallbackGroup] = _groups_from_json_config(config) + + group_names = _csv(source.get("FALLBACK_GROUPS", "")) + if len(group_names) != len(set(group_names)): + raise RoutingConfigError("fallback group names must be unique") + for name in group_names: + key = f"FALLBACK_GROUP_{_env_key(name)}" + target_specs = _csv(source.get(key, "")) + if not target_specs: + raise RoutingConfigError(f"fallback group {name} has no targets") + failover_on = _policy_set(source.get(f"{key}_FAILOVER_ON"), DEFAULT_FAILOVER_ON, name) + stop_on = _policy_set(source.get(f"{key}_STOP_ON"), DEFAULT_STOP_ON, name, allow_hard_stop=True) + streaming_policy = _streaming_policy(source.get(f"{key}_STREAMING_POLICY", "pre_output_only"), name) + groups[name] = FallbackGroup(name=name, targets=tuple(parse_route_target(spec) for spec in target_specs), failover_on=failover_on, stop_on=stop_on, streaming_policy=streaming_policy) + + model_routes: dict[str, str] = _model_routes_from_json_config(config) + for key, value in source.items(): + if not key.startswith("MODEL_ROUTE_"): + continue + model_alias = key[len("MODEL_ROUTE_") :].lower() + route = value.strip() + model_routes[model_alias] = route + _validate_model_routes(model_routes, groups) + return RoutingConfig(fallback_groups=groups, model_routes=model_routes) + + +def _groups_from_json_config(config: object) -> dict[str, FallbackGroup]: + routing = getattr(config, "routing", {}) + if not isinstance(routing, Mapping): + return {} + raw_groups = routing.get("fallback_groups", {}) + if not isinstance(raw_groups, Mapping): + raise RoutingConfigError("routing.fallback_groups must be an object") + groups: dict[str, FallbackGroup] = {} + for name, raw_group in raw_groups.items(): + if not isinstance(raw_group, Mapping): + raise RoutingConfigError(f"fallback group {name} must be an object") + raw_targets = raw_group.get("targets", []) + if not isinstance(raw_targets, list) or not raw_targets: + raise RoutingConfigError(f"fallback group {name} has no targets") + groups[str(name)] = FallbackGroup( + name=str(name), + targets=tuple(parse_route_target(str(spec)) for spec in raw_targets), + failover_on=_policy_set(raw_group.get("failover_on"), DEFAULT_FAILOVER_ON, str(name)), + stop_on=_policy_set(raw_group.get("stop_on"), DEFAULT_STOP_ON, str(name), allow_hard_stop=True), + streaming_policy=_streaming_policy(raw_group.get("streaming_policy", "pre_output_only"), str(name)), + max_targets=int(raw_group["max_targets"]) if raw_group.get("max_targets") is not None else None, + metadata=dict(raw_group.get("metadata", {})) if isinstance(raw_group.get("metadata", {}), Mapping) else {}, + ) + return groups + + +def _model_routes_from_json_config(config: object) -> dict[str, str]: + routing = getattr(config, "routing", {}) + if not isinstance(routing, Mapping): + return {} + raw_routes = routing.get("model_routes", {}) + if not isinstance(raw_routes, Mapping): + raise RoutingConfigError("routing.model_routes must be an object") + return {str(alias).lower(): str(route).strip() for alias, route in raw_routes.items()} + + +def _validate_model_routes(model_routes: Mapping[str, str], groups: Mapping[str, FallbackGroup]) -> None: + for key, route in model_routes.items(): + if route.startswith("group:") and route[len("group:") :] not in groups: + raise RoutingConfigError(f"model route {key} references unknown fallback group {route}") + + +def _string_set(value: object, default: frozenset[str]) -> frozenset[str]: + if value is None: + return default + if isinstance(value, str): + return frozenset(_csv(value)) + if isinstance(value, list): + return frozenset(str(item) for item in value) + raise RoutingConfigError("routing policy lists must be strings or arrays") + + +def _policy_set(value: object, default: frozenset[str], group_name: str, *, allow_hard_stop: bool = False) -> frozenset[str]: + values = normalize_route_error_set(_string_set(value, default)) + unsafe = values & HARD_STOP_ON + if unsafe and not allow_hard_stop: + raise RoutingConfigError(f"fallback group {group_name} failover_on cannot include hard-stop errors: {', '.join(sorted(unsafe))}") + return values + + +def _streaming_policy(value: object, group_name: str): + policy = str(value or "pre_output_only").strip().lower() + if policy not in {"pre_output_only", "never"}: + raise RoutingConfigError(f"fallback group {group_name} has unsupported streaming_policy {value!r}") + return policy + + +def _csv(value: str) -> list[str]: + return [part.strip() for part in value.split(",") if part.strip()] + + +def _env_key(value: str) -> str: + return value.upper().replace("-", "_") diff --git a/src/rotator_library/routing/executor.py b/src/rotator_library/routing/executor.py new file mode 100644 index 000000000..dce897e9c --- /dev/null +++ b/src/rotator_library/routing/executor.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Ordered fallback attempt runner.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from .policy import FallbackPolicy, normalize_route_error_type +from .types import FallbackGroup, RouteAttemptResult, RouteTarget, RoutingDecision + +AttemptCallback = Callable[[RouteTarget, int], Awaitable[Any]] + + +class FallbackExhaustedError(Exception): + """Raised when every eligible fallback target fails.""" + + def __init__(self, decision: RoutingDecision, attempts: tuple[RouteAttemptResult, ...]) -> None: + self.decision = decision + self.attempts = attempts + summary = ", ".join(f"{attempt.target.name}:{attempt.error_type}" for attempt in attempts) + super().__init__(f"fallback group exhausted for {decision.requested_model}: {summary}") + + +class FallbackAttemptRunner: + """Run ordered route targets using an injected per-target attempt callback. + + The runner is independent from `RequestExecutor`. Phase 6 can unit-test + fallback control flow here, then wire the callback to existing per-target + credential retry logic without duplicating policy decisions. + """ + + def __init__(self, policy: FallbackPolicy | None = None) -> None: + self.policy = policy or FallbackPolicy() + + async def run(self, decision: RoutingDecision, attempt: AttemptCallback, *, stream: bool = False) -> Any: + """Try targets in order until one succeeds or fallback is exhausted.""" + + return await self.run_group(decision, decision.group, attempt, stream=stream) + + async def run_group( + self, + decision: RoutingDecision, + group: FallbackGroup | None, + attempt: AttemptCallback, + *, + stream: bool = False, + ) -> Any: + """Try targets while honoring optional group-specific policy overrides.""" + + attempts: list[RouteAttemptResult] = [] + for index, target in enumerate(decision.targets): + try: + return await attempt(target, index) + except Exception as exc: + error_type = _error_type(exc) + emitted_output = bool(getattr(exc, "emitted_output", False)) + attempts.append(RouteAttemptResult(target=target, success=False, error_type=error_type, emitted_output=emitted_output)) + has_next = index < len(decision.targets) - 1 + if not has_next or (stream and getattr(group, "streaming_policy", "pre_output_only") == "never") or not self.policy.should_fallback(error_type, group=group, emitted_output=emitted_output, stream=stream): + raise FallbackExhaustedError(decision, tuple(attempts)) from exc + raise FallbackExhaustedError(decision, tuple(attempts)) + + +def _error_type(error: BaseException) -> str: + return normalize_route_error_type(str(getattr(error, "error_type", error.__class__.__name__))) diff --git a/src/rotator_library/routing/policy.py b/src/rotator_library/routing/policy.py new file mode 100644 index 000000000..549ee114a --- /dev/null +++ b/src/rotator_library/routing/policy.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Fallback eligibility policy for ordered route chains.""" + +from __future__ import annotations + +from .types import DEFAULT_FAILOVER_ON, DEFAULT_STOP_ON, HARD_STOP_ON, FallbackGroup + + +_ALIASES = { + "auth": "authentication", + "invalid_api_key": "authentication", + "invalid_key": "authentication", + "unauthorized": "authentication", + "permission": "forbidden", + "permission_denied": "forbidden", + "access_denied": "forbidden", + "bad_request": "invalid_request", + "invalid_argument": "invalid_request", + "bad_argument": "invalid_request", + "validation": "invalid_request", + "permanent": "invalid_request", + "context_length": "context_window_exceeded", + "context_length_exceeded": "context_window_exceeded", + "context_window": "context_window_exceeded", + "too_many_tokens": "context_window_exceeded", + "max_tokens_exceeded": "context_window_exceeded", + "token_limit_exceeded": "context_window_exceeded", + "pre_request_callback": "pre_request_callback_error", + "configuration": "configuration_error", + "config": "configuration_error", + "quota": "quota_exceeded", + "resource_exhausted": "quota_exceeded", + "capacity": "rate_limit", + "rate_limited": "rate_limit", + "too_many_requests": "rate_limit", + "transient": "server_error", + "unavailable": "server_error", + "network": "api_connection", + "connection": "api_connection", + "deadline_exceeded": "api_connection", + "timeout": "api_connection", + "400": "invalid_request", + "401": "authentication", + "403": "forbidden", + "429": "rate_limit", + "500": "server_error", + "502": "server_error", + "503": "server_error", + "504": "server_error", +} + + +def normalize_route_error_type(error_type: str | None) -> str: + """Return the policy vocabulary for a raw classifier or config value. + + Routing config is user-facing, while the executor consumes classifier output. + Keeping normalization in one place prevents aliases such as `auth` from + bypassing non-overridable hard-stop categories. + """ + + normalized = str(error_type or "").strip().lower().replace("-", "_").replace(" ", "_") + return _ALIASES.get(normalized, normalized) + + +def normalize_route_error_set(values: frozenset[str]) -> frozenset[str]: + """Normalize a policy set while preserving unknown custom categories.""" + + return frozenset(normalize_route_error_type(value) for value in values) + + +class FallbackPolicy: + """Decide whether a failed target may advance to the next target.""" + + def should_fallback( + self, + error_type: str, + *, + group: FallbackGroup | None = None, + emitted_output: bool = False, + stream: bool = False, + ) -> bool: + """Return whether fallback is allowed for a classified failure.""" + + normalized = normalize_route_error_type(error_type) + if stream and emitted_output: + return False + if normalized in HARD_STOP_ON: + return False + active_stop = normalize_route_error_set(group.stop_on if group else DEFAULT_STOP_ON) + active_failover = normalize_route_error_set(group.failover_on if group else DEFAULT_FAILOVER_ON) + if normalized in active_stop: + return False + return normalized in active_failover diff --git a/src/rotator_library/routing/resolver.py b/src/rotator_library/routing/resolver.py new file mode 100644 index 000000000..53b66d58b --- /dev/null +++ b/src/rotator_library/routing/resolver.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Resolve requested models to direct targets or fallback groups.""" + +from __future__ import annotations + +from .config import RoutingConfigError, parse_route_target +from .types import FallbackGroup, RouteTarget, RoutingConfig, RoutingDecision + + +class FallbackResolver: + """Resolve a model name using deterministic fallback group rules.""" + + def __init__(self, config: RoutingConfig) -> None: + self.config = config + + def resolve(self, requested_model: str) -> RoutingDecision: + """Return the ordered targets for a requested model.""" + + route = self.config.model_routes.get(requested_model.lower()) + if route and route.startswith("group:"): + group_name = route[len("group:") :] + group = self.config.fallback_groups.get(group_name) + if not group: + raise RoutingConfigError(f"unknown fallback group {group_name}") + targets = _promote_requested_target(group, requested_model) + reason = "model_route_group_promoted" if targets != group.targets else "model_route_group" + return RoutingDecision(requested_model=requested_model, group_name=group.name, group=group, targets=targets, reason=reason) + if route: + return RoutingDecision(requested_model=requested_model, targets=(parse_route_target(route),), reason="model_route_target") + for group in self.config.fallback_groups.values(): + targets = _promote_requested_target(group, requested_model) + if targets != group.targets or any(_same_target(target, requested_model) for target in group.targets): + return RoutingDecision(requested_model=requested_model, group_name=group.name, group=group, targets=targets, reason="provider_model_group_promoted") + if "/" in requested_model: + return RoutingDecision(requested_model=requested_model, targets=(parse_route_target(requested_model),), reason="direct_provider_model") + raise RoutingConfigError(f"model {requested_model!r} is not provider-prefixed and has no route") + + +def _promote_requested_target(group: FallbackGroup, requested_model: str) -> tuple[RouteTarget, ...]: + """Return group targets with the requested provider/model attempted first.""" + + matching = [target for target in group.targets if _same_target(target, requested_model)] + if not matching: + return group.targets + selected = matching[0] + return (selected, *(target for target in group.targets if target is not selected)) + + +def _same_target(target: RouteTarget, requested_model: str) -> bool: + return target.prefixed_model.lower() == requested_model.lower() diff --git a/src/rotator_library/routing/types.py b/src/rotator_library/routing/types.py new file mode 100644 index 000000000..b0598e00b --- /dev/null +++ b/src/rotator_library/routing/types.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Typed route targets and fallback groups.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal, Protocol + +ExecutionMode = Literal["auto", "native", "custom", "litellm_fallback"] +StreamingFallbackPolicy = Literal["pre_output_only", "never"] + +DEFAULT_FAILOVER_ON = frozenset( + { + "rate_limit", + "quota_exceeded", + "server_error", + "api_connection", + "unsupported_operation", + # Human-friendly aliases for config files; classifier output uses the + # names above, but config authors should not need to know every internal + # error string. + "quota", + "capacity", + "transient", + } +) +DEFAULT_STOP_ON = frozenset( + { + "authentication", + "forbidden", + "invalid_request", + "context_window_exceeded", + "credential_reauth_needed", + "pre_request_callback_error", + "cancelled", + # Config aliases retained for readability. + "auth", + "validation", + "permanent", + "pre_request_callback", + } +) +HARD_STOP_ON = frozenset( + { + "authentication", + "forbidden", + "invalid_request", + "context_window_exceeded", + "credential_reauth_needed", + "pre_request_callback_error", + "cancelled", + "configuration_error", + } +) + + +@dataclass(frozen=True) +class RouteTarget: + """One concrete provider/model execution target in a fallback chain.""" + + provider: str + model: str + name: str = "" + protocol: str | None = None + execution: ExecutionMode = "auto" + priority: int | None = None + weight: float | None = None + conditions: dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if not self.provider or not self.model: + raise ValueError("route targets require provider and model") + if self.execution not in {"auto", "native", "custom", "litellm_fallback"}: + raise ValueError(f"unsupported execution mode: {self.execution}") + if not self.name: + object.__setattr__(self, "name", f"{self.provider}/{self.model}") + + @property + def prefixed_model(self) -> str: + """Return `provider/model` without double-prefixing an already-prefixed model.""" + + return self.model if self.model.startswith(f"{self.provider}/") else f"{self.provider}/{self.model}" + + +@dataclass(frozen=True) +class FallbackGroup: + """Deterministic ordered chain of route targets.""" + + name: str + targets: tuple[RouteTarget, ...] + failover_on: frozenset[str] = DEFAULT_FAILOVER_ON + stop_on: frozenset[str] = DEFAULT_STOP_ON + streaming_policy: StreamingFallbackPolicy = "pre_output_only" + max_targets: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if not self.name: + raise ValueError("fallback group name is required") + if not self.targets: + raise ValueError("fallback groups require at least one target") + if self.max_targets is not None and self.max_targets <= 0: + raise ValueError("max_targets must be positive") + if self.max_targets is not None and len(self.targets) > self.max_targets: + raise ValueError("fallback group target count exceeds max_targets") + + +@dataclass(frozen=True) +class RoutingConfig: + """Routing configuration loaded from env or tests.""" + + fallback_groups: dict[str, FallbackGroup] = field(default_factory=dict) + model_routes: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RoutingDecision: + """Resolved routing plan for a requested model.""" + + requested_model: str + targets: tuple[RouteTarget, ...] + group_name: str | None = None + group: FallbackGroup | None = None + selected_target_index: int = 0 + reason: str = "direct" + + +@dataclass(frozen=True) +class RouteAttemptResult: + """Result summary for one attempted route target.""" + + target: RouteTarget + success: bool + error_type: str | None = None + emitted_output: bool = False + usage: dict[str, Any] = field(default_factory=dict) + + +class TargetSelector(Protocol): + """Future target-group selector seam; Phase 6 keeps ordered fallback only.""" + + def select(self, targets: tuple[RouteTarget, ...]) -> RouteTarget: + """Return one target from a richer target group.""" + + +@dataclass(frozen=True) +class TargetGroup: + """Future richer target group; not used for Phase 6 ordered fallback.""" + + name: str + targets: tuple[RouteTarget, ...] + selector: str = "ordered" diff --git a/src/rotator_library/session_tracking.py b/src/rotator_library/session_tracking.py index c30b3b738..f54588097 100644 --- a/src/rotator_library/session_tracking.py +++ b/src/rotator_library/session_tracking.py @@ -773,6 +773,21 @@ def _anchors_from_response(self, response: Any, namespace: str) -> List[SessionA data = response.model_dump() if hasattr(response, "model_dump") else response if not isinstance(data, dict): return [] + anchors: List[SessionAnchor] = [] + response_id = data.get("id") + if response_id: + # Responses API continuations identify their parent with + # previous_response_id. Record the emitted response id as strong + # response evidence so the next request can route back to the exact + # session/credential that produced it. + anchors.append( + SessionAnchor( + self._scoped(namespace, f"provider:responses_previous_response_id:{response_id}"), + "strong", + source="response", + group="responses_previous_response_id", + ) + ) messages: List[Dict[str, Any]] = [] for choice in data.get("choices") or []: if not isinstance(choice, dict): @@ -782,7 +797,9 @@ def _anchors_from_response(self, response: Any, namespace: str) -> List[SessionA response_message = dict(message) response_message.setdefault("role", "assistant") messages.append(response_message) - return self._anchors_from_messages(messages, namespace, source="response") if messages else [] + if messages: + anchors.extend(self._anchors_from_messages(messages, namespace, source="response")) + return anchors def _affinity_from_anchors( self, diff --git a/src/rotator_library/streaming/__init__.py b/src/rotator_library/streaming/__init__.py new file mode 100644 index 000000000..209c068a8 --- /dev/null +++ b/src/rotator_library/streaming/__init__.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Reusable streaming primitives shared by protocol and provider layers.""" + +from .events import StreamEvent, stream_event_from_sse_chunk +from .errors import StreamingErrorDecision, decide_streaming_error_action +from .metrics import StreamMetrics, StreamMonitor +from .policy import can_retry_stream_after_error, is_visible_stream_output +from .transport import JSONLineStreamFormatter, SSEStreamFormatter, WebSocketStreamFormatter + +__all__ = [ + "JSONLineStreamFormatter", + "SSEStreamFormatter", + "StreamEvent", + "StreamingErrorDecision", + "StreamMetrics", + "StreamMonitor", + "WebSocketStreamFormatter", + "can_retry_stream_after_error", + "decide_streaming_error_action", + "is_visible_stream_output", + "stream_event_from_sse_chunk", +] diff --git a/src/rotator_library/streaming/errors.py b/src/rotator_library/streaming/errors.py new file mode 100644 index 000000000..94467e228 --- /dev/null +++ b/src/rotator_library/streaming/errors.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Side-effect-free decisions for streaming failures.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from ..error_handler import ClassifiedError, classify_error, should_retry_same_key, should_rotate_on_error +from ..retry_policy import ProviderCooldownDecision, decide_provider_cooldown +from .policy import can_retry_stream_after_error, is_visible_stream_output + + +@dataclass(frozen=True) +class StreamingErrorDecision: + """Decision returned by streaming error policy before executor side effects.""" + + classified: ClassifiedError + action: str + start_provider_cooldown: bool = False + provider_cooldown_duration: int = 0 + provider_cooldown_scope: str = "provider" + provider_cooldown_model: str | None = None + reason: str = "" + + +def decide_streaming_error_action( + error: Exception, + *, + provider: str, + last_streamed_chunk: str | None, + attempt: int, + max_retries: int, + small_cooldown_threshold: int, + provider_cooldown_min_seconds: int, + provider_cooldown_default_seconds: int, + cooldown_on_quota: bool = False, + allow_reasoning_only_retry: bool = False, + model: str | None = None, + emitted_output: bool = False, +) -> StreamingErrorDecision: + """Classify a stream failure without sleeping or mutating state. + + The executor remains responsible for logging, credential state, sleeping, + provider cooldown mutation, and fallback. This helper only makes the same + decision consistently across streaming exception branches. + """ + + classified = classify_error(error, provider) + cooldown = _cooldown_decision( + classified, + small_cooldown_threshold=small_cooldown_threshold, + provider_cooldown_min_seconds=provider_cooldown_min_seconds, + provider_cooldown_default_seconds=provider_cooldown_default_seconds, + cooldown_on_quota=cooldown_on_quota, + last_streamed_chunk=last_streamed_chunk, + emitted_output=emitted_output, + provider=provider, + model=model, + original_error=error, + ) + if not should_rotate_on_error(classified): + return _decision(classified, "fail", cooldown, "non_rotatable") + if emitted_output or not can_retry_stream_after_error(last_streamed_chunk, allow_reasoning_only_retry): + return _decision(classified, "fallback_blocked_after_output", cooldown, "visible_output") + if should_retry_same_key(classified, small_cooldown_threshold) and attempt < max_retries - 1: + return _decision(classified, "retry_same", cooldown, "retry_same_credential") + return _decision(classified, "rotate", cooldown, "rotate_credential") + + +def _cooldown_decision( + classified: ClassifiedError, + *, + small_cooldown_threshold: int, + provider_cooldown_min_seconds: int, + provider_cooldown_default_seconds: int, + cooldown_on_quota: bool, + last_streamed_chunk: str | None, + emitted_output: bool, + provider: str, + model: str | None, + original_error: Exception, +) -> ProviderCooldownDecision: + if emitted_output or is_visible_stream_output(last_streamed_chunk): + return ProviderCooldownDecision(False, reason="visible_output") + return decide_provider_cooldown( + classified, + small_cooldown_threshold=small_cooldown_threshold, + provider_cooldown_min_seconds=provider_cooldown_min_seconds, + default_duration=provider_cooldown_default_seconds, + cooldown_on_quota=cooldown_on_quota, + provider=provider, + model=model, + original_error=original_error, + ) + + +def _decision( + classified: ClassifiedError, + action: str, + cooldown: ProviderCooldownDecision, + reason: str, +) -> StreamingErrorDecision: + return StreamingErrorDecision( + classified=classified, + action=action, + start_provider_cooldown=cooldown.should_start, + provider_cooldown_duration=cooldown.duration, + provider_cooldown_scope=cooldown.scope, + provider_cooldown_model=cooldown.model, + reason=reason if not cooldown.should_start else f"{reason};cooldown:{cooldown.reason}", + ) diff --git a/src/rotator_library/streaming/events.py b/src/rotator_library/streaming/events.py new file mode 100644 index 000000000..294b6b267 --- /dev/null +++ b/src/rotator_library/streaming/events.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Transport-neutral stream event model.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import UTC, datetime +import json +from typing import Any, Optional + +from ..protocols import serialize_value + + +@dataclass(frozen=True) +class StreamEvent: + """A normalized stream event before transport-specific formatting. + + The event model is intentionally broad so providers can attach native data + without inventing a new stream lifecycle object. Raw data is for trace/debug + use only and should still pass through transaction-log redaction when logged. + """ + + event_type: str + data: Any = field(default_factory=dict) + protocol: Optional[str] = None + transport: str = "sse" + raw: Any = None + metadata: dict[str, Any] = field(default_factory=dict) + visible_output: bool = False + timestamp_utc: str = field(default_factory=lambda: datetime.now(UTC).isoformat()) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-safe representation for traces and transports.""" + + return { + "event_type": self.event_type, + "protocol": self.protocol, + "transport": self.transport, + "data": serialize_value(self.data), + "raw": serialize_value(self.raw), + "metadata": serialize_value(self.metadata), + "visible_output": self.visible_output, + "timestamp_utc": self.timestamp_utc, + } + + +def stream_event_from_sse_chunk(chunk: Any, *, protocol: str = "openai_chat") -> StreamEvent: + """Parse an SSE `data:` chunk into a normalized stream event. + + Malformed or non-JSON chunks are treated as metadata and fail closed for + visibility. This avoids accidental fallback after ambiguous client output. + """ + + if isinstance(chunk, dict): + data = chunk + raw = chunk + elif isinstance(chunk, str): + raw = chunk + text = chunk.strip() + if text.startswith("data:"): + text = text[len("data:") :].strip() + if text == "[DONE]": + return StreamEvent("completed", protocol=protocol, raw=chunk, data={"done": True}) + try: + data = json.loads(text) + except json.JSONDecodeError: + return StreamEvent("metadata", protocol=protocol, raw=chunk, data={"malformed": True}) + else: + return StreamEvent("metadata", protocol=protocol, raw=chunk, data={"unsupported_chunk": True}) + + if isinstance(data, dict) and data.get("error"): + return StreamEvent("error", protocol=protocol, raw=raw, data=data, visible_output=False) + visible = _openai_chat_visible(data) if protocol == "openai_chat" else False + event_type = "delta" if visible else "parsed_chunk" + return StreamEvent(event_type, protocol=protocol, raw=raw, data=data, visible_output=visible) + + +def _openai_chat_visible(data: Any) -> bool: + if not isinstance(data, dict): + return False + choices = data.get("choices") + if not isinstance(choices, list): + return False + for choice in choices: + if not isinstance(choice, dict): + continue + delta = choice.get("delta") or {} + message = choice.get("message") or {} + for source in (delta, message): + if not isinstance(source, dict): + continue + if _has_visible_text(source.get("content")) or _has_visible_text(source.get("text")): + return True + if source.get("tool_calls") or source.get("function_call"): + return True + return False + + +def _has_visible_text(value: Any) -> bool: + if isinstance(value, str): + return bool(value.strip()) + if isinstance(value, list): + for item in value: + if isinstance(item, str) and item.strip(): + return True + if isinstance(item, dict) and _has_visible_text(item.get("text")): + return True + return False diff --git a/src/rotator_library/streaming/metrics.py b/src/rotator_library/streaming/metrics.py new file mode 100644 index 000000000..64470d700 --- /dev/null +++ b/src/rotator_library/streaming/metrics.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Stream lifecycle metrics and stall monitoring.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Optional + +from .events import StreamEvent + + +Clock = Callable[[], float] + + +@dataclass +class StreamMetrics: + """Timing and lifecycle counters for one stream.""" + + started_at: float + first_byte_at: Optional[float] = None + first_visible_output_at: Optional[float] = None + last_chunk_at: Optional[float] = None + completed_at: Optional[float] = None + chunk_count: int = 0 + visible_chunk_count: int = 0 + error_count: int = 0 + cancelled: bool = False + + @property + def ttfb_seconds(self) -> Optional[float]: + return None if self.first_byte_at is None else self.first_byte_at - self.started_at + + @property + def ttft_seconds(self) -> Optional[float]: + return None if self.first_visible_output_at is None else self.first_visible_output_at - self.started_at + + @property + def duration_seconds(self) -> Optional[float]: + return None if self.completed_at is None else self.completed_at - self.started_at + + @property + def idle_seconds(self) -> Optional[float]: + if self.last_chunk_at is None or self.completed_at is None: + return None + return self.completed_at - self.last_chunk_at + + def to_dict(self) -> dict: + return { + "started_at": self.started_at, + "first_byte_at": self.first_byte_at, + "first_visible_output_at": self.first_visible_output_at, + "last_chunk_at": self.last_chunk_at, + "completed_at": self.completed_at, + "chunk_count": self.chunk_count, + "visible_chunk_count": self.visible_chunk_count, + "error_count": self.error_count, + "cancelled": self.cancelled, + "ttfb_seconds": self.ttfb_seconds, + "ttft_seconds": self.ttft_seconds, + "duration_seconds": self.duration_seconds, + "idle_seconds": self.idle_seconds, + } + + +class StreamMonitor: + """Record stream lifecycle events without changing stream behavior.""" + + def __init__(self, *, clock: Clock) -> None: + self._clock = clock + self.metrics = StreamMetrics(started_at=clock()) + + def record_event(self, event: StreamEvent) -> None: + now = self._clock() + if self.metrics.first_byte_at is None: + self.metrics.first_byte_at = now + self.metrics.last_chunk_at = now + self.metrics.chunk_count += 1 + if event.visible_output: + self.metrics.visible_chunk_count += 1 + if self.metrics.first_visible_output_at is None: + self.metrics.first_visible_output_at = now + if event.event_type == "error": + self.metrics.error_count += 1 + + def complete(self) -> None: + self.metrics.completed_at = self._clock() + + def cancel(self) -> None: + self.metrics.cancelled = True + self.metrics.completed_at = self._clock() + + def is_stalled(self, timeout_seconds: float) -> bool: + if timeout_seconds <= 0 or self.metrics.last_chunk_at is None: + return False + return self._clock() - self.metrics.last_chunk_at > timeout_seconds diff --git a/src/rotator_library/streaming/policy.py b/src/rotator_library/streaming/policy.py new file mode 100644 index 000000000..1dfb0ac55 --- /dev/null +++ b/src/rotator_library/streaming/policy.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Stream retry and visible-output policy.""" + +from __future__ import annotations + +import json +from typing import Any, Optional + +_REASONING_FIELDS = ( + "reasoning", + "reasoning_content", + "thinking", + "thinking_content", +) + + +def can_retry_stream_after_error(last_streamed_chunk: Optional[str], allow_reasoning_only_retry: bool) -> bool: + """Return whether an upstream stream can be retried after an error.""" + + if last_streamed_chunk is None: + return True + if is_stream_heartbeat_or_comment(last_streamed_chunk): + return True + metadata = _sse_json(last_streamed_chunk, malformed_is_visible=False) + if isinstance(metadata, dict) and (metadata.get("event_type") or metadata.get("type")) == "cost": + return True + if not allow_reasoning_only_retry: + return False + data = metadata + if data is None: + return False + + has_reasoning = False + choices = data.get("choices") + if not isinstance(choices, list): + return False + for choice in choices: + if not isinstance(choice, dict): + return False + for source in (choice, choice.get("delta"), choice.get("message")): + if not isinstance(source, dict): + continue + if _has_visible_text(source.get("content")) or _has_visible_text(source.get("text")): + return False + if source.get("tool_calls") or source.get("function_call"): + return False + if any(_has_visible_text(source.get(key)) for key in _REASONING_FIELDS): + has_reasoning = True + return has_reasoning + + +def is_visible_stream_output(chunk: Optional[str], *, protocol: str = "openai_chat") -> bool: + """Return whether a formatted stream chunk should block fallback. + + Malformed or ambiguous chunks fail closed by counting as visible output. This + preserves the existing safety rule that route fallback must not happen after + a client may have received model output. + """ + + if chunk is None: + return False + data = _sse_json(chunk, malformed_is_visible=True) + if data is _MALFORMED_VISIBLE: + return True + if data is None: + return False + if data.get("error"): + return False + event_type = data.get("event_type") or data.get("type") + if isinstance(event_type, str) and event_type.startswith("response."): + return _responses_visible(data) + if protocol == "responses": + return _responses_visible(data) + return _openai_chat_visible(data) + + +def is_stream_heartbeat_or_comment(chunk: Optional[str]) -> bool: + """Return true for SSE comment-only frames that must not affect retry state.""" + + if chunk is None: + return False + payload = chunk.strip() + return bool(payload) and all(line.startswith(":") for line in payload.splitlines() if line.strip()) + + +_MALFORMED_VISIBLE = object() + + +def _sse_json(chunk: str, *, malformed_is_visible: bool) -> dict[str, Any] | object | None: + payload = chunk.strip() + if not payload: + return None + if all(line.startswith(":") for line in payload.splitlines() if line.strip()): + return None + event_type = None + data_lines: list[str] = [] + for line in payload.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith(":"): + continue + if stripped.startswith("event:"): + event_type = stripped[6:].strip() + continue + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + continue + return _MALFORMED_VISIBLE if malformed_is_visible else None + if not data_lines: + if event_type in {"error", "response.failed"}: + return {"event_type": event_type} + return None + payload = "\n".join(data_lines).strip() + if not payload or payload == "[DONE]": + return None + try: + parsed = json.loads(payload) + except json.JSONDecodeError: + return _MALFORMED_VISIBLE if malformed_is_visible else None + if not isinstance(parsed, dict): + if event_type == "cost": + return {"event_type": "cost", "value": parsed} + return _MALFORMED_VISIBLE if malformed_is_visible else None + if event_type and "event_type" not in parsed: + parsed["event_type"] = event_type + return parsed + + +def _openai_chat_visible(data: dict[str, Any]) -> bool: + choices = data.get("choices") + if not isinstance(choices, list): + return False + for choice in choices: + if not isinstance(choice, dict): + continue + for source in (choice.get("delta"), choice.get("message")): + if not isinstance(source, dict): + continue + if _has_visible_text(source.get("content")) or _has_visible_text(source.get("text")): + return True + if source.get("tool_calls") or source.get("function_call"): + return True + return False + + +def _responses_visible(data: dict[str, Any]) -> bool: + event_type = data.get("event_type") or data.get("type") + if event_type == "response.output_text.delta": + return bool(str(data.get("delta", "")).strip()) + if isinstance(event_type, str) and ("function_call" in event_type or "tool_call" in event_type): + return _has_visible_text(data.get("delta")) or _has_visible_text(data.get("arguments")) or _has_visible_text(data.get("item")) + if event_type == "response.failed": + return False + return False + + +def _has_visible_text(value: Any) -> bool: + if isinstance(value, str): + return bool(value.strip()) + if isinstance(value, list): + for item in value: + if isinstance(item, str) and item.strip(): + return True + if isinstance(item, dict) and _has_visible_text(item.get("text")): + return True + return False diff --git a/src/rotator_library/streaming/transport.py b/src/rotator_library/streaming/transport.py new file mode 100644 index 000000000..b7ebd1b43 --- /dev/null +++ b/src/rotator_library/streaming/transport.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Transport formatters for normalized stream events.""" + +from __future__ import annotations + +import json + +from .events import StreamEvent + + +class SSEStreamFormatter: + """Format stream events as HTTP Server-Sent Events.""" + + transport = "sse" + + def format_event(self, event: StreamEvent) -> str: + return f"event: {event.event_type}\ndata: {json.dumps(event.to_dict(), ensure_ascii=False)}\n\n" + + def format_error(self, event: StreamEvent) -> str: + return self.format_event(event) + + def format_done(self) -> str: + return "data: [DONE]\n\n" + + def format_heartbeat(self, comment: str = "heartbeat") -> str: + """Return an SSE comment heartbeat frame. + + Comment frames keep HTTP connections active without becoming model + output, so routing/session code must continue treating them as + non-visible stream data. + """ + + safe_comment = comment.replace("\r", " ").replace("\n", " ") + return f": {safe_comment}\n\n" + + def is_terminal_event(self, event: StreamEvent) -> bool: + return event.event_type in {"completed", "cancelled", "error"} + + +class WebSocketStreamFormatter: + """Format stream events as WebSocket JSON messages without exposing a route.""" + + transport = "websocket" + future_supported = True + + def format_event(self, event: StreamEvent) -> dict: + return {"type": event.event_type, "payload": event.to_dict()} + + def format_error(self, event: StreamEvent) -> dict: + return self.format_event(event) + + def format_done(self) -> dict: + return {"type": "completed", "payload": {"done": True}} + + def format_heartbeat(self) -> dict: + """Return the future WebSocket heartbeat message shape.""" + + return {"type": "heartbeat", "payload": {"visible_output": False}} + + def is_terminal_event(self, event: StreamEvent) -> bool: + return event.event_type in {"completed", "cancelled", "error"} + + +class JSONLineStreamFormatter: + """Format stream events as newline-delimited JSON for provider tests.""" + + transport = "jsonl" + + def format_event(self, event: StreamEvent) -> str: + return json.dumps(event.to_dict(), ensure_ascii=False) + "\n" + + def format_error(self, event: StreamEvent) -> str: + return self.format_event(event) + + def format_done(self) -> str: + return json.dumps({"event_type": "completed", "done": True}) + "\n" + + def format_heartbeat(self) -> str: + """Return a JSONL heartbeat record for test transports.""" + + return json.dumps({"event_type": "heartbeat", "visible_output": False}) + "\n" + + def is_terminal_event(self, event: StreamEvent) -> bool: + return event.event_type in {"completed", "cancelled", "error"} diff --git a/src/rotator_library/transaction_logger.py b/src/rotator_library/transaction_logger.py index 1a5fa6df8..5a74e5a43 100644 --- a/src/rotator_library/transaction_logger.py +++ b/src/rotator_library/transaction_logger.py @@ -29,10 +29,11 @@ import time import uuid from dataclasses import dataclass -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path from typing import Any, Dict, Optional, Union +from .transform_trace import TransformTraceWriter, provider_snapshot_namespace from .utils.paths import get_logs_dir lib_logger = logging.getLogger("rotator_library") @@ -58,6 +59,12 @@ def _get_transactions_dir() -> Path: return transactions_dir +def _utc_timestamp() -> str: + """Return an ISO-8601 UTC timestamp for log records.""" + + return datetime.now(UTC).isoformat() + + def _sanitize_name(name: str) -> str: """Sanitize a name for use in directory/file names.""" # Replace problematic characters with underscores @@ -90,6 +97,21 @@ class TransactionContext: model: str """Model name (sanitized for filesystem use).""" + trace_model: Optional[str] = None + """Exact model name used in transform trace entries.""" + + session_id: Optional[str] = None + """Inferred session id for trace correlation, when available.""" + + scope_key: Optional[str] = None + """Usage scope key for trace correlation, when available.""" + + classifier: Optional[str] = None + """Classifier/private routing label for trace correlation, when available.""" + + trace_enabled: bool = False + """Whether provider loggers should append transform trace entries.""" + class TransactionLogger: """ @@ -112,10 +134,15 @@ class TransactionLogger: "request_id", "provider", "model", + "trace_model", + "session_id", + "scope_key", + "classifier", "streaming", "api_format", "_dir_available", "_context", + "_trace_writer", ) def __init__( @@ -140,6 +167,10 @@ def __init__( self.start_time = time.time() self.request_id = str(uuid.uuid4())[:8] # 8-char short ID self.provider = provider + self.trace_model = model + self.session_id: Optional[str] = None + self.scope_key: Optional[str] = None + self.classifier: Optional[str] = None self.api_format = api_format # Strip provider prefix from model if present @@ -153,6 +184,7 @@ def __init__( self.log_dir: Optional[Path] = None self._dir_available = False self._context: Optional[TransactionContext] = None + self._trace_writer: Optional[TransformTraceWriter] = None if not enabled: return @@ -174,6 +206,14 @@ def __init__( try: self.log_dir.mkdir(parents=True, exist_ok=True) self._dir_available = True + self._trace_writer = TransformTraceWriter( + self.log_dir, + component="client", + provider=provider, + model=self.trace_model, + request_id=self.request_id, + enabled=True, + ) except Exception as e: lib_logger.error(f"TransactionLogger: Failed to create directory: {e}") self.enabled = False @@ -192,9 +232,104 @@ def get_context(self) -> TransactionContext: enabled=self.enabled, provider=self.provider, model=self.model, + trace_model=self.trace_model, + session_id=self.session_id, + scope_key=self.scope_key, + classifier=self.classifier, + trace_enabled=bool(self._trace_writer), ) return self._context + def set_trace_context( + self, + *, + session_id: Optional[str] = None, + scope_key: Optional[str] = None, + classifier: Optional[str] = None, + ) -> None: + """Attach routing/session metadata discovered after logger creation.""" + + if session_id is not None: + self.session_id = session_id + if scope_key is not None: + self.scope_key = scope_key + if classifier is not None: + self.classifier = classifier + if self._trace_writer: + self._trace_writer.update_context( + session_id=self.session_id, + scope_key=self.scope_key, + classifier=self.classifier, + ) + if self._context: + self._context.session_id = self.session_id + self._context.scope_key = self.scope_key + self._context.classifier = self.classifier + + def log_transform_pass( + self, + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + protocol: Optional[str] = None, + credential_id: Optional[str] = None, + transport: Optional[str] = None, + changed_from_previous: Optional[bool] = None, + metadata: Optional[Dict[str, Any]] = None, + scrub_strings: bool = False, + snapshot: bool = True, + ) -> None: + """Record an additive transform trace entry if tracing is available.""" + + if not self.enabled or not self._dir_available or not self._trace_writer: + return + self._trace_writer.record( + pass_name, + data, + direction=direction, + stage=stage, + protocol=protocol, + credential_id=credential_id, + transport=transport, + changed_from_previous=changed_from_previous, + metadata=metadata, + scrub_strings=scrub_strings, + snapshot=snapshot, + ) + + def log_transform_error( + self, + failed_pass_name: str, + error: BaseException, + *, + payload: Any = None, + stage: str = "client", + protocol: Optional[str] = None, + transport: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Record a standardized transform/logging failure without raising.""" + + error_data = { + "failed_pass_name": failed_pass_name, + "error_type": type(error).__name__, + "message": str(error), + "payload": payload, + } + self.log_transform_pass( + "transform_log_error", + error_data, + direction="error", + stage=stage, + protocol=protocol, + transport=transport, + metadata=metadata, + scrub_strings=True, + snapshot=False, + ) + def log_request( self, request_data: Dict[str, Any], filename: str = "request.json" ) -> None: @@ -212,15 +347,25 @@ def log_request( data = { "request_id": self.request_id, - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "data": request_data, } + self.log_transform_pass( + "raw_client_request", + request_data, + direction="request", + stage="client", + transport="sse" if self.streaming else "http", + ) self._write_json(filename, data) def log_transformed_request( self, transformed_data: Dict[str, Any], original_data: Dict[str, Any], + *, + credential_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> None: """ Log the transformed request if it differs from the original. @@ -239,18 +384,32 @@ def log_transformed_request( stripped_transformed = _strip_framework_keys(transformed_data) stripped_original = _strip_framework_keys(original_data) + changed_from_previous: Optional[bool] = None try: - if json.dumps(stripped_transformed, sort_keys=True, default=str) == json.dumps( + changed_from_previous = json.dumps(stripped_transformed, sort_keys=True, default=str) != json.dumps( stripped_original, sort_keys=True, default=str - ): - return + ) except (TypeError, ValueError): - pass + changed_from_previous = None + + self.log_transform_pass( + "prepared_provider_request", + transformed_data, + direction="request", + stage="client", + credential_id=credential_id, + transport="sse" if transformed_data.get("stream") else "http", + changed_from_previous=changed_from_previous, + metadata=metadata, + ) + + if changed_from_previous is False: + return logged = _strip_framework_keys(transformed_data) data = { "request_id": self.request_id, - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "data": logged, } self._write_json("request_transformed.json", data) @@ -266,9 +425,17 @@ def log_stream_chunk(self, chunk: Dict[str, Any]) -> None: return log_entry = { - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "chunk": chunk, } + self.log_transform_pass( + "parsed_stream_chunk", + chunk, + direction="stream", + stage="client", + transport="sse", + snapshot=False, + ) content = json.dumps(log_entry, ensure_ascii=False) + "\n" self._append_text("streaming_chunks.jsonl", content) @@ -296,12 +463,20 @@ def log_response( data = { "request_id": self.request_id, - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "status_code": status_code, "duration_ms": round(duration_ms), "headers": dict(headers) if headers else None, "data": response_data, } + self.log_transform_pass( + "final_client_response", + response_data, + direction="response", + stage="final", + transport="sse" if self.streaming else "http", + metadata={"status_code": status_code, "headers": dict(headers) if headers else None}, + ) self._write_json(filename, data) # Also write metadata @@ -331,7 +506,7 @@ def _log_metadata( metadata = { "request_id": self.request_id, - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "duration_ms": round(duration_ms), "status_code": status_code, "provider": self.provider, @@ -543,7 +718,7 @@ class ProviderLogger: or add custom methods for provider-specific logging needs. """ - __slots__ = ("enabled", "log_dir") + __slots__ = ("enabled", "log_dir", "_trace_writer") def __init__(self, context: Optional[TransactionContext]): """ @@ -554,6 +729,7 @@ def __init__(self, context: Optional[TransactionContext]): """ self.enabled = False self.log_dir: Optional[Path] = None + self._trace_writer: Optional[TransformTraceWriter] = None if context is None or not context.enabled: return @@ -563,10 +739,46 @@ def __init__(self, context: Optional[TransactionContext]): try: self.log_dir.mkdir(parents=True, exist_ok=True) + if getattr(context, "trace_enabled", False): + self._trace_writer = TransformTraceWriter( + context.log_dir, + component="provider", + provider=context.provider, + model=context.trace_model or context.model, + request_id=context.request_id, + session_id=context.session_id, + scope_key=context.scope_key, + classifier=context.classifier, + snapshot_namespace=provider_snapshot_namespace(), + enabled=True, + ) except Exception as e: lib_logger.error(f"ProviderLogger: Failed to create directory: {e}") self.enabled = False + def _log_transform_pass( + self, + pass_name: str, + data: Any, + *, + direction: str, + stage: str = "provider", + transport: Optional[str] = None, + scrub_strings: bool = False, + snapshot: bool = True, + ) -> None: + if not self.enabled or not self._trace_writer: + return + self._trace_writer.record( + pass_name, + data, + direction=direction, + stage=stage, + transport=transport, + scrub_strings=scrub_strings, + snapshot=snapshot, + ) + def log_request(self, payload: Dict[str, Any]) -> None: """ Log the request payload sent to the provider API. @@ -574,6 +786,12 @@ def log_request(self, payload: Dict[str, Any]) -> None: Args: payload: The transformed request payload """ + self._log_transform_pass( + "provider_request_payload", + payload, + direction="request", + transport="http", + ) self._write_json("request_payload.json", payload) def log_response_chunk(self, chunk: str) -> None: @@ -583,6 +801,13 @@ def log_response_chunk(self, chunk: str) -> None: Args: chunk: Raw chunk string from the stream """ + self._log_transform_pass( + "provider_raw_stream_chunk", + chunk, + direction="stream", + transport="sse", + snapshot=False, + ) self._append_text("response_stream.log", chunk + "\n") def log_final_response(self, response_data: Dict[str, Any]) -> None: @@ -592,6 +817,11 @@ def log_final_response(self, response_data: Dict[str, Any]) -> None: Args: response_data: The complete response data """ + self._log_transform_pass( + "provider_final_response", + response_data, + direction="response", + ) self._write_json("final_response.json", response_data) def log_error(self, error_message: str) -> None: @@ -601,7 +831,14 @@ def log_error(self, error_message: str) -> None: Args: error_message: The error message to log """ - timestamp = datetime.utcnow().isoformat() + timestamp = _utc_timestamp() + self._log_transform_pass( + "provider_error", + {"timestamp_utc": timestamp, "message": error_message}, + direction="error", + scrub_strings=True, + snapshot=False, + ) self._append_text("error.log", f"[{timestamp}] {error_message}\n") def log_extra(self, filename: str, data: Union[Dict[str, Any], str]) -> None: diff --git a/src/rotator_library/transform_trace.py b/src/rotator_library/transform_trace.py new file mode 100644 index 000000000..513e37a56 --- /dev/null +++ b/src/rotator_library/transform_trace.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Transform-pass trace primitives for transaction logging. + +The trace layer is observability-only: failures here must never change request +execution. Transaction logging uses this module to snapshot each meaningful +request, response, and stream state as later protocol/adapters mutate payloads. +""" + +from __future__ import annotations + +import json +import logging +import re +import uuid +from dataclasses import dataclass, field +from dataclasses import is_dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Mapping, Optional + +from .protocols import serialize_value + +lib_logger = logging.getLogger("rotator_library") + +REDACTED = "[REDACTED]" + +_SENSITIVE_KEYS = frozenset( + { + "api-key", + "credential-identifier", + "authorization", + "proxy-authorization", + "cookie", + "set-cookie", + "x-api-key", + "x-goog-api-key", + "openai-api-key", + "api-key", + "api-secret", + "access-token", + "refresh-token", + "id-token", + "client-secret", + "password", + "secret", + "token", + # Provider protocol-state fields can contain opaque signatures or + # reasoning continuations that field-cache rules preserve and re-inject. + "reasoning-content", + "reasoningcontent", + "thought-signature", + "thoughtsignature", + "signature", + "state", + "provider-state", + "provider-session-id", + "providersessionid", + "prompt-cache-key", + "promptcachekey", + "cache-key", + "cachekey", + "thinking-signatures", + "thinkingsignatures", + "thought-signatures", + "thoughtsignatures", + } +) + +_SENSITIVE_TEXT_RE = re.compile( + r"(?im)\b(authorization|proxy-authorization|x-api-key|x-goog-api-key|openai-api-key|api[_-]?key|access[_-]?token|refresh[_-]?token|client[_-]?secret|cookie|set-cookie)\b(['\"]?\s*[:=]\s*['\"]?)([^'\"\r\n,}]+)" +) + + +def _normalise_key(key: Any) -> str: + camel_split = re.sub(r"(?<=[a-z0-9])(?=[A-Z])", "-", str(key).strip()) + return camel_split.lower().replace("_", "-") + + +def scrub_sensitive_text(value: str) -> str: + """Scrub obvious credential-bearing header/query fragments from text. + + General trace payloads use key-based redaction to avoid hiding model text. + Error strings and object reprs are different: providers often embed HTTP + headers or query strings in exception text, so this targeted scrub only + applies when callers opt into string scrubbing. + """ + + return _SENSITIVE_TEXT_RE.sub(lambda match: f"{match.group(1)}{match.group(2)}{REDACTED}", value) + + +def _object_mapping(value: Any) -> Optional[dict[str, Any]]: + """Extract structured data from common SDK objects before using repr().""" + + for method_name in ("model_dump", "dict"): + method = getattr(value, method_name, None) + if callable(method): + try: + result = method() + if isinstance(result, Mapping): + return dict(result) + except Exception: + pass + structured: dict[str, Any] = {"type": f"{type(value).__module__}.{type(value).__name__}"} + found = False + for attr in ("status_code", "headers", "url", "method", "text", "content"): + if hasattr(value, attr): + try: + structured[attr] = getattr(value, attr) + found = True + except Exception: + pass + if found: + return structured + if hasattr(value, "__dict__") and not is_dataclass(value): + try: + return {"type": structured["type"], "attributes": vars(value)} + except Exception: + return None + return None + + +def sanitize_for_trace(value: Any, *, scrub_strings: bool = False) -> Any: + """Return a JSON-safe, recursively redacted value for trace files. + + Redaction is intentionally key-based rather than value-based. Model text may + legitimately mention tokens, passwords, or secrets; hiding by value would make + debugging transformations unreliable. Sensitive framework fields are redacted + only when their key name is known to carry credentials. + """ + + if isinstance(value, Mapping): + sanitized = {} + block_type = str(value.get("type", "")).lower() + source_field = str(value.get("source_field", "")).lower() + if not source_field and isinstance(value.get("extra"), Mapping): + source_field = str(value["extra"].get("source_field", "")).lower() + is_reasoning_block = "reasoning" in block_type or source_field in {"reasoning_content", "thoughtsignature", "thought_signature", "signature"} + for key, item in value.items(): + if _normalise_key(key) in _SENSITIVE_KEYS or (is_reasoning_block and _normalise_key(key) == "text"): + sanitized[str(key)] = REDACTED + else: + sanitized[str(key)] = sanitize_for_trace(item, scrub_strings=scrub_strings) + return sanitized + if isinstance(value, (list, tuple, set, frozenset)): + return [sanitize_for_trace(item, scrub_strings=scrub_strings) for item in value] + if isinstance(value, str): + return scrub_sensitive_text(value) if scrub_strings else value + + object_mapping = _object_mapping(value) + if object_mapping is not None: + return sanitize_for_trace(object_mapping, scrub_strings=scrub_strings) + + serialized = serialize_value(value) + if isinstance(serialized, dict): + return sanitize_for_trace(serialized, scrub_strings=scrub_strings) + if isinstance(serialized, list): + return [sanitize_for_trace(item, scrub_strings=scrub_strings) for item in serialized] + if isinstance(serialized, str): + return scrub_sensitive_text(serialized) if scrub_strings else serialized + return serialized + + +def sanitize_filename(value: str) -> str: + """Return a stable, filesystem-safe name component for snapshot files.""" + + safe = value.strip().lower().replace(" ", "_") or "trace" + for char in '/\\:*?"<>|': + safe = safe.replace(char, "_") + return "".join(char if char.isalnum() or char in "._-" else "_" for char in safe) + + +@dataclass +class TransformTraceEntry: + """A single transform-pass observation. + + Entries are deliberately protocol-neutral. Later phases can add protocol, + adapter, field-cache, routing, and transport passes without changing the file + format used by this phase. + """ + + sequence: int + component: str + pass_name: str + direction: str + stage: str + request_id: Optional[str] = None + timestamp_utc: str = field(default_factory=lambda: datetime.now(UTC).isoformat()) + protocol: Optional[str] = None + provider: Optional[str] = None + model: Optional[str] = None + credential_id: Optional[str] = None + transport: Optional[str] = None + changed_from_previous: Optional[bool] = None + session_id: Optional[str] = None + scope_key: Optional[str] = None + classifier: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + data: Any = None + scrub_strings: bool = False + + def to_dict(self) -> dict[str, Any]: + return { + "sequence": self.sequence, + "component": self.component, + "pass_name": self.pass_name, + "direction": self.direction, + "stage": self.stage, + "request_id": self.request_id, + "timestamp_utc": self.timestamp_utc, + "protocol": self.protocol, + "provider": self.provider, + "model": self.model, + "credential_id": self.credential_id, + "transport": self.transport, + "changed_from_previous": self.changed_from_previous, + "session_id": self.session_id, + "scope_key": self.scope_key, + "classifier": self.classifier, + "metadata": sanitize_for_trace(self.metadata, scrub_strings=self.scrub_strings), + "data": sanitize_for_trace(self.data, scrub_strings=self.scrub_strings), + } + + +class TransformTraceWriter: + """Append-only writer for transform trace entries and snapshots. + + One writer owns one local sequence counter. Phase 2 intentionally does not + promise a global ordering across client and provider writers; entries include + component and timestamps so interleaved logs remain understandable. + """ + + def __init__( + self, + log_dir: Path, + *, + component: str, + provider: Optional[str] = None, + model: Optional[str] = None, + request_id: Optional[str] = None, + session_id: Optional[str] = None, + scope_key: Optional[str] = None, + classifier: Optional[str] = None, + snapshot_namespace: Optional[str] = None, + enabled: bool = True, + ) -> None: + self.log_dir = log_dir + self.component = component + self.provider = provider + self.model = model + self.request_id = request_id + self.session_id = session_id + self.scope_key = scope_key + self.classifier = classifier + self.snapshot_namespace = snapshot_namespace + self.enabled = enabled + self._sequence = 0 + self.trace_file = log_dir / "transform_trace.jsonl" + self.snapshot_dir = log_dir / "transforms" + + def update_context( + self, + *, + session_id: Optional[str] = None, + scope_key: Optional[str] = None, + classifier: Optional[str] = None, + ) -> None: + """Update immutable-ish correlation fields discovered after creation.""" + + if session_id is not None: + self.session_id = session_id + if scope_key is not None: + self.scope_key = scope_key + if classifier is not None: + self.classifier = classifier + + def record( + self, + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + protocol: Optional[str] = None, + credential_id: Optional[str] = None, + transport: Optional[str] = None, + changed_from_previous: Optional[bool] = None, + session_id: Optional[str] = None, + scope_key: Optional[str] = None, + classifier: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + scrub_strings: bool = False, + snapshot: bool = True, + ) -> Optional[TransformTraceEntry]: + """Record a transform pass, swallowing logging failures.""" + + if not self.enabled: + return None + self._sequence += 1 + entry = TransformTraceEntry( + sequence=self._sequence, + component=self.component, + pass_name=pass_name, + direction=direction, + stage=stage, + request_id=self.request_id, + protocol=protocol, + provider=self.provider, + model=self.model, + credential_id=credential_id, + transport=transport, + changed_from_previous=changed_from_previous, + session_id=session_id if session_id is not None else self.session_id, + scope_key=scope_key if scope_key is not None else self.scope_key, + classifier=classifier if classifier is not None else self.classifier, + metadata=metadata or {}, + data=data, + scrub_strings=scrub_strings, + ) + try: + self.log_dir.mkdir(parents=True, exist_ok=True) + with open(self.trace_file, "a", encoding="utf-8") as handle: + handle.write(json.dumps(entry.to_dict(), ensure_ascii=False) + "\n") + if snapshot and direction != "stream": + self.snapshot_dir.mkdir(parents=True, exist_ok=True) + namespace = f"{sanitize_filename(self.snapshot_namespace)}_" if self.snapshot_namespace else "" + snapshot_name = f"{entry.sequence:04d}_{namespace}{sanitize_filename(pass_name)}.json" + with open(self.snapshot_dir / snapshot_name, "w", encoding="utf-8") as handle: + json.dump(entry.to_dict(), handle, indent=2, ensure_ascii=False) + except Exception as exc: + lib_logger.debug("Transform trace write failed for %s: %s", pass_name, exc) + return entry + + +def provider_snapshot_namespace() -> str: + """Return a short namespace that prevents provider snapshot collisions.""" + + return f"provider_{uuid.uuid4().hex[:8]}" diff --git a/src/rotator_library/usage/__init__.py b/src/rotator_library/usage/__init__.py index f8ae947fe..2ab7f5bd8 100644 --- a/src/rotator_library/usage/__init__.py +++ b/src/rotator_library/usage/__init__.py @@ -56,11 +56,21 @@ # Main facade (imports components above) from .manager import UsageManager, CredentialContext +from .accounting import UsageRecord, extract_usage_record +from .costs import CostBreakdown, CostCalculator, ModelPricing +from .quota import QuotaSnapshot, build_quota_snapshots __all__ = [ # Main public API "UsageManager", "CredentialContext", + "UsageRecord", + "CostBreakdown", + "CostCalculator", + "ModelPricing", + "QuotaSnapshot", + "build_quota_snapshots", + "extract_usage_record", # Types "WindowStats", "TotalStats", diff --git a/src/rotator_library/usage/accounting.py b/src/rotator_library/usage/accounting.py new file mode 100644 index 000000000..1b1bdb122 --- /dev/null +++ b/src/rotator_library/usage/accounting.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Normalized usage accounting across provider protocols. + +The existing `UsageManager` remains the source of truth for persistence, +windows, and credential selection. This module only converts provider-specific +usage payloads into the numeric buckets that `CredentialContext.mark_success()` +already accepts. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from ..protocols import serialize_value + + +@dataclass(frozen=True) +class UsageRecord: + """Provider-neutral token usage buckets. + + `completion_tokens` excludes `reasoning_tokens` when providers report hidden + reasoning separately. This preserves the current no-double-count behavior and + gives Phase 9+ cost logic clear input/output/reasoning buckets. + """ + + input_tokens: int = 0 + completion_tokens: int = 0 + reasoning_tokens: int = 0 + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 + raw_total_tokens: int = 0 + request_count: int = 1 + source: str = "unknown" + provider: Optional[str] = None + model: Optional[str] = None + provider_reported_cost: Optional[float] = None + cost_currency: str = "USD" + cost_source: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def output_tokens(self) -> int: + return self.completion_tokens + self.reasoning_tokens + + @property + def total_tokens(self) -> int: + return ( + self.input_tokens + + self.cache_read_tokens + + self.cache_write_tokens + + self.completion_tokens + + self.reasoning_tokens + ) + + @property + def prompt_tokens_for_mark_success(self) -> int: + """Return non-cache-read prompt tokens for existing usage storage.""" + + return self.input_tokens + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-safe summary for transform traces and tests.""" + + return { + "input_tokens": self.input_tokens, + "completion_tokens": self.completion_tokens, + "reasoning_tokens": self.reasoning_tokens, + "cache_read_tokens": self.cache_read_tokens, + "cache_write_tokens": self.cache_write_tokens, + "output_tokens": self.output_tokens, + "total_tokens": self.total_tokens, + "raw_total_tokens": self.raw_total_tokens, + "request_count": self.request_count, + "source": self.source, + "provider": self.provider, + "model": self.model, + "provider_reported_cost": self.provider_reported_cost, + "cost_currency": self.cost_currency, + "cost_source": self.cost_source, + "metadata": serialize_value(self.metadata), + } + + +def extract_usage_record( + response_or_usage: Any, + *, + provider: Optional[str] = None, + model: Optional[str] = None, + source: str = "response", +) -> UsageRecord: + """Extract normalized usage from a response object or usage payload.""" + + usage = _unwrap_usage(response_or_usage) + if usage is None: + return UsageRecord(provider=provider, model=model, source=source) + data = _as_dict(usage) + if not data: + return UsageRecord(provider=provider, model=model, source=source) + + if "usageMetadata" in data and isinstance(data["usageMetadata"], dict): + data = data["usageMetadata"] + + if _looks_like_gemini(data): + return _from_gemini_usage(data, provider=provider, model=model, source=source) + if _looks_like_anthropic(data): + return _from_anthropic_usage(data, provider=provider, model=model, source=source) + return _from_openai_like_usage(data, provider=provider, model=model, source=source) + + +def _unwrap_usage(value: Any) -> Any: + if value is None: + return None + if isinstance(value, dict): + if "usage" in value: + usage = _as_dict(value.get("usage")) + return _merge_top_level_cost_fields(usage, value) + return value + data = _as_dict(value) + if "usage" in data: + usage = _as_dict(data.get("usage")) + return _merge_top_level_cost_fields(usage, data) + return getattr(value, "usage", value) + + +def _merge_top_level_cost_fields(usage: dict[str, Any], response: dict[str, Any]) -> dict[str, Any]: + """Copy sibling cost metadata into a nested usage payload.""" + + merged = dict(usage) + for key in ("cost", "total_cost", "estimated_cost", "cost_details", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in response and key not in merged: + merged[key] = response[key] + return merged + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return value + if hasattr(value, "to_dict"): + dumped = value.to_dict() + return dumped if isinstance(dumped, dict) else {} + if hasattr(value, "model_dump"): + dumped = value.model_dump() + return dumped if isinstance(dumped, dict) else {} + if hasattr(value, "dict"): + dumped = value.dict() + return dumped if isinstance(dumped, dict) else {} + result: dict[str, Any] = {} + for key in ( + "prompt_tokens", + "usage", + "completion_tokens", + "total_tokens", + "prompt_tokens_details", + "completion_tokens_details", + "cache_read_tokens", + "cache_creation_tokens", + "input_tokens", + "output_tokens", + "input_tokens_details", + "output_tokens_details", + "cache_creation_input_tokens", + "cache_read_input_tokens", + "cached_tokens", + "cache_creation_tokens", + "cache_write_tokens", + "reasoning_tokens", + "thinking_tokens", + "cost", + "cost_details", + "costMetadata", + "provider_reported_cost", + "request_cost_usd", + "estimated_cost", + "total_cost", + "currency", + ): + if hasattr(value, key): + result[key] = getattr(value, key) + return result + + +def _from_openai_like_usage(data: dict[str, Any], *, provider: Optional[str], model: Optional[str], source: str) -> UsageRecord: + prompt_tokens = _int(data.get("prompt_tokens", data.get("input_tokens", 0))) + completion_tokens = _int(data.get("completion_tokens", data.get("output_tokens", 0))) + raw_total = _int(data.get("total_tokens", data.get("raw_total_tokens", 0))) + + prompt_details = _as_dict(data.get("prompt_tokens_details") or data.get("input_tokens_details") or {}) + completion_details = _as_dict(data.get("completion_tokens_details") or data.get("output_tokens_details") or {}) + cache_read = _int( + data.get( + "cache_read_tokens", + data.get("cached_tokens", prompt_details.get("cached_tokens", 0)), + ) + ) + cache_write = _int( + data.get( + "cache_creation_tokens", + data.get("cache_write_tokens", prompt_details.get("cache_creation_tokens", 0)), + ) + ) + reasoning = _int( + data.get( + "reasoning_tokens", + completion_details.get("reasoning_tokens", completion_details.get("thinking_tokens", 0)), + ) + ) + if reasoning and completion_tokens >= reasoning: + completion_tokens -= reasoning + input_tokens = max(0, prompt_tokens - cache_read - cache_write) + cost = _extract_cost(data) + return UsageRecord( + input_tokens=input_tokens, + completion_tokens=completion_tokens, + reasoning_tokens=reasoning, + cache_read_tokens=cache_read, + cache_write_tokens=cache_write, + raw_total_tokens=raw_total, + source=source, + provider=provider, + model=model, + provider_reported_cost=cost[0], + cost_currency=cost[1], + cost_source=cost[2], + metadata={"shape": "openai_like"}, + ) + + +def _from_anthropic_usage(data: dict[str, Any], *, provider: Optional[str], model: Optional[str], source: str) -> UsageRecord: + cache_read = _int(data.get("cache_read_input_tokens", data.get("cache_read_tokens", 0))) + cache_write = _int(data.get("cache_creation_input_tokens", data.get("cache_creation_tokens", 0))) + input_tokens = max(0, _int(data.get("input_tokens", 0)) - cache_read - cache_write) + output_tokens = _int(data.get("output_tokens", data.get("completion_tokens", 0))) + reasoning = _int(data.get("reasoning_tokens", data.get("thinking_tokens", 0))) + if reasoning and output_tokens >= reasoning: + output_tokens -= reasoning + cost = _extract_cost(data) + return UsageRecord( + input_tokens=input_tokens, + completion_tokens=output_tokens, + reasoning_tokens=reasoning, + cache_read_tokens=cache_read, + cache_write_tokens=cache_write, + raw_total_tokens=_int(data.get("total_tokens", 0)), + source=source, + provider=provider, + model=model, + provider_reported_cost=cost[0], + cost_currency=cost[1], + cost_source=cost[2], + metadata={"shape": "anthropic"}, + ) + + +def _from_gemini_usage(data: dict[str, Any], *, provider: Optional[str], model: Optional[str], source: str) -> UsageRecord: + cache_read = _int(data.get("cachedContentTokenCount", data.get("cache_read_tokens", 0))) + prompt_tokens = _int(data.get("promptTokenCount", data.get("prompt_tokens", 0))) + reasoning = _int(data.get("thoughtsTokenCount", data.get("reasoning_tokens", 0))) + completion = _int(data.get("candidatesTokenCount", data.get("completion_tokens", 0))) + if reasoning and completion >= reasoning: + completion -= reasoning + cost = _extract_cost(data) + return UsageRecord( + input_tokens=max(0, prompt_tokens - cache_read), + completion_tokens=completion, + reasoning_tokens=reasoning, + cache_read_tokens=cache_read, + raw_total_tokens=_int(data.get("totalTokenCount", data.get("total_tokens", 0))), + source=source, + provider=provider, + model=model, + provider_reported_cost=cost[0], + cost_currency=cost[1], + cost_source=cost[2], + metadata={"shape": "gemini"}, + ) + + +def _extract_cost(data: dict[str, Any]) -> tuple[Optional[float], str, Optional[str]]: + """Extract actual provider-reported cost without guessing prices. + + Advisory pricing belongs in `usage.costs`. This helper only preserves cost + values explicitly reported by a provider or protocol adapter. + """ + + cost_payload = _as_dict(data.get("cost") or data.get("cost_details") or data.get("costMetadata") or {}) + raw_cost = ( + cost_payload.get("provider_reported_cost", cost_payload.get("request_cost_usd")) + if cost_payload + else None + ) + breakdown_used = False + if raw_cost is None and cost_payload: + raw_cost = cost_payload.get("total_cost", cost_payload.get("total")) + if raw_cost is None and cost_payload: + raw_cost = cost_payload.get("cost") + if raw_cost is None and cost_payload: + breakdown_total = _sum_cost_breakdown(cost_payload) + if breakdown_total is not None: + raw_cost = breakdown_total + breakdown_used = True + if raw_cost is None: + raw_cost = data.get("provider_reported_cost", data.get("request_cost_usd", data.get("total_cost", data.get("cost", data.get("estimated_cost"))))) + cost_value = _float_or_none(raw_cost) + currency = str(cost_payload.get("currency") or data.get("currency") or "USD") + source = cost_payload.get("source") or ("provider_reported_breakdown" if breakdown_used else ("provider_reported" if cost_value is not None else None)) + return cost_value, currency, str(source) if source else None + + +def _sum_cost_breakdown(payload: dict[str, Any]) -> Optional[float]: + """Sum structured provider cost buckets when no total is reported.""" + + total = 0.0 + found = False + for key in ( + "input_cost", + "prompt_cost", + "cache_read_cost", + "cache_write_cost", + "cached_input_cost", + "cache_write_input_cost", + "output_cost", + "completion_cost", + "completions_cost", + "reasoning_cost", + "thinking_cost", + "upstream_inference_cost", + "upstream_inference_prompt_cost", + "upstream_inference_completions_cost", + "upstream_inference_input_cost", + "upstream_inference_output_cost", + "image_input_cost", + "image_output_cost", + "audio_input_cost", + "audio_output_cost", + "data_storage_cost", + "estimated_cost", + "request_cost", + "web_search_cost", + "search_cost", + ): + value = _float_or_none(payload.get(key)) + if value is not None: + total += value + found = True + ticks = _float_or_none(payload.get("cost_in_usd_ticks")) + if ticks is not None: + total += ticks / 10_000_000_000 + found = True + return total if found else None + + +def _looks_like_gemini(data: dict[str, Any]) -> bool: + return any(key in data for key in ("promptTokenCount", "candidatesTokenCount", "thoughtsTokenCount", "cachedContentTokenCount")) + + +def _looks_like_anthropic(data: dict[str, Any]) -> bool: + return any(key in data for key in ("cache_creation_input_tokens", "cache_read_input_tokens")) and "input_tokens" in data + + +def _int(value: Any) -> int: + try: + return max(0, int(value or 0)) + except (TypeError, ValueError): + return 0 + + +def _float_or_none(value: Any) -> Optional[float]: + if value is None: + return None + try: + return max(0.0, float(value)) + except (TypeError, ValueError): + return None diff --git a/src/rotator_library/usage/costs.py b/src/rotator_library/usage/costs.py new file mode 100644 index 000000000..6a37aeb82 --- /dev/null +++ b/src/rotator_library/usage/costs.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Advisory cost calculation for normalized usage records.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +import litellm + +from ..protocols import serialize_value +from .accounting import UsageRecord + + +@dataclass(frozen=True) +class ModelPricing: + """Per-token pricing for one provider/model. + + Prices are advisory and local-only; this module never calls a network pricing + endpoint. Providers can return this object from `get_model_pricing()` later. + """ + + input_cost_per_token: float = 0.0 + cache_read_cost_per_token: float = 0.0 + cache_write_cost_per_token: float = 0.0 + output_cost_per_token: float = 0.0 + reasoning_cost_per_token: float = 0.0 + currency: str = "USD" + source: str = "explicit" + + +@dataclass(frozen=True) +class CostBreakdown: + """Advisory request cost split by normalized usage bucket.""" + + input_cost: float = 0.0 + cache_read_cost: float = 0.0 + cache_write_cost: float = 0.0 + output_cost: float = 0.0 + reasoning_cost: float = 0.0 + provider_reported_cost: float = 0.0 + currency: str = "USD" + pricing_source: str = "unavailable" + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def total_cost(self) -> float: + if self.provider_reported_cost: + return self.provider_reported_cost + return self.input_cost + self.cache_read_cost + self.cache_write_cost + self.output_cost + self.reasoning_cost + + def to_dict(self) -> dict[str, Any]: + return { + "input_cost": self.input_cost, + "cache_read_cost": self.cache_read_cost, + "cache_write_cost": self.cache_write_cost, + "output_cost": self.output_cost, + "reasoning_cost": self.reasoning_cost, + "provider_reported_cost": self.provider_reported_cost, + "total_cost": self.total_cost, + "currency": self.currency, + "pricing_source": self.pricing_source, + "metadata": serialize_value(self.metadata), + } + + +class CostCalculator: + """Calculate advisory costs without replacing usage tracking.""" + + def __init__(self, *, provider_plugin: Any = None, use_litellm_fallback: bool = True, config: Any = None, env: Any = None) -> None: + self.provider_plugin = provider_plugin + self.use_litellm_fallback = use_litellm_fallback + self.config = config + self.env = env + + def calculate(self, usage: UsageRecord, *, model: str, response: Any = None, provider: str | None = None) -> CostBreakdown: + """Return an advisory cost breakdown for a normalized usage record.""" + + if self.provider_plugin and getattr(self.provider_plugin, "skip_cost_calculation", False): + return CostBreakdown(pricing_source="skipped", metadata={"reason": "provider_skip_cost_calculation"}) + if usage.provider_reported_cost is not None: + return CostBreakdown( + provider_reported_cost=usage.provider_reported_cost, + currency=usage.cost_currency, + pricing_source=usage.cost_source or "provider_reported", + metadata={"actual_provider_reported": True}, + ) + pricing = self._provider_pricing(model) + if pricing: + return _calculate_from_pricing(usage, pricing) + pricing = self._configured_pricing(provider or usage.provider or _provider_from_model(model), model) + if pricing: + return _calculate_from_pricing(usage, pricing) + if self.use_litellm_fallback: + lite = self._litellm_cost(usage, model=model, response=response) + if lite: + return lite + return CostBreakdown(pricing_source="unavailable") + + def _provider_pricing(self, model: str) -> Optional[ModelPricing]: + if not self.provider_plugin: + return None + method = getattr(self.provider_plugin, "get_model_pricing", None) + if not method: + return None + pricing = method(model) + if isinstance(pricing, ModelPricing): + return pricing + if isinstance(pricing, dict): + return ModelPricing(**pricing) + return None + + def _configured_pricing(self, provider: str | None, model: str) -> Optional[ModelPricing]: + if not provider: + return None + from ..config.experimental import get_configured_model_pricing + + return get_configured_model_pricing(provider, _model_without_provider(provider, model), config=self.config, env=self.env) + + @staticmethod + def _litellm_cost(usage: UsageRecord, *, model: str, response: Any = None) -> Optional[CostBreakdown]: + if response is not None: + try: + cost = litellm.completion_cost(completion_response=response, model=model) + if cost is not None: + return CostBreakdown(output_cost=float(cost), pricing_source="litellm_completion_cost") + except Exception: + pass + try: + model_info = litellm.get_model_info(model) + except Exception: + return None + input_price = float(model_info.get("input_cost_per_token") or 0.0) + output_price = float(model_info.get("output_cost_per_token") or 0.0) + if input_price == 0.0 and output_price == 0.0: + return None + pricing = ModelPricing( + input_cost_per_token=input_price, + cache_read_cost_per_token=float(model_info.get("cache_read_input_token_cost") or input_price), + cache_write_cost_per_token=float(model_info.get("cache_creation_input_token_cost") or input_price), + output_cost_per_token=output_price, + reasoning_cost_per_token=output_price, + source="litellm_model_info", + ) + return _calculate_from_pricing(usage, pricing) + + +def _calculate_from_pricing(usage: UsageRecord, pricing: ModelPricing) -> CostBreakdown: + return CostBreakdown( + input_cost=usage.input_tokens * pricing.input_cost_per_token, + cache_read_cost=usage.cache_read_tokens * pricing.cache_read_cost_per_token, + cache_write_cost=usage.cache_write_tokens * pricing.cache_write_cost_per_token, + output_cost=usage.completion_tokens * pricing.output_cost_per_token, + reasoning_cost=usage.reasoning_tokens * (pricing.reasoning_cost_per_token or pricing.output_cost_per_token), + currency=pricing.currency, + pricing_source=pricing.source, + ) + + +def _provider_from_model(model: str) -> Optional[str]: + return model.split("/", 1)[0] if "/" in model else None + + +def _model_without_provider(provider: str, model: str) -> str: + prefix = f"{provider}/" + return model[len(prefix) :] if model.startswith(prefix) else model diff --git a/src/rotator_library/usage/quota.py b/src/rotator_library/usage/quota.py new file mode 100644 index 000000000..7d490c2a0 --- /dev/null +++ b/src/rotator_library/usage/quota.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Read-only quota snapshot helpers built from existing usage state.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Mapping, Optional + +from ..error_handler import mask_credential +from ..protocols import serialize_value +from .types import CredentialState, WindowStats + + +@dataclass(frozen=True) +class QuotaSnapshot: + """Client-safe view of one usage window. + + Snapshots are reporting-only. They intentionally do not participate in limit + checks, selection, or persistence so existing quota behavior remains owned by + `UsageManager`, `TrackingEngine`, and `LimitEngine`. + """ + + provider: str + model: Optional[str] + quota_group: Optional[str] + credential_id: Optional[str] + window_name: str + limit: Optional[int] + used: int + remaining: Optional[int] + reset_at: Optional[float] + source: str + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "provider": self.provider, + "model": self.model, + "quota_group": self.quota_group, + "credential_id": self.credential_id, + "window_name": self.window_name, + "limit": self.limit, + "used": self.used, + "remaining": self.remaining, + "reset_at": self.reset_at, + "source": self.source, + "metadata": serialize_value(self.metadata), + } + + +def build_quota_snapshots( + *, + provider: str, + states: Mapping[str, CredentialState], + model: Optional[str] = None, + quota_group: Optional[str] = None, + include_credentials: bool = True, +) -> list[QuotaSnapshot]: + """Build read-only request/token quota snapshots from credential states. + + The current usage state stores request/token windows, not a reliable + provider-cost ledger. Snapshots therefore avoid inventing cost totals; cost + reporting can be added later only if the underlying state owns that data. + """ + + snapshots: list[QuotaSnapshot] = [] + for stable_id, state in states.items(): + credential_id = mask_credential(stable_id) if include_credentials else None + if model: + model_stats = state.get_model_stats(model, create=False) + if model_stats: + snapshots.extend( + _snapshots_for_windows( + provider=provider, + model=model, + quota_group=None, + credential_id=credential_id, + windows=model_stats.windows, + source="model", + ) + ) + if quota_group: + group_stats = state.get_group_stats(quota_group, create=False) + if group_stats: + snapshots.extend( + _snapshots_for_windows( + provider=provider, + model=model, + quota_group=quota_group, + credential_id=credential_id, + windows=group_stats.windows, + source="group", + ) + ) + return snapshots + + +def _snapshots_for_windows( + *, + provider: str, + model: Optional[str], + quota_group: Optional[str], + credential_id: Optional[str], + windows: Mapping[str, WindowStats], + source: str, +) -> list[QuotaSnapshot]: + return [ + QuotaSnapshot( + provider=provider, + model=model, + quota_group=quota_group, + credential_id=credential_id, + window_name=window.name, + limit=window.limit, + used=window.request_count, + remaining=window.remaining, + reset_at=window.reset_at, + source=source, + metadata={"scope": "request_token_window"}, + ) + for window in windows.values() + ] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..df150dc59 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" + +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) diff --git a/tests/test_adapter_registry.py b/tests/test_adapter_registry.py new file mode 100644 index 000000000..cef410c85 --- /dev/null +++ b/tests/test_adapter_registry.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.adapters import ( + AdapterContext, + PayloadAdapter, + get_adapter, + get_adapter_class, + list_adapters, + register_adapter, + resolve_adapter_name, + run_adapter_chain, +) +from rotator_library.transaction_logger import TransactionLogger + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +def test_adapter_registry_auto_discovers_builtins_and_aliases() -> None: + adapters = list_adapters() + + assert "noop" in adapters + assert "model_override" in adapters + assert "field_rename" in adapters + assert resolve_adapter_name("passthrough") == "noop" + assert resolve_adapter_name("field_copy") == "field_rename" + assert get_adapter_class("override_model").name == "model_override" + assert get_adapter("none") is get_adapter("noop") + + +def test_adapter_registry_rejects_duplicate_names_and_alias_collisions() -> None: + class DuplicateNoop(PayloadAdapter): + name = "noop" + + class AliasCollision(PayloadAdapter): + name = "custom_alias_collision" + aliases = ("noop",) + + with pytest.raises(ValueError): + register_adapter(DuplicateNoop) + with pytest.raises(ValueError): + register_adapter(AliasCollision) + + +@pytest.mark.asyncio +async def test_noop_adapter_preserves_payload_without_mutating_original() -> None: + payload = {"model": "original", "messages": []} + + result = await run_adapter_chain([get_adapter("noop")], payload, AdapterContext(), stage="request") + + assert result == payload + assert result is not payload + + +@pytest.mark.asyncio +async def test_model_override_adapter_changes_model_from_config() -> None: + payload = {"model": "public-name", "messages": []} + context = AdapterContext(adapter_config={"model_override": {"model": "native-name"}}) + + result = await run_adapter_chain([get_adapter("model_override")], payload, context, stage="request") + + assert result["model"] == "native-name" + assert payload["model"] == "public-name" + + +@pytest.mark.asyncio +async def test_suppress_developer_role_converts_or_drops_messages() -> None: + payload = {"messages": [{"role": "developer", "content": "rules"}, {"role": "user", "content": "hi"}]} + + system_result = await run_adapter_chain([get_adapter("suppress_developer_role")], payload, AdapterContext(), stage="request") + drop_result = await run_adapter_chain( + [get_adapter("suppress_developer_role")], + payload, + AdapterContext(adapter_config={"suppress_developer_role": {"mode": "drop"}}), + stage="request", + ) + + assert system_result["messages"][0]["role"] == "system" + assert [message["role"] for message in drop_result["messages"]] == ["user"] + + +@pytest.mark.asyncio +async def test_reasoning_content_adapter_copies_common_reasoning_field() -> None: + payload = {"choices": [{"message": {"role": "assistant", "reasoning": "hidden"}}]} + + result = await run_adapter_chain([get_adapter("reasoning_content")], payload, AdapterContext(), stage="response") + + assert result["choices"][0]["message"]["reasoning_content"] == "hidden" + + +@pytest.mark.asyncio +async def test_field_rename_adapter_copies_and_moves_configured_fields() -> None: + payload = {"old": {"field": "value"}, "messages": [{}]} + context = AdapterContext( + adapter_config={ + "field_rename": { + "rules": [ + {"source_path": "old.field", "target_path": "messages[-1].new_field", "move": True} + ] + } + } + ) + + result = await run_adapter_chain([get_adapter("field_rename")], payload, context, stage="request") + + assert result["messages"][-1]["new_field"] == "value" + assert "field" not in result["old"] + assert payload["old"]["field"] == "value" + + +@pytest.mark.asyncio +async def test_antigravity_envelope_wraps_user_request_key() -> None: + adapter = get_adapter("antigravity_envelope") + + result = await adapter.transform_request( + {"model": "gemini-3-flash", "request": {"user_supplied": True}, "contents": [{"parts": [{"text": "hi"}]}]}, + AdapterContext(adapter_config={"antigravity_envelope": {"request_type": "CHAT_COMPLETION", "user_agent": "test-agent"}}), + ) + + assert result["requestType"] == "CHAT_COMPLETION" + assert result["requestId"] + assert result["request"]["request"] == {"user_supplied": True} + assert result["request"]["contents"][0]["parts"][0]["text"] == "hi" + + +@pytest.mark.asyncio +async def test_antigravity_envelope_is_idempotent_for_controlled_envelope() -> None: + adapter = get_adapter("antigravity_envelope") + payload = {"model": "gemini-3-flash", "request": {"contents": []}, "requestType": "CHAT_COMPLETION", "requestId": "id"} + + result = await adapter.transform_request(payload, AdapterContext()) + + assert result == payload + + +@pytest.mark.asyncio +async def test_adapter_chain_order_is_preserved() -> None: + payload = {"model": "public", "messages": [{"role": "developer", "content": "rules"}]} + context = AdapterContext(adapter_config={"model_override": {"model": "native"}}) + + result = await run_adapter_chain( + [get_adapter("model_override"), get_adapter("suppress_developer_role")], + payload, + context, + stage="request", + ) + + assert result["model"] == "native" + assert result["messages"][0]["role"] == "system" + + +@pytest.mark.asyncio +async def test_adapter_chain_traces_final_summary(tmp_path) -> None: + logger = TransactionLogger("native", "native/test", parent_dir=tmp_path) + payload = {"model": "public", "messages": []} + context = AdapterContext( + adapter_config={"model_override": {"model": "native"}}, + transaction_logger=logger, + protocol="openai_chat", + ) + + result = await run_adapter_chain([get_adapter("model_override")], payload, context, stage="request") + + assert result["model"] == "native" + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names == ["before_adapter_chain", "after_adapter", "after_adapter_chain"] + assert entries[-1]["metadata"]["adapter_count"] == 1 + assert entries[-1]["metadata"]["changed"] is True diff --git a/tests/test_anthropic_transform_tracing.py b/tests/test_anthropic_transform_tracing.py new file mode 100644 index 000000000..1f1c7ffbe --- /dev/null +++ b/tests/test_anthropic_transform_tracing.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import json + +import pytest + +import rotator_library.client.anthropic as anthropic_client_module +from rotator_library.anthropic_compat import AnthropicMessagesRequest +from rotator_library.anthropic_compat.streaming import anthropic_streaming_wrapper +from rotator_library.client.anthropic import AnthropicHandler +from rotator_library.transaction_logger import TransactionLogger + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +class FakeRotatingClient: + enable_request_logging = True + + async def acompletion(self, **kwargs): + return { + "id": "chat_1", + "model": kwargs["model"], + "choices": [{"message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + + +@pytest.mark.asyncio +async def test_anthropic_handler_traces_conversion_boundaries(tmp_path, monkeypatch) -> None: + created: list[TransactionLogger] = [] + + def logger_factory(provider, model, enabled=True, api_format="ant", parent_dir=None): + logger = TransactionLogger(provider, model, enabled=enabled, api_format=api_format, parent_dir=tmp_path) + created.append(logger) + return logger + + monkeypatch.setattr(anthropic_client_module, "TransactionLogger", logger_factory) + request = AnthropicMessagesRequest(model="openai/gpt-test", max_tokens=16, messages=[{"role": "user", "content": "hi"}]) + + response = await AnthropicHandler(FakeRotatingClient()).messages(request) + + assert response["content"][0]["text"] == "ok" + pass_names = [entry["pass_name"] for entry in _trace_entries(created[0].log_dir)] + assert "anthropic_raw_request" in pass_names + assert "anthropic_to_openai_request" in pass_names + assert "anthropic_openai_response" in pass_names + assert "openai_to_anthropic_response" in pass_names + assert "anthropic_final_response" in pass_names + + +class ClosingOpenAIStream: + def __init__(self) -> None: + self.closed = False + self._chunks = iter(['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n']) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._chunks) + except StopIteration: + raise StopAsyncIteration + + async def aclose(self) -> None: + self.closed = True + + +class IteratorOnlyCloseStream: + def __init__(self) -> None: + self.iterator = IteratorOnlyCloseStreamIterator() + + def __aiter__(self): + return self.iterator + + +class IteratorOnlyCloseStreamIterator: + def __init__(self) -> None: + self.closed = False + self._chunks = iter(['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n']) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._chunks) + except StopIteration: + raise StopAsyncIteration + + async def aclose(self) -> None: + self.closed = True + + +@pytest.mark.asyncio +async def test_anthropic_stream_traces_and_closes_on_disconnect(tmp_path) -> None: + logger = TransactionLogger("anthropic", "claude-test", parent_dir=tmp_path) + stream = ClosingOpenAIStream() + + async def disconnected() -> bool: + return True + + chunks = [chunk async for chunk in anthropic_streaming_wrapper(stream, "claude-test", is_disconnected=disconnected, transaction_logger=logger)] + + assert chunks == [] + assert stream.closed is True + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "anthropic_stream_source_chunk" in pass_names + assert "anthropic_stream_disconnected" in pass_names + assert "anthropic_stream_upstream_closed" in pass_names + + +@pytest.mark.asyncio +async def test_anthropic_stream_closes_iterator_only_upstream() -> None: + stream = IteratorOnlyCloseStream() + + async def disconnected() -> bool: + return True + + chunks = [chunk async for chunk in anthropic_streaming_wrapper(stream, "claude-test", is_disconnected=disconnected)] + + assert chunks == [] + assert stream.iterator.closed is True + + +@pytest.mark.asyncio +async def test_anthropic_stream_traces_emitted_frames(tmp_path) -> None: + logger = TransactionLogger("anthropic", "claude-test", parent_dir=tmp_path) + + async def stream(): + yield 'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' + yield "data: [DONE]\n\n" + + chunks = [chunk async for chunk in anthropic_streaming_wrapper(stream(), "claude-test", transaction_logger=logger)] + + assert any("event: message_start" in chunk for chunk in chunks) + assert any("event: message_stop" in chunk for chunk in chunks) + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "anthropic_stream_message_start" in pass_names + assert "anthropic_stream_content_delta" in pass_names + assert "anthropic_stream_message_delta" in pass_names + assert "anthropic_stream_message_stop" in pass_names diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py new file mode 100644 index 000000000..ce5dbba87 --- /dev/null +++ b/tests/test_antigravity_provider_restore.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import pytest + +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.antigravity_provider import AntigravityProvider + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeClient: + def __init__(self, payload): + self.payload = payload + self.calls = [] + + async def post(self, url, *, headers, json, timeout): + self.calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout}) + return FakeResponse(self.payload) + + +def test_antigravity_provider_is_discovered() -> None: + assert "antigravity" in PROVIDER_PLUGINS + + +def test_antigravity_provider_restores_safe_declarations() -> None: + provider = AntigravityProvider() + + assert provider.get_protocol_name("gemini-3-flash") == "gemini" + assert provider.get_adapter_names("gemini-3-flash") == ("antigravity_envelope",) + assert provider.get_adapter_config("gemini-3-flash")["antigravity_envelope"]["request_type"] == "CHAT_COMPLETION" + assert provider.get_model_tier_requirement("antigravity/gemini-3-flash") is None + rules = provider.get_field_cache_rules("gemini-3-flash") + assert rules[0].name == "antigravity_thought_signature" + assert rules[0].scope == ("provider", "model", "credential", "session") + + +def test_antigravity_quota_groups_are_model_family_scoped() -> None: + groups = AntigravityProvider.model_quota_groups + + assert groups["gemini_3_pro"] == ["gemini-3-pro-preview", "gemini-3-pro-low", "gemini-3-pro-high"] + assert groups["gemini_3_flash"] == ["gemini-3-flash"] + assert groups["gemini_2_5_flash"] == ["gemini-2.5-flash"] + assert groups["gemini_2_5_flash_lite"] == ["gemini-2.5-flash-lite"] + assert groups["claude_sonnet_4_5"] == ["claude-sonnet-4.5"] + + +def test_antigravity_provider_builds_static_headers_without_device_profile(monkeypatch) -> None: + monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") + provider = AntigravityProvider() + + headers = provider.get_native_headers("token") + + assert provider.get_native_endpoint(operation="generate") == "https://antigravity.test/v1internal:generateContent" + assert provider.get_native_endpoint(operation="stream_generate") == "https://antigravity.test/v1internal:streamGenerateContent?alt=sse" + assert provider.get_native_endpoint(operation="models") == "https://antigravity.test/v1internal:fetchAvailableModels" + assert headers["Authorization"] == "Bearer token" + assert headers["X-Goog-Api-Client"] == "google-cloud-sdk vscode_cloudshelleditor/0.1" + assert "Accept" not in headers + assert provider.get_native_headers("token", operation="stream_generate")["Accept"] == "text/event-stream" + assert "X-Client-Device-Id" not in headers + + +def test_antigravity_model_aliases_and_tracking_normalization() -> None: + provider = AntigravityProvider() + + assert provider._alias_to_internal("claude-sonnet-4.5") == "claude-sonnet-4-5" + assert provider.normalize_native_model("antigravity/claude-sonnet-4.5") == "claude-sonnet-4-5" + assert provider.normalize_native_model("antigravity/gemini-3-pro-preview") == "gemini-3-pro-preview" + assert provider.normalize_model_for_tracking("antigravity/claude-sonnet-4-5") == "antigravity/claude-sonnet-4.5" + + +def test_antigravity_native_operation_model_and_stream_support() -> None: + provider = AntigravityProvider() + + assert provider.get_native_operation("gemini-3-flash", {}, stream=False) == "generate" + assert provider.get_native_operation("gemini-3-flash", {}, stream=True) == "stream_generate" + assert provider.supports_native_streaming("gemini-3-flash", operation="stream_generate") is False + assert provider.supports_native_streaming("gemini-3-flash", operation="generate") is False + prepared = provider.prepare_native_request({"model": "antigravity/gemini-3-pro-low"}, model="gemini-3-pro-preview", operation="generate") + assert prepared["model"] == "gemini-3-pro-low" + assert prepared["generationConfig"]["thinkingConfig"]["thinkingLevel"] == "low" + + +def test_antigravity_prepare_native_request_converts_messages_to_gemini_contents() -> None: + provider = AntigravityProvider() + + prepared = provider.prepare_native_request( + {"messages": [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi"}]}, + model="gemini-3-flash", + operation="generate", + ) + + assert prepared["model"] == "gemini-3-flash" + assert prepared["contents"] == [ + {"role": "user", "parts": [{"text": "hello"}]}, + {"role": "model", "parts": [{"text": "hi"}]}, + ] + assert "messages" not in prepared + + +@pytest.mark.asyncio +async def test_antigravity_get_models_filters_and_aliases_mocked_response(monkeypatch) -> None: + monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") + client = FakeClient({"models": {"gemini-3-pro-low": {}, "chat_20706": {}, "claude-opus-4-6": {}}}) + provider = AntigravityProvider() + + models = await provider.get_models("token", client) + + assert models == ["antigravity/gemini-3-pro-preview", "antigravity/claude-opus-4.6"] + assert client.calls[0]["url"] == "https://antigravity.test/v1internal:fetchAvailableModels" + assert client.calls[0]["headers"]["Authorization"] == "Bearer token" + + +@pytest.mark.asyncio +async def test_antigravity_get_models_falls_back_on_errors() -> None: + class BrokenClient: + async def post(self, *args, **kwargs): + raise RuntimeError("offline") + + models = await AntigravityProvider().get_models("token", BrokenClient()) + + assert "antigravity/gemini-3-flash" in models + assert "antigravity/claude-opus-4.6" in models diff --git a/tests/test_claude_code_provider.py b/tests/test_claude_code_provider.py new file mode 100644 index 000000000..cf47902cf --- /dev/null +++ b/tests/test_claude_code_provider.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import pytest + +from rotator_library.adapters import AdapterContext, get_adapter, run_adapter_chain +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.claude_code_provider import ClaudeCodeProvider + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeClient: + def __init__(self, payload): + self.payload = payload + self.calls = [] + + async def get(self, url, *, headers, timeout): + self.calls.append({"url": url, "headers": headers, "timeout": timeout}) + return FakeResponse(self.payload) + + +def test_claude_code_provider_is_discovered() -> None: + assert "claude_code" in PROVIDER_PLUGINS + + +def test_claude_code_provider_declares_native_protocol_adapters_and_cache_rules() -> None: + provider = ClaudeCodeProvider() + + assert provider.get_protocol_name("claude-sonnet-4-5") == "anthropic_messages" + assert provider.get_adapter_names("claude-sonnet-4-5") == ("suppress_developer_role",) + assert provider.get_adapter_config("claude-sonnet-4-5") == {"suppress_developer_role": {"mode": "user"}} + rules = provider.get_field_cache_rules("claude-sonnet-4-5") + assert rules[0].name == "claude_code_thinking_signature" + assert rules[0].scope == ("provider", "model", "credential", "session") + + +@pytest.mark.asyncio +async def test_claude_code_adapter_config_converts_developer_role_to_user() -> None: + provider = ClaudeCodeProvider() + + result = await run_adapter_chain( + [get_adapter("suppress_developer_role")], + {"messages": [{"role": "developer", "content": "rules"}]}, + AdapterContext(adapter_config=provider.get_adapter_config("claude-sonnet-4-5")), + stage="request", + ) + + assert result["messages"][0]["role"] == "user" + + +def test_claude_code_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: + monkeypatch.setenv("CLAUDE_CODE_API_BASE", "https://claude-code.test") + monkeypatch.setenv("CLAUDE_CODE_ANTHROPIC_VERSION", "2099-01-01") + provider = ClaudeCodeProvider() + + assert provider.get_native_endpoint(operation="messages") == "https://claude-code.test/v1/messages" + assert provider.get_native_headers("token") == { + "Authorization": "Bearer token", + "anthropic-version": "2099-01-01", + "content-type": "application/json", + } + + +def test_claude_code_native_operation_model_and_stream_support() -> None: + provider = ClaudeCodeProvider() + + assert provider.get_native_operation("claude-sonnet-4-5", {}, stream=False) == "messages" + assert provider.normalize_native_model("claude_code/claude-sonnet-4-5") == "claude-sonnet-4-5" + assert provider.supports_native_streaming("claude-sonnet-4-5", operation="messages") is False + assert provider.supports_native_streaming("claude-sonnet-4-5", operation="chat") is False + + +@pytest.mark.asyncio +async def test_claude_code_provider_get_models_uses_mocked_models_endpoint(monkeypatch) -> None: + monkeypatch.setenv("CLAUDE_CODE_API_BASE", "https://claude-code.test") + client = FakeClient({"data": [{"id": "claude-sonnet-test"}]}) + provider = ClaudeCodeProvider() + + models = await provider.get_models("token", client) + + assert models == ["claude_code/claude-sonnet-test"] + assert client.calls[0]["url"] == "https://claude-code.test/v1/models" + assert client.calls[0]["headers"]["Authorization"] == "Bearer token" + + +@pytest.mark.asyncio +async def test_claude_code_provider_get_models_falls_back_on_errors() -> None: + class BrokenClient: + async def get(self, *args, **kwargs): + raise RuntimeError("offline") + + assert await ClaudeCodeProvider().get_models("token", BrokenClient()) == [ + "claude_code/claude-sonnet-4-5", + "claude_code/claude-opus-4-5", + ] diff --git a/tests/test_codex_provider.py b/tests/test_codex_provider.py new file mode 100644 index 000000000..60ba5790e --- /dev/null +++ b/tests/test_codex_provider.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import pytest + +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.codex_provider import CodexProvider + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeClient: + def __init__(self, payload): + self.payload = payload + self.calls = [] + + async def get(self, url, *, headers, timeout): + self.calls.append({"url": url, "headers": headers, "timeout": timeout}) + return FakeResponse(self.payload) + + +def test_codex_provider_is_discovered() -> None: + assert "codex" in PROVIDER_PLUGINS + + +def test_codex_provider_declares_responses_protocol_and_cache_rule() -> None: + provider = CodexProvider() + + assert provider.get_protocol_name("codex-mini-latest") == "responses" + assert provider.get_adapter_names("codex-mini-latest") == () + rules = provider.get_field_cache_rules("codex-mini-latest") + assert rules[0].name == "codex_previous_response_id" + assert rules[0].inject.path == "previous_response_id" + + +def test_codex_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: + monkeypatch.setenv("CODEX_API_BASE", "https://codex.test") + provider = CodexProvider() + + assert provider.get_native_endpoint(operation="responses") == "https://codex.test/v1/responses" + assert provider.get_native_endpoint(operation="models") == "https://codex.test/v1/models" + assert provider.get_native_headers("token") == {"Authorization": "Bearer token", "content-type": "application/json"} + + +def test_codex_native_operation_model_and_stream_support() -> None: + provider = CodexProvider() + + assert provider.get_native_operation("gpt-5.1-codex", {}, stream=False) == "responses" + assert provider.normalize_native_model("codex/gpt-5.1-codex") == "gpt-5.1-codex" + assert provider.supports_native_streaming("gpt-5.1-codex", operation="responses") is False + assert provider.supports_native_streaming("gpt-5.1-codex", operation="chat") is False + + +def test_codex_prepare_native_request_converts_messages_to_responses_input() -> None: + provider = CodexProvider() + + prepared = provider.prepare_native_request( + {"model": "codex/gpt-5.1-codex", "messages": [{"role": "user", "content": "hello"}]}, + model="gpt-5.1-codex", + operation="responses", + ) + + assert prepared["model"] == "gpt-5.1-codex" + assert prepared["input"] == [{"role": "user", "content": "hello"}] + assert "messages" not in prepared + + +@pytest.mark.asyncio +async def test_codex_provider_get_models_filters_codex_models(monkeypatch) -> None: + monkeypatch.setenv("CODEX_API_BASE", "https://codex.test") + client = FakeClient({"data": [{"id": "gpt-5.1-codex"}, {"id": "gpt-5.1"}]}) + provider = CodexProvider() + + models = await provider.get_models("token", client) + + assert models == ["codex/gpt-5.1-codex"] + assert client.calls[0]["url"] == "https://codex.test/v1/models" + + +@pytest.mark.asyncio +async def test_codex_provider_get_models_falls_back_on_errors() -> None: + class BrokenClient: + async def get(self, *args, **kwargs): + raise RuntimeError("offline") + + assert await CodexProvider().get_models("token", BrokenClient()) == ["codex/codex-mini-latest", "codex/gpt-5.1-codex"] diff --git a/tests/test_config_pricing.py b/tests/test_config_pricing.py new file mode 100644 index 000000000..39269b4b5 --- /dev/null +++ b/tests/test_config_pricing.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from rotator_library.config.experimental import env_price_key, load_config_from_mapping +from rotator_library.usage.accounting import UsageRecord +from rotator_library.usage.costs import CostCalculator, ModelPricing + + +class ExplicitPricingProvider: + def get_model_pricing(self, model: str) -> ModelPricing: + return ModelPricing(input_cost_per_token=9.0, source="provider") + + +class SkipCostProvider: + skip_cost_calculation = True + + +def _usage() -> UsageRecord: + return UsageRecord(input_tokens=10, completion_tokens=5, reasoning_tokens=2, cache_read_tokens=3, cache_write_tokens=4, provider="openai", model="gpt-test") + + +def test_json_pricing_calculates_all_buckets() -> None: + config = load_config_from_mapping( + {"pricing": {"openai": {"gpt-test": {"input": 1.0, "output": 2.0, "reasoning": 3.0, "cache_read": 0.5, "cache_write": 0.75}}}} + ) + + cost = CostCalculator(config=config, use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "json_config" + assert cost.input_cost == 10.0 + assert cost.output_cost == 10.0 + assert cost.reasoning_cost == 6.0 + assert cost.cache_read_cost == 1.5 + assert cost.cache_write_cost == 3.0 + + +def test_env_pricing_overrides_json_pricing() -> None: + config = load_config_from_mapping({"pricing": {"openai": {"gpt-test": {"input": 1.0}}}}) + env = {env_price_key("openai", "gpt-test", "input"): "4.0"} + + cost = CostCalculator(config=config, env=env, use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "env" + assert cost.input_cost == 40.0 + + +def test_provider_pricing_beats_env_and_json_pricing() -> None: + config = load_config_from_mapping({"pricing": {"openai": {"gpt-test": {"input": 1.0}}}}) + env = {env_price_key("openai", "gpt-test", "input"): "4.0"} + + cost = CostCalculator(provider_plugin=ExplicitPricingProvider(), config=config, env=env, use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "provider" + assert cost.input_cost == 90.0 + + +def test_skip_cost_provider_beats_all_config_pricing() -> None: + config = load_config_from_mapping({"pricing": {"openai": {"gpt-test": {"input": 1.0}}}}) + + cost = CostCalculator(provider_plugin=SkipCostProvider(), config=config, use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "skipped" + assert cost.total_cost == 0.0 + + +def test_missing_pricing_remains_unavailable() -> None: + cost = CostCalculator(use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "unavailable" diff --git a/tests/test_config_routing_json.py b/tests/test_config_routing_json.py new file mode 100644 index 000000000..06c90c7b6 --- /dev/null +++ b/tests/test_config_routing_json.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytest + +from rotator_library.config.experimental import load_config_from_mapping +from rotator_library.routing import RoutingConfigError, load_routing_config_from_env + + +def test_json_routing_config_loads_fallback_group_and_route() -> None: + experimental = load_config_from_mapping( + { + "routing": { + "fallback_groups": { + "code_chain": { + "targets": ["codex/gpt-5.1@native", "openai/gpt-5.1@litellm_fallback"], + "failover_on": ["rate_limit"], + "stop_on": ["authentication"], + } + }, + "model_routes": {"code": "group:code_chain"}, + } + } + ) + + config = load_routing_config_from_env({}, config=experimental) + + assert config.fallback_groups["code_chain"].targets[0].execution == "native" + assert config.fallback_groups["code_chain"].failover_on == frozenset({"rate_limit"}) + assert config.model_routes["code"] == "group:code_chain" + + +def test_env_group_overrides_json_group_targets() -> None: + experimental = load_config_from_mapping( + {"routing": {"fallback_groups": {"chain": {"targets": ["codex/gpt"]}}, "model_routes": {"code": "group:chain"}}} + ) + + config = load_routing_config_from_env( + {"FALLBACK_GROUPS": "chain", "FALLBACK_GROUP_CHAIN": "openai/gpt-5.1", "MODEL_ROUTE_CODE": "group:chain"}, + config=experimental, + ) + + assert config.fallback_groups["chain"].targets[0].provider == "openai" + + +def test_env_route_can_reference_json_group() -> None: + experimental = load_config_from_mapping({"routing": {"fallback_groups": {"chain": {"targets": ["codex/gpt"]}}}}) + + config = load_routing_config_from_env({"MODEL_ROUTE_CODE": "group:chain"}, config=experimental) + + assert config.model_routes["code"] == "group:chain" + + +def test_json_route_rejects_unknown_group() -> None: + experimental = load_config_from_mapping({"routing": {"model_routes": {"code": "group:missing"}}}) + + with pytest.raises(RoutingConfigError): + load_routing_config_from_env({}, config=experimental) diff --git a/tests/test_config_stream_settings.py b/tests/test_config_stream_settings.py new file mode 100644 index 000000000..5da2c5141 --- /dev/null +++ b/tests/test_config_stream_settings.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import pytest + +from rotator_library.config.experimental import ExperimentalConfigError, get_stream_runtime_settings, load_config_from_mapping + + +def test_stream_settings_parse_json_values() -> None: + config = load_config_from_mapping( + {"streaming": {"ttfb_timeout_seconds": 5, "stall_timeout_seconds": 30, "heartbeat_interval_seconds": 10, "cancel_upstream_on_disconnect": False, "trace_metrics": False}} + ) + + settings = get_stream_runtime_settings(config=config, env={}) + + assert settings.ttfb_timeout_seconds == 5 + assert settings.stall_timeout_seconds == 30 + assert settings.heartbeat_seconds == 10 + assert settings.cancel_upstream_on_disconnect is False + assert settings.trace_metrics is False + + +def test_stream_settings_env_overrides_json_values() -> None: + config = load_config_from_mapping({"streaming": {"trace_metrics": True, "heartbeat_seconds": 10}}) + + settings = get_stream_runtime_settings(config=config, env={"STREAM_TRACE_METRICS": "false", "STREAM_HEARTBEAT_INTERVAL_SECONDS": "2"}) + + assert settings.trace_metrics is False + assert settings.heartbeat_seconds == 2 + + +def test_stream_settings_accept_legacy_heartbeat_env_name() -> None: + settings = get_stream_runtime_settings(env={"STREAM_HEARTBEAT_SECONDS": "3"}) + + assert settings.heartbeat_seconds == 3 + + +def test_stream_settings_invalid_boolean_fails_clearly() -> None: + with pytest.raises(ExperimentalConfigError): + get_stream_runtime_settings(env={"STREAM_TRACE_METRICS": "maybe"}) diff --git a/tests/test_cooldown_activation.py b/tests/test_cooldown_activation.py new file mode 100644 index 000000000..b7f0490a3 --- /dev/null +++ b/tests/test_cooldown_activation.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from rotator_library.client.executor import RequestExecutor, RoutingExecutionError, _can_start_stream_provider_cooldown +from rotator_library.cooldown_manager import CooldownManager +from rotator_library.core.types import RequestContext +from rotator_library.error_handler import ClassifiedError +from rotator_library.transaction_logger import TransactionLogger + + +class FakeCooldown: + def __init__(self) -> None: + self.started = [] + self.scoped_started = [] + self.waits = [] + + async def start_cooldown(self, provider, duration): + self.started.append((provider, duration)) + + async def start_scoped_cooldown(self, provider, duration, *, model=None, scope="provider", reason=None): + self.scoped_started.append((provider, duration, model, scope, reason)) + + async def get_max_remaining(self, provider, *, model=None): + self.waits.append((provider, model)) + return 0 + + +class BudgetCooldown(FakeCooldown): + def __init__(self, remaining: float) -> None: + super().__init__() + self.remaining = remaining + + async def get_max_remaining(self, provider, *, model=None): + self.waits.append((provider, model)) + return self.remaining + + +@pytest.mark.asyncio +async def test_start_cooldown_extends_but_does_not_shorten() -> None: + manager = CooldownManager() + + await manager.start_cooldown("provider", 30) + initial = await manager.get_remaining_cooldown("provider") + await asyncio.sleep(0.01) + await manager.start_cooldown("provider", 1) + after_shorter = await manager.get_remaining_cooldown("provider") + await manager.start_cooldown("provider", 60) + after_longer = await manager.get_remaining_cooldown("provider") + + assert after_shorter > 25 + assert after_shorter <= initial + assert after_longer > after_shorter + + +@pytest.mark.asyncio +async def test_model_cooldown_is_independent_from_provider_cooldown() -> None: + manager = CooldownManager() + + await manager.start_scoped_cooldown("provider", 30, model="model-a", scope="model", reason="capacity") + + assert await manager.is_scoped_cooling_down("provider", model="model-a", scope="model") is True + assert await manager.is_scoped_cooling_down("provider", model="model-b", scope="model") is False + assert await manager.is_cooling_down("provider") is False + + +@pytest.mark.asyncio +async def test_max_remaining_uses_provider_or_model_scope() -> None: + manager = CooldownManager() + + await manager.start_cooldown("provider", 1) + await manager.start_scoped_cooldown("provider", 30, model="model-a", scope="model") + + assert await manager.get_max_remaining("provider", model="model-a") > 20 + assert 0 < await manager.get_max_remaining("provider", model="model-b") <= 1 + + +@pytest.mark.asyncio +async def test_cooldown_snapshot_reports_scopes() -> None: + manager = CooldownManager() + + await manager.start_scoped_cooldown("provider", 30, model="model-a", scope="model", reason="capacity") + snapshot = await manager.snapshot() + + assert snapshot[0].provider == "provider" + assert snapshot[0].scope == "model" + assert snapshot[0].model == "model-a" + assert snapshot[0].reason == "capacity" + + +def _classified(error_type: str, retry_after=None) -> ClassifiedError: + return ClassifiedError(error_type, original_exception=Exception(error_type), retry_after=retry_after) + + +def _executor(cooldown) -> RequestExecutor: + executor = RequestExecutor.__new__(RequestExecutor) + executor._cooldown = cooldown + from rotator_library.retry_policy import FailureHistory + + executor._failure_history = FailureHistory() + return executor + + +def _context(logger) -> RequestContext: + return RequestContext( + model="openai/gpt-test", + provider="openai", + kwargs={"model": "openai/gpt-test"}, + streaming=False, + credentials=["cred"], + deadline=9999999999.0, + transaction_logger=logger, + ) + + +@pytest.mark.asyncio +async def test_large_retry_after_starts_provider_cooldown_and_traces(tmp_path, monkeypatch) -> None: + monkeypatch.setenv("SMALL_COOLDOWN_RETRY_THRESHOLD", "10") + monkeypatch.setenv("PROVIDER_COOLDOWN_MIN_SECONDS", "10") + cooldown = FakeCooldown() + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + await _executor(cooldown)._maybe_start_provider_cooldown( + "openai", + _classified("rate_limit", retry_after=60), + context=_context(logger), + ) + + assert cooldown.scoped_started == [("openai", 60, None, "provider", "retry_after")] + trace = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "provider_cooldown_started" in trace + + +@pytest.mark.asyncio +async def test_small_retry_after_skips_provider_cooldown(monkeypatch) -> None: + monkeypatch.setenv("SMALL_COOLDOWN_RETRY_THRESHOLD", "10") + cooldown = FakeCooldown() + + await _executor(cooldown)._maybe_start_provider_cooldown( + "openai", + _classified("rate_limit", retry_after=3), + context=None, + ) + + assert cooldown.started == [] + assert cooldown.scoped_started == [] + + +@pytest.mark.asyncio +async def test_model_capacity_starts_model_scoped_cooldown(monkeypatch) -> None: + monkeypatch.setenv("PROVIDER_COOLDOWN_DEFAULT_SECONDS", "30") + monkeypatch.setenv("PROVIDER_COOLDOWN_MIN_SECONDS", "10") + cooldown = FakeCooldown() + executor = _executor(cooldown) + + await executor._maybe_start_provider_cooldown( + "openai", + _classified("server_error"), + context=None, + model="gpt-5", + original_error=Exception("MODEL_CAPACITY_EXHAUSTED"), + ) + + assert cooldown.scoped_started == [("openai", 30, "gpt-5", "model", "model_capacity_cooldown")] + assert executor._failure_history.snapshot()[0].scope == "model" + + +@pytest.mark.asyncio +async def test_wait_for_cooldown_uses_model_scope_when_available() -> None: + cooldown = FakeCooldown() + + await _executor(cooldown)._wait_for_cooldown("openai", 9999999999.0, model="gpt-5") + + assert cooldown.waits == [("openai", "gpt-5")] + + +@pytest.mark.asyncio +async def test_wait_for_cooldown_exceeding_budget_fails_fast() -> None: + cooldown = BudgetCooldown(remaining=60) + + with pytest.raises(RoutingExecutionError) as exc: + await _executor(cooldown)._wait_for_cooldown("openai", 1.0, model="gpt-5") + + assert exc.value.error_type == "rate_limit" + assert cooldown.waits == [("openai", "gpt-5")] + + +@pytest.mark.asyncio +async def test_generic_transient_records_history_before_starting_cooldown(monkeypatch) -> None: + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") + monkeypatch.setenv("PROVIDER_COOLDOWN_DEFAULT_SECONDS", "10") + monkeypatch.setenv("PROVIDER_COOLDOWN_MIN_SECONDS", "10") + cooldown = FakeCooldown() + executor = _executor(cooldown) + + await executor._maybe_start_provider_cooldown("openai", _classified("server_error"), context=None, model="gpt-5") + assert cooldown.scoped_started == [] + assert executor._failure_history.snapshot()[0].reason == "transient_backoff_threshold_not_met" + + await executor._maybe_start_provider_cooldown("openai", _classified("server_error"), context=None, model="gpt-5") + assert cooldown.scoped_started == [("openai", 10, None, "provider", "default_transient_cooldown")] + + +def test_streaming_provider_cooldown_gate_allows_only_pre_output_failures() -> None: + assert _can_start_stream_provider_cooldown(None) is True + assert _can_start_stream_provider_cooldown('data: {"error":{"type":"rate_limit"}}\n\n') is True + assert _can_start_stream_provider_cooldown('data: {"choices":[{"delta":{"content":"visible"}}]}\n\n') is False + assert _can_start_stream_provider_cooldown('data: {"usage":{"total_tokens":1}}\n\n', emitted_output=True) is False diff --git a/tests/test_copilot_provider.py b/tests/test_copilot_provider.py new file mode 100644 index 000000000..3586c0995 --- /dev/null +++ b/tests/test_copilot_provider.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import pytest + +from rotator_library.adapters import AdapterContext, get_adapter, run_adapter_chain +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.copilot_provider import CopilotProvider + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeClient: + def __init__(self, payload): + self.payload = payload + self.calls = [] + + async def get(self, url, *, headers, timeout): + self.calls.append({"url": url, "headers": headers, "timeout": timeout}) + return FakeResponse(self.payload) + + +def test_copilot_provider_is_discovered() -> None: + assert "copilot" in PROVIDER_PLUGINS + + +def test_copilot_provider_declares_openai_chat_protocol_and_adapter() -> None: + provider = CopilotProvider() + + assert provider.get_protocol_name("gpt-4.1") == "openai_chat" + assert provider.get_adapter_names("gpt-4.1") == ("suppress_developer_role",) + assert provider.get_adapter_config("gpt-4.1") == {"suppress_developer_role": {"mode": "system"}} + assert provider.get_field_cache_rules("gpt-4.1") == () + + +@pytest.mark.asyncio +async def test_copilot_adapter_config_converts_developer_role_to_system() -> None: + provider = CopilotProvider() + + result = await run_adapter_chain( + [get_adapter("suppress_developer_role")], + {"messages": [{"role": "developer", "content": "rules"}]}, + AdapterContext(adapter_config=provider.get_adapter_config("gpt-4.1")), + stage="request", + ) + + assert result["messages"][0]["role"] == "system" + + +def test_copilot_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: + monkeypatch.setenv("COPILOT_API_BASE", "https://copilot.test") + monkeypatch.setenv("COPILOT_INTEGRATION_ID", "proxy-test") + provider = CopilotProvider() + + assert provider.get_native_endpoint(operation="chat") == "https://copilot.test/chat/completions" + assert provider.get_native_endpoint(operation="models") == "https://copilot.test/models" + assert provider.get_native_headers("token") == { + "Authorization": "Bearer token", + "content-type": "application/json", + "Copilot-Integration-Id": "proxy-test", + } + + +def test_copilot_native_operation_model_and_stream_support() -> None: + provider = CopilotProvider() + + assert provider.get_native_operation("gpt-4.1", {}, stream=False) == "chat" + assert provider.normalize_native_model("copilot/gpt-4.1") == "gpt-4.1" + assert provider.supports_native_streaming("gpt-4.1", operation="chat") is False + assert provider.supports_native_streaming("gpt-4.1", operation="responses") is False + + +@pytest.mark.asyncio +async def test_copilot_provider_get_models_uses_mocked_endpoint(monkeypatch) -> None: + monkeypatch.setenv("COPILOT_API_BASE", "https://copilot.test") + client = FakeClient({"data": [{"id": "gpt-4.1"}]}) + provider = CopilotProvider() + + models = await provider.get_models("token", client) + + assert models == ["copilot/gpt-4.1"] + assert client.calls[0]["url"] == "https://copilot.test/models" + + +@pytest.mark.asyncio +async def test_copilot_provider_get_models_falls_back_on_errors() -> None: + class BrokenClient: + async def get(self, *args, **kwargs): + raise RuntimeError("offline") + + assert await CopilotProvider().get_models("token", BrokenClient()) == ["copilot/gpt-4.1", "copilot/claude-sonnet-4-5"] diff --git a/tests/test_env_example_experimental_config.py b/tests/test_env_example_experimental_config.py new file mode 100644 index 000000000..cac0f0d61 --- /dev/null +++ b/tests/test_env_example_experimental_config.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from pathlib import Path + + +def test_env_example_documents_experimental_config_knobs() -> None: + text = Path(".env.example").read_text(encoding="utf-8") + + for key in ( + "LLM_PROXY_CONFIG_FILE", + "FALLBACK_GROUPS", + "FALLBACK_GROUP_CODE_CHAIN", + "MODEL_ROUTE_CODE", + "PROVIDER_COOLDOWN_MIN_SECONDS", + "PROVIDER_COOLDOWN_DEFAULT_SECONDS", + "PROVIDER_COOLDOWN_ON_QUOTA", + "PROVIDER_BACKOFF_WINDOW_SECONDS", + "PROVIDER_BACKOFF_THRESHOLD", + "PROVIDER_BACKOFF_BASE_SECONDS", + "PROVIDER_BACKOFF_MAX_SECONDS", + "FAILURE_HISTORY_MAX_ENTRIES", + "RESPONSES_STORE_TTL_SECONDS", + "RESPONSES_STORE_MAX_ITEMS", + "RESPONSES_STORE_FAILED", + "RESPONSES_STORE_IN_PROGRESS", + "STREAM_TRACE_METRICS", + "STREAM_TTFB_TIMEOUT_SECONDS", + "STREAM_STALL_TIMEOUT_SECONDS", + "STREAM_HEARTBEAT_INTERVAL_SECONDS", + "STREAM_HEARTBEAT_SECONDS", + "STREAM_CANCEL_UPSTREAM_ON_DISCONNECT", + "MODEL_PRICE_OPENAI_GPT_5_1_INPUT", + "MODEL_PRICE_OPENAI_GPT_5_1_REASONING", + ): + assert key in text + + assert "Do not put API keys" in text + for default_line in ( + "PROVIDER_COOLDOWN_MIN_SECONDS=10", + "PROVIDER_COOLDOWN_DEFAULT_SECONDS=30", + "PROVIDER_COOLDOWN_ON_QUOTA=false", + "PROVIDER_BACKOFF_WINDOW_SECONDS=60", + "PROVIDER_BACKOFF_THRESHOLD=3", + "PROVIDER_BACKOFF_BASE_SECONDS=0", + "PROVIDER_BACKOFF_MAX_SECONDS=300", + "FAILURE_HISTORY_MAX_ENTRIES=200", + "RESPONSES_STORE_TTL_SECONDS=0", + "RESPONSES_STORE_MAX_ITEMS=0", + "RESPONSES_STORE_FAILED=true", + "RESPONSES_STORE_IN_PROGRESS=false", + "STREAM_TRACE_METRICS=true", + "STREAM_TTFB_TIMEOUT_SECONDS=0", + "STREAM_STALL_TIMEOUT_SECONDS=0", + "STREAM_HEARTBEAT_INTERVAL_SECONDS=0", + "STREAM_CANCEL_UPSTREAM_ON_DISCONNECT=true", + ): + assert default_line in text diff --git a/tests/test_executor_usage_accounting.py b/tests/test_executor_usage_accounting.py new file mode 100644 index 000000000..1f28cc537 --- /dev/null +++ b/tests/test_executor_usage_accounting.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace + +from rotator_library.client.executor import RequestExecutor +from rotator_library.core.types import RequestContext +from rotator_library.transaction_logger import TransactionLogger + + +def _executor() -> RequestExecutor: + return RequestExecutor({}, None, None, None, {}, None) + + +def test_executor_accounts_for_non_streaming_usage_and_cost_trace(tmp_path, monkeypatch) -> None: + executor = _executor() + response = SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=100, + completion_tokens=30, + prompt_tokens_details={"cached_tokens": 40}, + completion_tokens_details={"reasoning_tokens": 10}, + ) + ) + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + context = RequestContext( + provider="openai", + model="gpt-test", + kwargs={}, + streaming=False, + credentials=[], + deadline=0, + transaction_logger=logger, + ) + monkeypatch.setattr( + "rotator_library.usage.costs.litellm.get_model_info", + lambda model: {"input_cost_per_token": 0.001, "output_cost_per_token": 0.002}, + ) + + usage, cost = executor._account_for_response_usage("openai", "gpt-test", response, context) + + assert usage.input_tokens == 60 + assert usage.cache_read_tokens == 40 + assert usage.completion_tokens == 20 + assert usage.reasoning_tokens == 10 + assert cost.total_cost > 0 + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + assert entries[-1]["pass_name"] == "usage_accounting_summary" + assert entries[-1]["data"]["usage"]["total_tokens"] == usage.total_tokens + + +def test_executor_accounting_uses_configured_env_pricing(tmp_path, monkeypatch) -> None: + monkeypatch.setenv("MODEL_PRICE_OPENAI_GPT_TEST_INPUT", "2.0") + executor = _executor() + response = SimpleNamespace(usage=SimpleNamespace(prompt_tokens=3, completion_tokens=0)) + context = RequestContext( + provider="openai", + model="gpt-test", + kwargs={}, + streaming=False, + credentials=[], + deadline=0, + transaction_logger=TransactionLogger("openai", "gpt-test", parent_dir=tmp_path), + ) + + _, cost = executor._account_for_response_usage("openai", "gpt-test", response, context) + + assert cost.pricing_source == "env" + assert cost.input_cost == 6.0 + + +def test_normalize_response_usage_handles_dict_responses() -> None: + response = { + "usage": { + "prompt_tokens": 4, + "completion_tokens": 1, + "completion_tokens_details": {"reasoning_tokens": 3}, + } + } + + result = RequestExecutor._normalize_response_usage(response, "gpt-test") + + assert result is response + assert response["usage"]["completion_tokens"] == 4 + assert response["usage"]["total_tokens"] == 8 diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py new file mode 100644 index 000000000..083b4faf1 --- /dev/null +++ b/tests/test_experimental_config.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import pytest + +from rotator_library.config.experimental import ( + ExperimentalConfigError, + as_int, + env_price_key, + get_responses_store_settings, + get_responses_store_runtime_settings, + get_retry_runtime_settings, + get_stream_runtime_settings, + load_config_from_mapping, + load_experimental_config, + parse_field_cache_rules, +) + + +def test_missing_config_file_returns_empty(tmp_path) -> None: + config = load_experimental_config(tmp_path / "missing.json", env={}) + + assert config.is_empty + + +def test_loads_config_from_env_path(tmp_path) -> None: + path = tmp_path / "config.json" + path.write_text('{"routing":{"model_routes":{"code":"group:chain"}},"extra":{}}', encoding="utf-8") + + config = load_experimental_config(env={"LLM_PROXY_CONFIG_FILE": str(path)}) + + assert config.routing["model_routes"]["code"] == "group:chain" + assert config.unknown_sections == {"extra": {}} + assert config.warnings + + +def test_rejects_secret_like_json_keys() -> None: + with pytest.raises(ExperimentalConfigError): + load_config_from_mapping({"providers": {"openai": {"api_key": "hidden"}}}) + + +def test_invalid_json_raises(tmp_path) -> None: + path = tmp_path / "config.json" + path.write_text("{", encoding="utf-8") + + with pytest.raises(ExperimentalConfigError): + load_experimental_config(path, env={}) + + +def test_stream_runtime_settings_env_overrides_json() -> None: + config = load_config_from_mapping({"streaming": {"trace_metrics": True, "stall_timeout_seconds": 30}}) + + settings = get_stream_runtime_settings(config=config, env={"STREAM_TRACE_METRICS": "false"}) + + assert settings.trace_metrics is False + assert settings.stall_timeout_seconds == 30 + + +def test_retry_runtime_settings_env_overrides_json() -> None: + config = load_config_from_mapping( + { + "retry": { + "provider_cooldown": {"provider_cooldown_min_seconds": 20, "provider_cooldown_on_quota": True}, + "backoff": {"provider_backoff_threshold": 4, "failure_history_max_entries": 50}, + } + } + ) + + settings = get_retry_runtime_settings(config=config, env={"PROVIDER_COOLDOWN_MIN_SECONDS": "5"}) + + assert settings.provider_cooldown_min_seconds == 5 + assert settings.provider_cooldown_on_quota is True + assert settings.provider_backoff_threshold == 4 + assert settings.failure_history_max_entries == 50 + + +def test_responses_store_settings_env_overrides_json() -> None: + config = load_config_from_mapping( + { + "responses": { + "store": { + "ttl_seconds": 60, + "max_items": 10, + "store_failed": False, + "store_in_progress": True, + } + } + } + ) + + settings = get_responses_store_settings(config=config, env={"RESPONSES_STORE_MAX_ITEMS": "5", "RESPONSES_STORE_FAILED": "true"}) + + assert settings.ttl_seconds == 60 + assert settings.max_items == 5 + assert settings.store_failed is True + assert settings.store_in_progress is True + + +def test_responses_store_runtime_settings_env_overrides_json(tmp_path) -> None: + config = load_config_from_mapping( + { + "responses": { + "store": { + "backend": "memory", + "cache_name": "json_responses", + "cache_prefix": "json_prefix", + "cache_dir": str(tmp_path / "json-cache"), + "cache_memory_ttl_seconds": 10, + "cache_disk_ttl_seconds": 20, + } + } + } + ) + + settings = get_responses_store_runtime_settings( + config=config, + env={"RESPONSES_STORE_BACKEND": "provider_cache", "RESPONSES_STORE_CACHE_MEMORY_TTL_SECONDS": "30"}, + ) + + assert settings.backend == "provider_cache" + assert settings.cache_name == "json_responses" + assert settings.cache_prefix == "json_prefix" + assert settings.cache_dir == str(tmp_path / "json-cache") + assert settings.cache_memory_ttl_seconds == 30 + assert settings.cache_disk_ttl_seconds == 20 + + +def test_responses_store_runtime_rejects_unknown_backend() -> None: + with pytest.raises(ExperimentalConfigError): + get_responses_store_runtime_settings(config=load_config_from_mapping({"responses": {"store": {"backend": "sqlite"}}}), env={}) + + +def test_new_config_sections_still_reject_secret_like_keys() -> None: + with pytest.raises(ExperimentalConfigError): + load_config_from_mapping({"responses": {"store": {"authorization": "hidden"}}}) + + +@pytest.mark.parametrize("secret_key", ["secret_key", "secret-key", "apiKey", "client-secret", "oauth_token", "oauthToken", "oauth-token", "id_token", "oauth_token_secret"]) +def test_secret_key_variants_are_rejected(secret_key: str) -> None: + with pytest.raises(ExperimentalConfigError): + load_config_from_mapping({"retry": {secret_key: "hidden"}}) + + +def test_retry_runtime_settings_malformed_env_preserves_defaults() -> None: + config = load_config_from_mapping({"retry": {"provider_cooldown_default_seconds": 45}}) + + settings = get_retry_runtime_settings( + config=config, + env={ + "PROVIDER_COOLDOWN_DEFAULT_SECONDS": "not-an-int", + "PROVIDER_COOLDOWN_ON_QUOTA": "not-a-bool", + "PROVIDER_BACKOFF_BASE_SECONDS": "not-an-int", + }, + ) + + assert settings.provider_cooldown_default_seconds == 30 + assert settings.provider_cooldown_on_quota is False + assert settings.provider_backoff_base_seconds is None + + +def test_field_cache_rules_parse_wildcard_then_model_specific() -> None: + config = load_config_from_mapping( + { + "field_cache": { + "gemini_cli": { + "*": [{"name": "thought", "source": "response", "path": "$.thought", "target_path": "$.cached_thought"}], + "gemini-3": [{"name": "signature", "source": "response", "path": "$.sig", "scope": ["provider", "model"]}], + } + } + } + ) + + rules = parse_field_cache_rules(config, "gemini_cli", "gemini-3") + + assert [rule.name for rule in rules] == ["thought", "signature"] + assert rules[0].inject is not None + + +def test_field_cache_rules_match_unprefixed_model_alias() -> None: + config = load_config_from_mapping( + { + "field_cache": { + "gemini_cli": { + "gemini-3": [{"name": "signature", "source": "response", "path": "sig"}], + } + } + } + ) + + rules = parse_field_cache_rules(config, "gemini_cli", "gemini_cli/gemini-3") + + assert [rule.name for rule in rules] == ["signature"] + + +def test_field_cache_rule_parses_ttl_metadata_and_insert_injection() -> None: + config = load_config_from_mapping( + { + "field_cache": { + "provider": { + "*": [ + { + "name": "tool_state", + "source": "stream_event", + "path": "raw.tool.state", + "mode": "per_tool_call", + "ttl_seconds": 120, + "metadata": {"tool_container_path": "tools"}, + "inject": {"target": "request", "path": "metadata.tool_state", "insert": True}, + } + ] + } + } + } + ) + + rule = parse_field_cache_rules(config, "provider", "model")[0] + + assert rule.ttl_seconds == 120 + assert rule.metadata == {"tool_container_path": "tools"} + assert rule.inject is not None + assert rule.inject.insert is True + + +def test_field_cache_rule_rejects_invalid_config_values() -> None: + config = load_config_from_mapping( + { + "field_cache": { + "provider": { + "*": [ + {"name": "bad_source", "source": "not_a_source", "path": "x"}, + ] + } + } + } + ) + + with pytest.raises(ValueError, match="Unsupported field-cache source"): + parse_field_cache_rules(config, "provider", "model") + + +def test_field_cache_rules_reject_malformed_shapes() -> None: + config = load_config_from_mapping({"field_cache": {"provider": {"*": ["not-a-rule"]}}}) + + with pytest.raises(ExperimentalConfigError, match="rule entries"): + parse_field_cache_rules(config, "provider", "model") + + +def test_field_cache_rules_reject_non_list_model_rules() -> None: + config = load_config_from_mapping({"field_cache": {"provider": {"*": {"name": "bad"}}}}) + + with pytest.raises(ExperimentalConfigError, match="model rules"): + parse_field_cache_rules(config, "provider", "model") + + +def test_field_cache_rules_reject_malformed_nested_shapes() -> None: + bad_inject = load_config_from_mapping({"field_cache": {"provider": {"*": [{"name": "bad", "source": "response", "path": "x", "inject": "not-object"}]}}}) + + with pytest.raises(ExperimentalConfigError, match="inject"): + parse_field_cache_rules(bad_inject, "provider", "model") + + bad_metadata = load_config_from_mapping({"field_cache": {"provider": {"*": [{"name": "bad", "source": "response", "path": "x", "metadata": "not-object"}]}}}) + + with pytest.raises(ExperimentalConfigError, match="metadata"): + parse_field_cache_rules(bad_metadata, "provider", "model") + + +def test_env_price_key_sanitizes_provider_and_model() -> None: + assert env_price_key("openai", "gpt-5.1-mini", "cache_read") == "MODEL_PRICE_OPENAI_GPT_5_1_MINI_CACHE_READ" + + +def test_as_int_parses_integers_with_redacted_errors() -> None: + assert as_int("5", name="TEST_INT") == 5 + with pytest.raises(ExperimentalConfigError, match="TEST_INT"): + as_int("not-secret-value", name="TEST_INT") diff --git a/tests/test_fallback_attempt_runner.py b/tests/test_fallback_attempt_runner.py new file mode 100644 index 000000000..167e5698d --- /dev/null +++ b/tests/test_fallback_attempt_runner.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import pytest + +from rotator_library.routing import FallbackAttemptRunner, FallbackExhaustedError, RoutingDecision, parse_route_target +from rotator_library.routing.types import FallbackGroup + + +class ClassifiedFailure(Exception): + def __init__(self, error_type: str, *, emitted_output: bool = False) -> None: + super().__init__(error_type) + self.error_type = error_type + self.emitted_output = emitted_output + + +def _decision() -> RoutingDecision: + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) + return RoutingDecision( + requested_model="code", + group_name="code_chain", + targets=targets, + group=FallbackGroup(name="code_chain", targets=targets), + reason="model_route_group", + ) + + +@pytest.mark.asyncio +async def test_attempt_runner_returns_first_success_without_fallback() -> None: + calls = [] + + async def attempt(target, index): + calls.append(target.prefixed_model) + return {"target": target.prefixed_model} + + result = await FallbackAttemptRunner().run(_decision(), attempt) + + assert result == {"target": "codex/gpt-5.1-codex"} + assert calls == ["codex/gpt-5.1-codex"] + + +@pytest.mark.asyncio +async def test_attempt_runner_falls_back_on_retryable_error() -> None: + calls = [] + + async def attempt(target, index): + calls.append(target.prefixed_model) + if index == 0: + raise ClassifiedFailure("rate_limit") + return {"target": target.prefixed_model} + + result = await FallbackAttemptRunner().run(_decision(), attempt) + + assert result == {"target": "openai/gpt-5.1"} + assert calls == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + + +@pytest.mark.asyncio +async def test_attempt_runner_stops_on_permanent_error() -> None: + async def attempt(target, index): + raise ClassifiedFailure("validation") + + with pytest.raises(FallbackExhaustedError) as exc: + await FallbackAttemptRunner().run(_decision(), attempt) + + assert len(exc.value.attempts) == 1 + assert exc.value.attempts[0].error_type == "invalid_request" + + +@pytest.mark.asyncio +async def test_attempt_runner_blocks_stream_fallback_after_output() -> None: + async def attempt(target, index): + raise ClassifiedFailure("rate_limit", emitted_output=True) + + with pytest.raises(FallbackExhaustedError) as exc: + await FallbackAttemptRunner().run(_decision(), attempt, stream=True) + + assert len(exc.value.attempts) == 1 + assert exc.value.attempts[0].emitted_output is True + + +@pytest.mark.asyncio +async def test_attempt_runner_hard_stops_group_policy_overrides() -> None: + group = FallbackGroup( + name="custom", + targets=_decision().targets, + failover_on=frozenset({"authentication"}), + stop_on=frozenset({"validation"}), + ) + calls = [] + + async def attempt(target, index): + calls.append(index) + if index == 0: + raise ClassifiedFailure("authentication") + return {"target": target.prefixed_model} + + with pytest.raises(FallbackExhaustedError): + await FallbackAttemptRunner().run_group(_decision(), group, attempt) + + assert calls == [0] + + +@pytest.mark.asyncio +async def test_attempt_runner_respects_never_streaming_policy() -> None: + group = FallbackGroup( + name="custom", + targets=_decision().targets, + failover_on=frozenset({"rate_limit"}), + streaming_policy="never", + ) + calls = [] + + async def attempt(target, index): + calls.append(index) + raise ClassifiedFailure("rate_limit") + + with pytest.raises(FallbackExhaustedError): + await FallbackAttemptRunner().run_group(_decision(), group, attempt, stream=True) + + assert calls == [0] + + +@pytest.mark.asyncio +async def test_attempt_runner_run_uses_decision_group_streaming_policy() -> None: + decision = _decision() + never_group = FallbackGroup(name="code_chain", targets=decision.targets, failover_on=frozenset({"rate_limit"}), streaming_policy="never") + decision = RoutingDecision(requested_model=decision.requested_model, group_name=decision.group_name, targets=decision.targets, group=never_group) + calls = [] + + async def attempt(target, index): + calls.append(index) + raise ClassifiedFailure("rate_limit") + + with pytest.raises(FallbackExhaustedError): + await FallbackAttemptRunner().run(decision, attempt, stream=True) + + assert calls == [0] diff --git a/tests/test_fallback_groups.py b/tests/test_fallback_groups.py new file mode 100644 index 000000000..66f24bdb0 --- /dev/null +++ b/tests/test_fallback_groups.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import pytest + +from rotator_library.routing import FallbackResolver, FallbackGroup, RouteTarget, RoutingConfig +from rotator_library.routing.config import RoutingConfigError, load_routing_config_from_env, parse_route_target +from rotator_library.routing.executor import FallbackAttemptRunner + + +def test_parse_route_target_supports_execution_suffix() -> None: + target = parse_route_target("openai/gpt-5@litellm_fallback") + + assert target.provider == "openai" + assert target.model == "gpt-5" + assert target.execution == "litellm_fallback" + + +def test_env_fallback_group_config_loads_current_routing_types() -> None: + config = load_routing_config_from_env( + { + "FALLBACK_GROUPS": "main", + "FALLBACK_GROUP_MAIN": "openai/gpt-5@litellm_fallback,anthropic/claude@native", + "MODEL_ROUTE_GPT5": "group:main", + }, + config=RoutingConfig(), + ) + + group = config.fallback_groups["main"] + assert [target.prefixed_model for target in group.targets] == ["openai/gpt-5", "anthropic/claude"] + assert group.targets[0].execution == "litellm_fallback" + assert group.targets[1].execution == "native" + assert config.model_routes["gpt5"] == "group:main" + + +def test_resolver_promotes_requested_provider_model_inside_group() -> None: + group = FallbackGroup( + name="main", + targets=( + RouteTarget("openai", "gpt-5"), + RouteTarget("anthropic", "claude"), + RouteTarget("google", "gemini"), + ), + ) + decision = FallbackResolver(RoutingConfig(fallback_groups={"main": group})).resolve("anthropic/claude") + + assert decision.reason == "provider_model_group_promoted" + assert [target.prefixed_model for target in decision.targets] == ["anthropic/claude", "openai/gpt-5", "google/gemini"] + + +def test_resolver_promotes_requested_model_for_group_route_alias() -> None: + group = FallbackGroup( + name="main", + targets=(RouteTarget("openai", "gpt-5"), RouteTarget("anthropic", "claude")), + ) + config = RoutingConfig(fallback_groups={"main": group}, model_routes={"anthropic/claude": "group:main"}) + + decision = FallbackResolver(config).resolve("anthropic/claude") + + assert decision.reason == "model_route_group_promoted" + assert [target.prefixed_model for target in decision.targets] == ["anthropic/claude", "openai/gpt-5"] + + +def test_resolver_rejects_missing_group_route() -> None: + with pytest.raises(RoutingConfigError): + FallbackResolver(RoutingConfig(model_routes={"alias": "group:missing"})).resolve("alias") + + +@pytest.mark.asyncio +async def test_fallback_runner_uses_decision_group_policy() -> None: + group = FallbackGroup(name="main", targets=(RouteTarget("a", "one"), RouteTarget("b", "two")), failover_on=frozenset({"rate_limit"})) + decision = FallbackResolver(RoutingConfig(fallback_groups={"main": group}, model_routes={"alias": "group:main"})).resolve("alias") + attempts: list[str] = [] + + class RateLimitError(RuntimeError): + error_type = "rate_limit" + + async def attempt(target, index): + attempts.append(target.prefixed_model) + if target.provider == "a": + raise RateLimitError("rate limit") + return "ok" + + result = await FallbackAttemptRunner().run(decision, attempt) + + assert result == "ok" + assert attempts == ["a/one", "b/two"] diff --git a/tests/test_fallback_policy.py b/tests/test_fallback_policy.py new file mode 100644 index 000000000..74adaeaab --- /dev/null +++ b/tests/test_fallback_policy.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from rotator_library.routing import FallbackPolicy, parse_route_target +from rotator_library.routing.policy import normalize_route_error_type +from rotator_library.routing.types import FallbackGroup + + +def test_policy_falls_back_on_retryable_categories() -> None: + policy = FallbackPolicy() + + assert policy.should_fallback("rate_limit") is True + assert policy.should_fallback("quota_exceeded") is True + assert policy.should_fallback("server_error") is True + assert policy.should_fallback("api_connection") is True + + +def test_policy_stops_on_permanent_categories() -> None: + policy = FallbackPolicy() + + assert policy.should_fallback("authentication") is False + assert policy.should_fallback("forbidden") is False + assert policy.should_fallback("invalid_request") is False + assert policy.should_fallback("context_window_exceeded") is False + assert policy.should_fallback("credential_reauth_needed") is False + assert policy.should_fallback("pre_request_callback_error") is False + assert policy.should_fallback("cancelled") is False + + +def test_policy_blocks_stream_fallback_after_visible_output() -> None: + assert FallbackPolicy().should_fallback("rate_limit", stream=True, emitted_output=True) is False + + +def test_policy_allows_stream_fallback_before_visible_output() -> None: + assert FallbackPolicy().should_fallback("rate_limit", stream=True, emitted_output=False) is True + + +def test_policy_respects_safe_group_overrides() -> None: + group = FallbackGroup( + name="auth_safe", + targets=(parse_route_target("a/model"), parse_route_target("b/model")), + failover_on=frozenset({"network"}), + stop_on=frozenset({"validation"}), + ) + + assert FallbackPolicy().should_fallback("api_connection", group=group) is True + assert FallbackPolicy().should_fallback("validation", group=group) is False + + +def test_policy_hard_stops_cannot_be_overridden_by_group_failover() -> None: + group = FallbackGroup( + name="unsafe", + targets=(parse_route_target("a/model"), parse_route_target("b/model")), + failover_on=frozenset({"auth", "configuration"}), + stop_on=frozenset(), + ) + + assert FallbackPolicy().should_fallback("authentication", group=group) is False + assert FallbackPolicy().should_fallback("configuration_error", group=group) is False + + +def test_policy_normalizes_user_facing_aliases() -> None: + assert normalize_route_error_type("auth") == "authentication" + assert normalize_route_error_type("permission-denied") == "forbidden" + assert normalize_route_error_type("bad request") == "invalid_request" + assert normalize_route_error_type("context_length_exceeded") == "context_window_exceeded" + assert FallbackPolicy().should_fallback("network") is True + assert FallbackPolicy().should_fallback("validation") is False + + +def test_policy_normalizes_common_structured_provider_aliases() -> None: + assert normalize_route_error_type("invalid_api_key") == "authentication" + assert normalize_route_error_type("unauthorized") == "authentication" + assert normalize_route_error_type("invalid_argument") == "invalid_request" + assert normalize_route_error_type("max_tokens_exceeded") == "context_window_exceeded" + assert normalize_route_error_type("rate_limited") == "rate_limit" + assert normalize_route_error_type("too_many_requests") == "rate_limit" + assert normalize_route_error_type("resource_exhausted") == "quota_exceeded" + assert normalize_route_error_type("unavailable") == "server_error" + assert normalize_route_error_type("deadline_exceeded") == "api_connection" diff --git a/tests/test_fallback_resolver.py b/tests/test_fallback_resolver.py new file mode 100644 index 000000000..674a5d559 --- /dev/null +++ b/tests/test_fallback_resolver.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest + +from rotator_library.routing import FallbackResolver, RoutingConfigError, load_routing_config_from_env + + +def test_resolver_maps_alias_to_fallback_group_in_order() -> None: + config = load_routing_config_from_env( + { + "FALLBACK_GROUPS": "code_chain", + "FALLBACK_GROUP_CODE_CHAIN": "codex/gpt-5.1-codex,openai/gpt-5.1", + "MODEL_ROUTE_CODEX": "group:code_chain", + } + ) + + decision = FallbackResolver(config).resolve("codex") + + assert decision.group_name == "code_chain" + assert [target.prefixed_model for target in decision.targets] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + assert decision.reason == "model_route_group" + + +def test_resolver_keeps_provider_prefixed_model_as_direct_target() -> None: + decision = FallbackResolver(load_routing_config_from_env({})).resolve("openai/gpt-5.1") + + assert decision.group_name is None + assert decision.targets[0].provider == "openai" + assert decision.targets[0].model == "gpt-5.1" + assert decision.reason == "direct_provider_model" + + +def test_resolver_maps_alias_to_single_target() -> None: + config = load_routing_config_from_env({"MODEL_ROUTE_FAST": "openai/gpt-5.1"}) + + decision = FallbackResolver(config).resolve("fast") + + assert decision.targets[0].prefixed_model == "openai/gpt-5.1" + assert decision.reason == "model_route_target" + + +def test_resolver_rejects_unprefixed_model_without_route() -> None: + with pytest.raises(RoutingConfigError): + FallbackResolver(load_routing_config_from_env({})).resolve("gpt-5.1") diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py new file mode 100644 index 000000000..842a34801 --- /dev/null +++ b/tests/test_field_cache_engine.py @@ -0,0 +1,559 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.field_cache import ( + FieldCacheContext, + FieldCacheEngine, + FieldCacheInjection, + FieldCacheRule, + InMemoryFieldCacheStore, + ProviderCacheFieldStore, + build_cache_key, +) +from rotator_library.transaction_logger import TransactionLogger + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +def _reasoning_rule(mode: str = "last", scope=("provider", "model", "session")) -> FieldCacheRule: + return FieldCacheRule( + name="reasoning_content", + source="response", + path="choices.*.message.reasoning_content", + mode=mode, + scope=scope, + inject=FieldCacheInjection(target="request", path="messages[-1].reasoning_content"), + ) + + +def _context(**overrides) -> FieldCacheContext: + values = {"provider": "openai", "model": "gpt-test", "session_id": "session-a", "classifier": "global"} + values.update(overrides) + return FieldCacheContext(**values) + + +@pytest.mark.asyncio +async def test_extract_response_value_and_inject_into_next_request() -> None: + engine = FieldCacheEngine([_reasoning_rule()]) + response = {"choices": [{"message": {"reasoning_content": "hidden"}}]} + request = {"messages": [{"role": "user", "content": "hi"}]} + + operations = await engine.extract("response", response, _context()) + updated, injection_operations = await engine.inject("request", request, _context()) + + assert operations[0].matched == 1 + assert injection_operations[0].hit is True + assert updated["messages"][-1]["reasoning_content"] == "hidden" + assert "reasoning_content" not in request["messages"][-1] + + +@pytest.mark.asyncio +async def test_last_mode_overwrites_prior_value() -> None: + engine = FieldCacheEngine([_reasoning_rule()]) + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "first"}}]}, _context()) + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "second"}}]}, _context()) + updated, _ = await engine.inject("request", {"messages": [{}]}, _context()) + + assert updated["messages"][-1]["reasoning_content"] == "second" + + +@pytest.mark.asyncio +async def test_all_mode_appends_values() -> None: + engine = FieldCacheEngine([_reasoning_rule(mode="all")]) + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "first"}}]}, _context()) + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "second"}}]}, _context()) + updated, _ = await engine.inject("request", {"messages": [{}]}, _context()) + + assert updated["messages"][-1]["reasoning_content"] == ["first", "second"] + + +@pytest.mark.asyncio +async def test_scope_isolation_by_session_and_classifier() -> None: + rule = _reasoning_rule(scope=("provider", "model", "session", "classifier")) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "a"}}]}, _context(session_id="session-a")) + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "b"}}]}, _context(session_id="session-b")) + + updated, _ = await engine.inject("request", {"messages": [{}]}, _context(session_id="session-a")) + + assert updated["messages"][-1]["reasoning_content"] == "a" + assert build_cache_key(rule, _context(session_id="session-a")) != build_cache_key(rule, _context(session_id="session-b")) + + +@pytest.mark.asyncio +async def test_scope_isolation_by_credential_and_provider() -> None: + rule = _reasoning_rule(scope=("provider", "model", "credential")) + engine = FieldCacheEngine([rule]) + + await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "cred-a"}}]}, + _context(provider="openai", credential_id="credential-a"), + ) + await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "cred-b"}}]}, + _context(provider="openai", credential_id="credential-b"), + ) + + updated, _ = await engine.inject("request", {"messages": [{}]}, _context(provider="openai", credential_id="credential-a")) + + assert updated["messages"][-1]["reasoning_content"] == "cred-a" + assert build_cache_key(rule, _context(provider="openai", credential_id="credential-a")) != build_cache_key( + rule, _context(provider="other", credential_id="credential-a") + ) + + +@pytest.mark.asyncio +async def test_missing_session_scope_skips_by_default() -> None: + engine = FieldCacheEngine([_reasoning_rule()]) + + operations = await engine.extract("response", {"choices": [{"message": {"reasoning_content": "x"}}]}, _context(session_id=None)) + + assert operations[0].skipped is True + assert operations[0].reason == "missing_required_scope" + + +@pytest.mark.asyncio +async def test_missing_credential_scope_skips_instead_of_sharing_none_bucket() -> None: + rule = _reasoning_rule(scope=("provider", "model", "credential")) + engine = FieldCacheEngine([rule]) + + operations = await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "x"}}]}, + _context(credential_id=None), + ) + updated, injection_operations = await engine.inject("request", {"messages": [{}]}, _context(credential_id=None)) + + assert operations[0].skipped is True + assert operations[0].reason == "missing_required_scope" + assert injection_operations[0].skipped is True + assert injection_operations[0].reason == "missing_required_scope" + assert updated == {"messages": [{}]} + assert build_cache_key(rule, _context(credential_id=None)) is None + + +@pytest.mark.asyncio +async def test_missing_path_is_noop() -> None: + engine = FieldCacheEngine([_reasoning_rule()]) + + operations = await engine.extract("response", {"choices": []}, _context()) + + assert operations[0].matched == 0 + assert operations[0].changed is False + + +@pytest.mark.asyncio +async def test_stream_event_extraction() -> None: + rule = FieldCacheRule( + name="provider_session_id", + source="stream_event", + path="metadata.provider_session_id", + scope=("provider", "model", "session"), + inject=FieldCacheInjection(target="request", path="metadata.provider_session_id"), + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("stream_event", {"metadata": {"provider_session_id": "sid_1"}}, _context()) + updated, _ = await engine.inject("request", {"metadata": {}}, _context()) + + assert updated["metadata"]["provider_session_id"] == "sid_1" + + +@pytest.mark.asyncio +async def test_trace_sample_values_are_truncated() -> None: + rule = _reasoning_rule() + engine = FieldCacheEngine([rule]) + long_value = "x" * 700 + + operations = await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": long_value}}]}, + _context(), + ) + + assert operations[0].sample_values[0].endswith("...") + + +@pytest.mark.asyncio +async def test_field_cache_trace_omits_raw_sample_values(tmp_path) -> None: + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + rule = _reasoning_rule() + engine = FieldCacheEngine([rule]) + + await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "provider-signature-secret"}}]}, + _context(), + transaction_logger=logger, + ) + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "provider-signature-secret" not in trace_text + entries = _trace_entries(logger.log_dir) + after_entry = next(entry for entry in entries if entry["pass_name"] == "after_field_cache_extraction") + assert after_entry["metadata"]["sample_value_count"] == 1 + assert after_entry["metadata"]["sample_value_types"] == ["str"] + + +@pytest.mark.asyncio +async def test_field_cache_error_trace_omits_raw_payload_values(tmp_path) -> None: + class FailingStore(InMemoryFieldCacheStore): + async def set(self, key, value, *, ttl_seconds=None): + raise RuntimeError("store failed") + + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + engine = FieldCacheEngine([_reasoning_rule()], store=FailingStore()) + + with pytest.raises(RuntimeError): + await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "provider-signature-secret"}}]}, + _context(), + transaction_logger=logger, + ) + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "provider-signature-secret" not in trace_text + assert "payload_type" in trace_text + + +@pytest.mark.asyncio +async def test_field_cache_traces_start_and_complete_even_without_matching_rules(tmp_path) -> None: + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + engine = FieldCacheEngine([]) + + operations = await engine.extract("response", {"choices": []}, _context(), transaction_logger=logger) + updated, injection_operations = await engine.inject("request", {"messages": []}, _context(), transaction_logger=logger) + + assert operations == [] + assert injection_operations == [] + assert updated == {"messages": []} + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names == [ + "field_cache_extraction_start", + "field_cache_extraction_complete", + "field_cache_injection_start", + "field_cache_injection_complete", + ] + assert entries[1]["metadata"]["rule_count"] == 0 + assert entries[-1]["metadata"]["operation_count"] == 0 + + +def test_per_tool_call_requires_tool_call_id_path() -> None: + with pytest.raises(ValueError): + FieldCacheEngine([ + FieldCacheRule(name="tool_state", source="response", path="tool_calls.*", mode="per_tool_call") + ]) + + +@pytest.mark.asyncio +async def test_provider_cache_field_store_wraps_json_string_cache() -> None: + class FakeProviderCache: + def __init__(self) -> None: + self.values = {} + + async def retrieve_async(self, key: str): + return self.values.get(key) + + async def store_async(self, key: str, value: str) -> None: + json.loads(value) + self.values[key] = value + + async def clear(self) -> None: + self.values.clear() + + store = ProviderCacheFieldStore(FakeProviderCache()) + + await store.set("key", {"value": 1}) + assert await store.get("key") == {"value": 1} + await store.append("key", [{"value": 2}]) + assert await store.get("key") == [{"value": 2}] + + +@pytest.mark.asyncio +async def test_in_memory_store_expires_ttl_values() -> None: + now = 100.0 + store = InMemoryFieldCacheStore(clock=lambda: now) + + await store.set("key", "value", ttl_seconds=5) + assert await store.get("key") == "value" + now = 106.0 + assert await store.get("key") is None + + +@pytest.mark.asyncio +async def test_provider_cache_field_store_expires_ttl_values(monkeypatch) -> None: + class FakeProviderCache: + def __init__(self) -> None: + self.values = {} + + async def retrieve_async(self, key: str): + return self.values.get(key) + + async def store_async(self, key: str, value: str) -> None: + json.loads(value) + self.values[key] = value + + async def clear(self) -> None: + self.values.clear() + + now = 100.0 + monkeypatch.setattr("rotator_library.field_cache.store.time.time", lambda: now) + store = ProviderCacheFieldStore(FakeProviderCache()) + + await store.set("key", "value", ttl_seconds=5) + assert await store.get("key") == "value" + now = 106.0 + assert await store.get("key") is None + + +@pytest.mark.asyncio +async def test_in_memory_store_returns_deep_copies() -> None: + store = InMemoryFieldCacheStore() + await store.set("key", {"nested": []}) + value = await store.get("key") + value["nested"].append("mutated") + + assert await store.get("key") == {"nested": []} + + +@pytest.mark.asyncio +async def test_last_user_turn_uses_latest_user_message() -> None: + rule = FieldCacheRule( + name="user_signature", + source="request", + path="messages.*.metadata.signature", + mode="last_user_turn", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + operations = await engine.extract( + "request", + { + "messages": [ + {"role": "user", "metadata": {"signature": "first-user"}}, + {"role": "assistant", "metadata": {"signature": "assistant"}}, + {"role": "user", "metadata": {"signature": "last-user"}}, + ] + }, + _context(session_id=None), + ) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert operations[0].changed is True + assert updated["metadata"]["signature"] == "last-user" + + +@pytest.mark.asyncio +async def test_last_mode_preserves_list_valued_field() -> None: + rule = FieldCacheRule( + name="list_value", + source="response", + path="metadata.signatures", + inject=FieldCacheInjection(target="request", path="metadata.signatures"), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"metadata": {"signatures": ["a", "b"]}}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert updated["metadata"]["signatures"] == ["a", "b"] + + +@pytest.mark.asyncio +async def test_as_list_unwraps_last_mode_value_envelope() -> None: + rule = FieldCacheRule( + name="value", + source="response", + path="value", + inject=FieldCacheInjection(target="request", path="metadata.values", as_list=True), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"value": "sig"}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert updated["metadata"]["values"] == ["sig"] + + +@pytest.mark.asyncio +async def test_last_assistant_turn_skips_without_turn_context() -> None: + rule = FieldCacheRule( + name="assistant_signature", + source="response", + path="choices.*.message.signature", + mode="last_assistant_turn", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + operations = await engine.extract("response", {"choices": [{"message": {"signature": "sig"}}]}, _context(session_id=None)) + + assert operations[0].skipped is True + assert operations[0].reason == "turn_context_not_found" + + +@pytest.mark.asyncio +async def test_turn_mode_uses_metadata_configured_relative_paths() -> None: + rule = FieldCacheRule( + name="assistant_signature", + source="response", + path="unused.global.path", + mode="last_assistant_turn", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + metadata={"turn_container_path": "messages", "turn_role_path": "kind", "turn_value_path": "parts.*.signature"}, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract( + "response", + {"messages": [{"kind": "assistant", "parts": [{"signature": "first"}]}, {"kind": "assistant", "parts": [{"signature": "second"}]}]}, + _context(session_id=None), + ) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert updated["metadata"]["signature"] == "second" + + +@pytest.mark.asyncio +async def test_per_tool_call_correlates_sibling_id_and_value_for_injection() -> None: + rule = FieldCacheRule( + name="tool_signature", + source="response", + path="tool_calls.*.signature", + mode="per_tool_call", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + metadata={ + "tool_container_path": "tool_calls", + "tool_call_id_path": "id", + "tool_value_path": "signature", + }, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"tool_calls": [{"id": "call_a", "signature": "sig-a"}, {"id": "call_b", "signature": "sig-b"}]}, _context(session_id=None)) + updated, operations = await engine.inject("request", {"metadata": {}}, _context(session_id=None, metadata={"tool_call_id": "call_b"})) + + assert operations[0].hit is True + assert updated["metadata"]["signature"] == "sig-b" + + +@pytest.mark.asyncio +async def test_per_tool_call_as_list_injects_matching_values() -> None: + rule = FieldCacheRule( + name="tool_signature", + source="response", + path="tool_calls.*", + mode="per_tool_call", + inject=FieldCacheInjection(target="request", path="metadata.signatures", as_list=True), + allow_missing_session=True, + metadata={"tool_call_id_path": "id", "inject_tool_call_id_path": "tool_ids.*"}, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"tool_calls": [{"id": "a", "signature": "sig-a"}, {"id": "b", "signature": "sig-b"}]}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"metadata": {}, "tool_ids": ["a", "b"]}, _context(session_id=None)) + + assert updated["metadata"]["signatures"] == [{"id": "a", "signature": "sig-a"}, {"id": "b", "signature": "sig-b"}] + + +@pytest.mark.asyncio +async def test_per_tool_call_preserves_list_valued_match() -> None: + rule = FieldCacheRule( + name="tool_signatures", + source="response", + path="tool_calls.*.signatures", + mode="per_tool_call", + inject=FieldCacheInjection(target="request", path="metadata.signatures"), + allow_missing_session=True, + metadata={"tool_container_path": "tool_calls", "tool_call_id_path": "id", "tool_value_path": "signatures"}, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"tool_calls": [{"id": "call", "signatures": ["a", "b"]}]}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None, metadata={"tool_call_id": "call"})) + + assert updated["metadata"]["signatures"] == ["a", "b"] + + +@pytest.mark.asyncio +async def test_engine_supports_legacy_store_without_ttl_keyword() -> None: + class LegacyStore: + def __init__(self) -> None: + self.values = {} + + async def get(self, key): + return self.values.get(key) + + async def set(self, key, value): + self.values[key] = value + + async def append(self, key, values): + self.values.setdefault(key, []).extend(values) + return self.values[key] + + async def clear(self): + self.values.clear() + + engine = FieldCacheEngine([_reasoning_rule()], store=LegacyStore()) + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "legacy"}}]}, _context()) + updated, _ = await engine.inject("request", {"messages": [{}]}, _context()) + + assert updated["messages"][-1]["reasoning_content"] == "legacy" + + +@pytest.mark.asyncio +async def test_per_tool_call_skips_when_current_tool_id_is_ambiguous() -> None: + rule = FieldCacheRule( + name="tool_signature", + source="response", + path="tool_calls.*", + mode="per_tool_call", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + metadata={"tool_call_id_path": "id"}, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"tool_calls": [{"id": "call_a", "signature": "sig-a"}]}, _context(session_id=None)) + updated, operations = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert operations[0].skipped is True + assert operations[0].reason == "tool_call_id_not_found" + assert updated == {"metadata": {}} + + +@pytest.mark.asyncio +async def test_insert_injection_adds_list_entry() -> None: + rule = FieldCacheRule( + name="prefix_message", + source="response", + path="message", + inject=FieldCacheInjection(target="request", path="messages.0", insert=True), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"message": {"role": "system", "content": "cached"}}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"messages": [{"role": "user", "content": "hi"}]}, _context(session_id=None)) + + assert updated["messages"] == [{"role": "system", "content": "cached"}, {"role": "user", "content": "hi"}] diff --git a/tests/test_field_cache_paths.py b/tests/test_field_cache_paths.py new file mode 100644 index 000000000..8e916b7df --- /dev/null +++ b/tests/test_field_cache_paths.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import pytest + +from rotator_library.field_cache.paths import FieldCachePathError, extract_path, inject_path, parse_path +from rotator_library.field_cache.types import FieldCacheInjection, FieldCacheRule + + +def test_parse_path_supports_keys_indexes_wildcards_and_tail_index() -> None: + tokens = parse_path("choices.*.message.parts[-1]") + + assert [token.kind for token in tokens] == ["key", "wildcard", "key", "key", "index"] + assert tokens[-1].value == -1 + + +def test_extract_path_handles_nested_dict_list_and_wildcard() -> None: + payload = { + "choices": [ + {"message": {"reasoning_content": "a"}}, + {"message": {"reasoning_content": "b"}}, + ] + } + + assert extract_path(payload, "choices.*.message.reasoning_content") == ["a", "b"] + assert extract_path(payload, "choices.1.message.reasoning_content") == ["b"] + + +def test_extract_path_missing_values_are_noop() -> None: + assert extract_path({"choices": []}, "choices.*.message.reasoning_content") == [] + assert extract_path({}, "missing.path") == [] + + +def test_extract_path_tail_index() -> None: + payload = {"messages": [{"content": "first"}, {"content": "last"}]} + + assert extract_path(payload, "messages[-1].content") == ["last"] + + +def test_inject_path_creates_dict_containers() -> None: + payload = {"messages": [{"role": "assistant"}]} + + changed = inject_path(payload, "messages[-1].reasoning_content", "hidden") + + assert changed is True + assert payload["messages"][-1]["reasoning_content"] == "hidden" + + +def test_inject_path_respects_when_missing_only() -> None: + payload = {"metadata": {"prompt_cache_key": "existing"}} + + changed = inject_path(payload, "metadata.prompt_cache_key", "new", when_missing_only=True) + + assert changed is False + assert payload["metadata"]["prompt_cache_key"] == "existing" + + +def test_inject_path_can_insert_at_final_list_index() -> None: + payload = {"messages": [{"role": "user"}]} + + changed = inject_path(payload, "messages.0", {"role": "system"}, insert=True) + + assert changed is True + assert payload["messages"] == [{"role": "system"}, {"role": "user"}] + + +def test_inject_path_can_insert_into_empty_list() -> None: + payload = {"messages": []} + + changed = inject_path(payload, "messages.0", {"role": "system"}, insert=True) + + assert changed is True + assert payload["messages"] == [{"role": "system"}] + + +def test_inject_path_rejects_wildcards_and_missing_lists() -> None: + with pytest.raises(FieldCachePathError): + inject_path({"choices": []}, "choices.*.message.reasoning_content", "x") + with pytest.raises(FieldCachePathError): + inject_path({}, "messages[-1].reasoning_content", "x") + + +def test_malformed_paths_and_rules_raise_useful_errors() -> None: + with pytest.raises(FieldCachePathError): + parse_path("choices..message") + with pytest.raises(FieldCachePathError): + parse_path("messages[abc]") + with pytest.raises(ValueError): + FieldCacheRule(name="bad/name", source="response", path="x") + assert FieldCacheRule( + name="reasoning_content", + source="response", + path="choices.*.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="messages[-1].reasoning_content"), + ).scope == ("provider", "model", "classifier", "session") diff --git a/tests/test_field_cache_trace.py b/tests/test_field_cache_trace.py new file mode 100644 index 000000000..8dd42e718 --- /dev/null +++ b/tests/test_field_cache_trace.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.adapters import AdapterContext, get_adapter, run_adapter_chain +from rotator_library.field_cache import FieldCacheContext, FieldCacheEngine, FieldCacheInjection, FieldCacheRule +from rotator_library.transaction_logger import TransactionLogger + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_adapter_chain_emits_before_after_trace_entries(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + context = AdapterContext( + provider="openai", + model="gpt-test", + protocol="openai_chat", + credential_id="cred_1", + transport="http", + transaction_logger=logger, + adapter_config={"model_override": {"model": "native"}}, + ) + + result = await run_adapter_chain([get_adapter("model_override")], {"model": "public"}, context, stage="request") + + entries = _trace_entries(logger.log_dir) + assert result["model"] == "native" + assert [entry["pass_name"] for entry in entries] == ["before_adapter_chain", "after_adapter", "after_adapter_chain"] + assert entries[1]["metadata"]["adapter"] == "model_override" + assert entries[1]["metadata"]["changed"] is True + assert entries[1]["credential_id"] == "cred_1" + + +@pytest.mark.asyncio +async def test_field_cache_extract_and_inject_emit_before_after_trace_entries(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="reasoning_content", + source="response", + path="choices.*.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="messages[-1].reasoning_content"), + ) + engine = FieldCacheEngine([rule]) + context = FieldCacheContext(provider="openai", model="gpt-test", session_id="session_1", classifier="global") + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "hidden"}}]}, context, transaction_logger=logger) + updated, _ = await engine.inject("request", {"messages": [{"role": "user"}]}, context, transaction_logger=logger) + + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert updated["messages"][-1]["reasoning_content"] == "hidden" + assert pass_names == [ + "field_cache_extraction_start", + "before_field_cache_extraction", + "after_field_cache_extraction", + "field_cache_extraction_complete", + "field_cache_injection_start", + "before_field_cache_injection", + "after_field_cache_injection", + "field_cache_injection_complete", + ] + assert entries[2]["metadata"]["rule_name"] == "reasoning_content" + assert entries[2]["metadata"]["matched"] == 1 + assert entries[6]["metadata"]["hit"] is True + assert entries[6]["metadata"]["changed"] is True + + +@pytest.mark.asyncio +async def test_stream_sourced_rule_injection_trace_uses_request_direction(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="provider_session_id", + source="stream_event", + path="metadata.provider_session_id", + inject=FieldCacheInjection(target="request", path="metadata.provider_session_id"), + ) + engine = FieldCacheEngine([rule]) + context = FieldCacheContext(provider="openai", model="gpt-test", session_id="session_1", classifier="global") + + await engine.extract("stream_event", {"metadata": {"provider_session_id": "sid_1"}}, context, transaction_logger=logger) + await engine.inject("request", {"metadata": {}}, context, transaction_logger=logger) + + entries = _trace_entries(logger.log_dir) + injection_entries = [entry for entry in entries if "injection" in entry["pass_name"]] + assert {entry["direction"] for entry in injection_entries} == {"request"} + + +@pytest.mark.asyncio +async def test_field_cache_errors_emit_transform_log_error(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="bad_injection", + source="response", + path="choices.*.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="messages.*.reasoning_content"), + ) + engine = FieldCacheEngine([rule]) + context = FieldCacheContext(provider="openai", model="gpt-test", session_id="session_1", classifier="global") + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "hidden"}}]}, context, transaction_logger=logger) + with pytest.raises(Exception): + await engine.inject("request", {"messages": [{"role": "user"}]}, context, transaction_logger=logger) + + entries = _trace_entries(logger.log_dir) + error_entry = next(entry for entry in entries if entry["pass_name"] == "transform_log_error") + assert error_entry["data"]["failed_pass_name"] == "field_cache_inject" + assert error_entry["metadata"]["rule_name"] == "bad_injection" diff --git a/tests/test_gemini_cli_protocol_declarations.py b/tests/test_gemini_cli_protocol_declarations.py new file mode 100644 index 000000000..9364d35df --- /dev/null +++ b/tests/test_gemini_cli_protocol_declarations.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import pytest + +from rotator_library.providers.gemini_cli_provider import GeminiCliProvider + + +@pytest.mark.asyncio +async def test_gemini_cli_declares_gemini_protocol_without_changing_custom_logic() -> None: + provider = GeminiCliProvider() + + assert provider.has_custom_logic() is True + assert provider.get_protocol_name("gemini-3-flash-preview") == "gemini" + assert provider.get_adapter_names("gemini-3-flash-preview") == () + assert provider.supports_native_streaming("gemini-3-flash-preview", operation="generate") is False + assert provider.normalize_native_model("gemini_cli/gemini-3-flash-preview") == "gemini-3-flash-preview" + + +@pytest.mark.asyncio +async def test_gemini_cli_declares_thought_signature_cache_rule() -> None: + provider = GeminiCliProvider() + + rules = provider.get_field_cache_rules("gemini-3-flash-preview") + + assert len(rules) == 1 + assert rules[0].name == "gemini_cli_thought_signature" + assert rules[0].path == "candidates.*.content.parts.*.thoughtSignature" + assert rules[0].inject.path == "metadata.thoughtSignatures" + assert rules[0].scope == ("provider", "model", "credential", "session") diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py new file mode 100644 index 000000000..3db53fd73 --- /dev/null +++ b/tests/test_native_provider_executor.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.adapters import PayloadAdapter, register_adapter +from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule +from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from rotator_library.protocols import ProtocolError +from rotator_library.transaction_logger import TransactionLogger + + +class FakeHTTPResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeHTTPClient: + def __init__(self, response): + self.response = response + self.calls = [] + + async def post(self, endpoint, *, headers, json): + self.calls.append({"endpoint": endpoint, "headers": headers, "json": json}) + return FakeHTTPResponse(self.response) + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="reasoning", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.reasoning_content"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + headers={"authorization": "Bearer test"}, + adapter_names=("model_override",), + adapter_config={"model_override": {"model": "provider/gpt-test"}}, + field_cache_rules=(rule,), + transaction_logger=logger, + ) + response = { + "id": "chat_1", + "model": "provider/gpt-test", + "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "hidden"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_cost": 0.01}, + } + client = FakeHTTPClient(response) + + result = await NativeProviderExecutor().execute( + {"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]}, + context, + NativeHTTPTransport(client), + ) + + assert result["id"] == "chat_1" + assert client.calls[0]["json"]["model"] == "provider/gpt-test" + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "hidden" not in trace_text + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "native_protocol_selected" in pass_names + assert "raw_native_client_request" in pass_names + assert "parsed_native_unified_request" in pass_names + assert "built_native_provider_request" in pass_names + assert "after_request_adapter_chain" in pass_names + assert "field_cache_injection_start" in pass_names + assert "after_field_cache_injection" in pass_names + assert "field_cache_injection_complete" in pass_names + assert "native_provider_request" in pass_names + assert "raw_native_provider_response" in pass_names + assert "parsed_native_unified_response" in pass_names + assert "formatted_native_response" in pass_names + assert "after_response_adapter_chain" in pass_names + assert "field_cache_extraction_start" in pass_names + assert "after_field_cache_extraction" in pass_names + assert "field_cache_extraction_complete" in pass_names + usage_entries = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "usage_accounting_summary"] + assert usage_entries[-1]["data"]["cost"]["provider_reported_cost"] == 0.01 + assert "final_client_response" in pass_names + + +@pytest.mark.asyncio +async def test_native_provider_default_field_cache_persists_across_requests() -> None: + rule = FieldCacheRule( + name="reasoning", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.reasoning_content"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + first_client = FakeHTTPClient( + { + "id": "chat_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "cached-state"}}], + } + ) + second_client = FakeHTTPClient( + { + "id": "chat_2", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "ok"}}], + } + ) + executor = NativeProviderExecutor() + + await executor.execute({"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]}, context, NativeHTTPTransport(first_client)) + await executor.execute({"model": "gpt-test", "messages": [{"role": "user", "content": "again"}]}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["reasoning_content"] == "cached-state" + + +@pytest.mark.asyncio +async def test_native_provider_trace_redacts_configured_injection_paths(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="state", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + transaction_logger=logger, + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "opaque-state"}}]})), + ) + await executor.execute( + {"model": "gpt-test", "messages": [{"role": "user", "content": "again"}]}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_2", "choices": [{"message": {"role": "assistant", "content": "ok"}}]})), + ) + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "opaque-state" not in trace_text + assert '"state": "[REDACTED]"' in trace_text + + +@pytest.mark.asyncio +async def test_native_provider_executor_logs_transform_errors(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + adapter_names=("missing_adapter",), + transaction_logger=logger, + ) + + with pytest.raises(KeyError): + await NativeProviderExecutor().execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeHTTPClient({}))) + + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "transform_log_error" in pass_names + + +@pytest.mark.asyncio +async def test_native_provider_executor_rejects_unsupported_operation_before_transport() -> None: + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + operation="embeddings", + endpoint="https://example.test/chat", + ) + client = FakeHTTPClient({"id": "should_not_call"}) + + with pytest.raises(ProtocolError, match="unsupported operation"): + await NativeProviderExecutor().execute( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(client), + ) + + assert client.calls == [] + + +@pytest.mark.asyncio +async def test_native_runtime_executes_unified_request_source_and_target() -> None: + rule = FieldCacheRule( + name="unified_request_state", + source="unified_request", + path="metadata.client_state", + inject=FieldCacheInjection(target="unified_request", path="metadata.cached_client_state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "gpt-test", "messages": [], "metadata": {"client_state": "state-a"}}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": []})), + ) + second_client = FakeHTTPClient({"id": "chat_2", "choices": []}) + await executor.execute( + {"model": "gpt-test", "messages": [], "metadata": {}}, + context, + NativeHTTPTransport(second_client), + ) + + assert second_client.calls[0]["json"]["metadata"]["cached_client_state"] == "state-a" + + +@pytest.mark.asyncio +async def test_native_runtime_executes_unified_response_source() -> None: + rule = FieldCacheRule( + name="response_object", + source="unified_response", + path="metadata.object", + inject=FieldCacheInjection(target="request", path="metadata.cached_object"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "object": "chat.completion", "choices": []}))) + second_client = FakeHTTPClient({"id": "chat_2", "choices": []}) + await executor.execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["cached_object"] == "chat.completion" + + +@pytest.mark.asyncio +async def test_native_response_source_extracts_raw_provider_response_before_client_formatting() -> None: + rule = FieldCacheRule( + name="anthropic_signature", + source="response", + path="content.*.signature", + mode="all", + inject=FieldCacheInjection(target="request", path="metadata.signatures", as_list=True), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="claude_code", + model="claude-sonnet-4-5", + protocol_name="anthropic_messages", + client_protocol_name="openai_chat", + endpoint="https://example.test/messages", + operation="messages", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "claude-sonnet-4-5", "messages": [], "max_tokens": 1}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "msg_1", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "ok", "signature": "sig-1"}]})), + ) + second_client = FakeHTTPClient({"id": "msg_2", "type": "message", "role": "assistant", "content": []}) + await executor.execute({"model": "claude-sonnet-4-5", "messages": [], "max_tokens": 1}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["signatures"] == ["sig-1"] + + +@pytest.mark.asyncio +async def test_native_runtime_executes_metadata_target_before_adapters() -> None: + class MetadataEchoAdapter(PayloadAdapter): + name = "test_metadata_echo" + supported_stages = ("request",) + + async def transform_request(self, payload, context): + payload.setdefault("metadata", {})["adapter_seen_state"] = context.metadata.get("cached_state") + return payload + + register_adapter(MetadataEchoAdapter, replace=True) + rule = FieldCacheRule( + name="metadata_state", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="metadata", path="cached_state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + adapter_names=("test_metadata_echo",), + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "meta-state"}}]}))) + second_client = FakeHTTPClient({"id": "chat_2", "choices": []}) + await executor.execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["adapter_seen_state"] == "meta-state" + + +@pytest.mark.asyncio +async def test_native_adapter_generic_traces_are_suppressed_for_field_cache_safety(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="hidden_state", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.hidden_state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + adapter_names=("model_override",), + adapter_config={"model_override": {"model": "provider/gpt-test"}}, + field_cache_rules=(rule,), + transaction_logger=logger, + ) + + await NativeProviderExecutor().execute( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "opaque-state"}}]})), + ) + + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert "before_adapter_chain" not in pass_names + assert "after_adapter" not in pass_names + assert "after_request_adapter_chain" in pass_names + assert "opaque-state" not in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + + +@pytest.mark.asyncio +async def test_native_runtime_executes_request_source() -> None: + rule = FieldCacheRule( + name="request_state", + source="request", + path="metadata.outgoing_state", + inject=FieldCacheInjection(target="unified_request", path="metadata.reused_request_state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "gpt-test", "messages": [], "metadata": {"outgoing_state": "request-state"}}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": []})), + ) + second_client = FakeHTTPClient({"id": "chat_2", "choices": []}) + await executor.execute({"model": "gpt-test", "messages": [], "metadata": {}}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["reused_request_state"] == "request-state" + + +@pytest.mark.asyncio +async def test_native_metadata_injection_trace_redacts_configured_paths(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="metadata_secret", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="metadata", path="cached_blob"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + transaction_logger=logger, + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "metadata-secret"}}]})), + ) + await executor.execute( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_2", "choices": []})), + ) + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "metadata-secret" not in trace_text + metadata_entries = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "after_metadata_field_cache_injection"] + assert metadata_entries[-1]["data"]["cached_blob"] == "[REDACTED]" + + +@pytest.mark.asyncio +async def test_native_provider_stream_rejects_unsupported_operation_before_transport() -> None: + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + operation="embeddings", + endpoint="https://example.test/chat", + ) + client = FakeHTTPClient({"id": "should_not_call"}) + + with pytest.raises(ProtocolError, match="unsupported operation"): + async for _ in NativeProviderExecutor().stream( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(client), + ): + pass + + assert client.calls == [] diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py new file mode 100644 index 000000000..de3d0a5c8 --- /dev/null +++ b/tests/test_native_provider_streaming.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.adapters import PayloadAdapter, register_adapter +from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule +from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from rotator_library.transaction_logger import TransactionLogger + + +class FakeStreamingClient: + def __init__(self, chunks): + self.chunks = chunks + self.calls = [] + + async def stream_json_lines(self, endpoint, *, headers, json): + self.calls.append({"endpoint": endpoint, "headers": headers, "json": json}) + for chunk in self.chunks: + yield chunk + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_native_provider_stream_traces_and_yields_formatted_events(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=( + FieldCacheRule(name="stream_reasoning", source="stream_event", path="raw.choices.0.delta.reasoning_content", allow_missing_session=True), + FieldCacheRule(name="stream_vendor_state", source="stream_event", path="raw.choices.0.delta.vendor_state", allow_missing_session=True), + ), + transaction_logger=logger, + ) + chunks = [ + {"choices": [{"delta": {"content": "hi", "reasoning_content": "hidden", "vendor_state": "opaque-vendor-state"}}]}, + "[DONE]", + ] + client = FakeStreamingClient(chunks) + + events = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(client))] + + assert events == [chunks[0]] + assert client.calls[0]["json"]["stream"] is True + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "native_provider_stream_request" in pass_names + assert pass_names.count("raw_native_provider_stream_chunk") == 2 + assert pass_names.count("parsed_native_stream_event") == 2 + assert "after_field_cache_extraction" in pass_names + assert "after_field_cache_stream_extraction" in pass_names + assert pass_names.count("formatted_client_stream_event") == 1 + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "opaque-vendor-state" not in trace_text + + +@pytest.mark.asyncio +async def test_native_provider_stream_traces_usage_accounting_summary(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + transaction_logger=logger, + ) + chunks = [ + { + "choices": [], + "usage": { + "prompt_tokens": 2, + "completion_tokens": 3, + "cost_details": {"total_cost": 0.04, "currency": "USD", "source": "stream_usage"}, + }, + }, + "[DONE]", + ] + + _ = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeStreamingClient(chunks)))] + + entries = _trace_entries(logger.log_dir) + summaries = [entry for entry in entries if entry["pass_name"] == "usage_accounting_summary"] + assert summaries + assert summaries[-1]["data"]["usage"]["input_tokens"] == 2 + assert summaries[-1]["data"]["usage"]["completion_tokens"] == 3 + assert summaries[-1]["data"]["usage"]["provider_reported_cost"] == 0.04 + assert summaries[-1]["data"]["cost"]["provider_reported_cost"] == 0.04 + + +@pytest.mark.asyncio +async def test_native_provider_stream_preserves_earlier_cost_when_later_usage_arrives(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + transaction_logger=logger, + ) + chunks = [ + {"choices": [], "usage": {"cost_details": {"total_cost": 0.07, "source": "early_cost"}}}, + {"choices": [], "usage": {"prompt_tokens": 2, "completion_tokens": 3}}, + "[DONE]", + ] + + _ = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeStreamingClient(chunks)))] + + summaries = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "usage_accounting_summary"] + assert summaries[-1]["data"]["usage"]["input_tokens"] == 2 + assert summaries[-1]["data"]["usage"]["completion_tokens"] == 3 + assert summaries[-1]["data"]["usage"]["provider_reported_cost"] == 0.07 + + +@pytest.mark.asyncio +async def test_native_provider_stream_preserves_cost_when_later_raw_usage_arrives(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + transaction_logger=logger, + ) + chunks = [ + {"choices": [], "usage": {"cost_details": {"total_cost": 0.08, "source": "early_cost"}}}, + {"choices": [], "prompt_tokens": 2, "completion_tokens": 3}, + "[DONE]", + ] + + _ = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeStreamingClient(chunks)))] + + summaries = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "usage_accounting_summary"] + assert summaries[-1]["data"]["usage"]["input_tokens"] == 2 + assert summaries[-1]["data"]["usage"]["completion_tokens"] == 3 + assert summaries[-1]["data"]["usage"]["provider_reported_cost"] == 0.08 + + +@pytest.mark.asyncio +async def test_native_provider_stream_logs_errors(tmp_path) -> None: + class BrokenClient: + async def stream_json_lines(self, endpoint, *, headers, json): + raise RuntimeError("broken stream") + yield None + + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext(provider="native", model="gpt-test", protocol_name="openai_chat", endpoint="https://example.test/chat", transaction_logger=logger) + + with pytest.raises(RuntimeError): + [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(BrokenClient()))] + + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "transform_log_error" in pass_names + + +@pytest.mark.asyncio +async def test_native_provider_stream_runs_stream_event_adapter_chain(tmp_path) -> None: + class StreamTextAdapter(PayloadAdapter): + name = "test_stream_text_adapter" + supported_stages = ("stream_event",) + + async def transform_stream_event(self, payload, context): + payload.delta.content[0].text = "adapted" + return payload + + register_adapter(StreamTextAdapter, replace=True) + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + adapter_names=("test_stream_text_adapter",), + transaction_logger=logger, + ) + + events = [ + event + async for event in NativeProviderExecutor().stream( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeStreamingClient([{"choices": [{"delta": {"content": "before"}}]}, "[DONE]"])), + ) + ] + + assert events[0]["choices"][0]["delta"]["content"] == "adapted" + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "after_stream_event_adapter_chain" in pass_names + + +@pytest.mark.asyncio +async def test_native_cross_protocol_stream_formats_openai_chat_sse() -> None: + context = NativeProviderContext( + provider="claude_code", + model="claude-sonnet-4-5", + protocol_name="anthropic_messages", + client_protocol_name="openai_chat", + endpoint="https://example.test/messages", + operation="messages", + ) + chunks = [ + {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "hi"}}, + "[DONE]", + ] + + events = [event async for event in NativeProviderExecutor().stream({"model": "claude-sonnet-4-5", "messages": [], "max_tokens": 1}, context, NativeHTTPTransport(FakeStreamingClient(chunks)))] + + assert events[0].startswith("data: ") + payload = json.loads(events[0][len("data: ") :].strip()) + assert payload["object"] == "chat.completion.chunk" + assert payload["choices"][0]["delta"]["content"] == "hi" + assert "content_block_delta" not in events[0] + + +@pytest.mark.asyncio +async def test_native_provider_stream_extracts_unified_stream_events_for_later_requests() -> None: + rule = FieldCacheRule( + name="unified_stream_text", + source="unified_stream_event", + path="delta.content.0.text", + inject=FieldCacheInjection(target="request", path="metadata.cached_stream_text"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + _ = [ + event + async for event in executor.stream( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeStreamingClient([{"choices": [{"delta": {"content": "stream-state"}}]}, "[DONE]"])), + ) + ] + second_client = FakeStreamingClient(["[DONE]"]) + _ = [ + event + async for event in executor.stream( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(second_client), + ) + ] + + assert second_client.calls[0]["json"]["metadata"]["cached_stream_text"] == "stream-state" diff --git a/tests/test_native_streaming_transport_seam.py b/tests/test_native_streaming_transport_seam.py new file mode 100644 index 000000000..cdb5fde1b --- /dev/null +++ b/tests/test_native_streaming_transport_seam.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import pytest + +from rotator_library.native_provider.http import NativeHTTPTransport +from rotator_library.native_provider.streaming import native_stream_event_from_formatted, provider_supports_native_streaming +from rotator_library.providers.provider_interface import ProviderInterface + + +class DefaultProvider(ProviderInterface): + async def get_models(self, api_key, client): + return [] + + +class StreamingProvider(DefaultProvider): + native_streaming_supported = True + + +def test_provider_native_streaming_support_defaults_false() -> None: + assert provider_supports_native_streaming(DefaultProvider(), model="gpt-test") is False + assert provider_supports_native_streaming(StreamingProvider(), model="gpt-test") is True + + +def test_native_formatted_sse_chunk_uses_common_stream_event_seam() -> None: + event = native_stream_event_from_formatted('data: {"choices":[{"delta":{"content":"hi"}}]}\n\n') + + assert event.visible_output is True + assert event.event_type == "delta" + + +class FakeStreamResponse: + def __init__(self, lines): + self.lines = lines + self.raised = False + + def raise_for_status(self): + self.raised = True + + async def aiter_lines(self): + for line in self.lines: + yield line + + +class FakeStreamContext: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class FakeHttpxClient: + def __init__(self, response): + self.response = response + self.calls = [] + + def stream(self, method, endpoint, *, headers, json): + self.calls.append((method, endpoint, headers, json)) + return FakeStreamContext(self.response) + + +@pytest.mark.asyncio +async def test_native_http_transport_streams_httpx_lines() -> None: + response = FakeStreamResponse(["", ": heartbeat", 'data: {"delta":"hi"}', "data: [DONE]"]) + client = FakeHttpxClient(response) + + chunks = [chunk async for chunk in NativeHTTPTransport(client).stream_json_lines("https://example.test/stream", headers={"h": "v"}, payload={"p": True})] + + assert client.calls == [("POST", "https://example.test/stream", {"h": "v"}, {"p": True})] + assert response.raised is True + assert chunks == [{"delta": "hi"}, "[DONE]"] + + +class FakeByteResponse: + def raise_for_status(self): + pass + + async def aiter_bytes(self): + yield b'data: {"a": 1}\n\n' + yield b'data: [DONE]\n' + + +@pytest.mark.asyncio +async def test_native_http_transport_streams_httpx_bytes() -> None: + chunks = [chunk async for chunk in NativeHTTPTransport(FakeHttpxClient(FakeByteResponse())).stream_json_lines("url", headers={}, payload={})] + + assert chunks == [{"a": 1}, "[DONE]"] diff --git a/tests/test_native_usage_accounting.py b/tests/test_native_usage_accounting.py new file mode 100644 index 000000000..62116a6d4 --- /dev/null +++ b/tests/test_native_usage_accounting.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from rotator_library.transaction_logger import TransactionLogger + + +class FakeNativeTransport(NativeHTTPTransport): + def __init__(self): + pass + + async def post_json(self, endpoint, *, headers, payload): + return { + "id": "chatcmpl_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 6, "completion_tokens_details": {"reasoning_tokens": 2}}, + } + + +@pytest.mark.asyncio +async def test_native_executor_traces_normalized_usage(tmp_path) -> None: + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="openai", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + headers={}, + transaction_logger=logger, + ) + + await NativeProviderExecutor().execute({"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]}, context, FakeNativeTransport()) + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + usage_entries = [entry for entry in entries if entry["pass_name"] == "usage_accounting_summary"] + assert usage_entries[-1]["data"]["usage"]["completion_tokens"] == 4 + assert usage_entries[-1]["data"]["usage"]["reasoning_tokens"] == 2 diff --git a/tests/test_protocol_anthropic_messages.py b/tests/test_protocol_anthropic_messages.py new file mode 100644 index 000000000..9cc0aaf54 --- /dev/null +++ b/tests/test_protocol_anthropic_messages.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import json + +from rotator_library.protocols import get_protocol, list_protocols + + +def test_anthropic_build_uses_mutated_unified_thinking_signature() -> None: + adapter = get_protocol("anthropic_messages") + raw = { + "model": "anthropic/claude-test", + "messages": [{"role": "assistant", "content": [{"type": "thinking", "thinking": "before", "signature": "sig_1"}]}], + } + + unified = adapter.parse_request(raw) + unified.messages[0].content[0].reasoning.text = "after" + rebuilt = adapter.build_request(unified) + + assert rebuilt["messages"][0]["content"][0]["thinking"] == "after" + assert rebuilt["messages"][0]["content"][0]["signature"] == "sig_1" + + +def test_anthropic_messages_protocol_is_discovered_with_aliases() -> None: + assert "anthropic_messages" in list_protocols() + assert get_protocol("anthropic") is get_protocol("anthropic_messages") + assert get_protocol("messages") is get_protocol("anthropic_messages") + + +def test_anthropic_request_round_trip_preserves_thinking_tools_and_cache_metadata() -> None: + adapter = get_protocol("anthropic_messages") + raw = { + "model": "anthropic/claude-test", + "system": [{"type": "text", "text": "system"}], + "max_tokens": 100, + "stream": True, + "messages": [ + {"role": "user", "content": "hello"}, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "worked", "signature": "sig_1"}, + {"type": "text", "text": "answer"}, + {"type": "tool_use", "id": "toolu_1", "name": "lookup", "input": {"q": "x"}}, + ], + }, + {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "toolu_1", "content": "result"}]}, + ], + "tools": [{"name": "lookup", "description": "Lookup", "input_schema": {"type": "object"}}], + "vendor_extension": {"kept": True}, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.model == "anthropic/claude-test" + assert unified.system[0].text == "system" + assert unified.messages[1].reasoning[0].signature == "sig_1" + assert unified.messages[1].tool_calls[0].name == "lookup" + assert unified.messages[2].content[0].tool_result.tool_call_id == "toolu_1" + assert unified.tools[0].input_schema == {"type": "object"} + assert rebuilt["vendor_extension"] == {"kept": True} + assert isinstance(rebuilt["system"], list) + assert rebuilt["system"][0]["text"] == "system" + assert rebuilt["messages"][1]["content"][0]["signature"] == "sig_1" + + +def test_anthropic_system_preserves_block_metadata() -> None: + adapter = get_protocol("anthropic_messages") + raw = { + "model": "anthropic/claude-test", + "system": [{"type": "text", "text": "system", "cache_control": {"type": "ephemeral"}}], + "messages": [{"role": "user", "content": "hello"}], + } + + unified = adapter.parse_request(raw) + unified.system[0].text = "updated" + rebuilt = adapter.build_request(unified) + + assert rebuilt["system"] == [{"type": "text", "text": "updated", "cache_control": {"type": "ephemeral"}}] + + +def test_anthropic_response_extracts_content_and_usage() -> None: + adapter = get_protocol("anthropic_messages") + raw = { + "id": "msg_1", + "type": "message", + "role": "assistant", + "model": "anthropic/claude-test", + "content": [ + {"type": "redacted_thinking", "signature": "sig_2"}, + {"type": "text", "text": "answer"}, + ], + "stop_reason": "end_turn", + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "cache_creation_input_tokens": 2, + "cache_read_input_tokens": 3, + }, + } + + unified = adapter.parse_response(raw) + + assert unified.id == "msg_1" + assert unified.messages[0].content[1].text == "answer" + assert unified.messages[0].reasoning[0].signature == "sig_2" + assert unified.messages[0].reasoning[0].redacted is True + assert unified.usage is not None + assert unified.usage.input_tokens == 10 + assert unified.usage.output_tokens == 5 + assert unified.usage.cache_write_tokens == 2 + assert unified.usage.cache_read_tokens == 3 + + +def test_anthropic_stream_event_parses_text_delta_and_usage() -> None: + adapter = get_protocol("anthropic_messages") + text_event = {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hi"}} + usage_event = {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 7}} + + parsed_text = adapter.parse_stream_event(f"data: {json.dumps(text_event)}\n\n") + parsed_usage = adapter.parse_stream_event(usage_event) + + assert parsed_text.type == "content_block_delta" + assert parsed_text.delta is not None + assert parsed_text.delta.content[0].text == "Hi" + assert parsed_usage.usage is not None + assert parsed_usage.usage.output_tokens == 7 diff --git a/tests/test_protocol_gemini.py b/tests/test_protocol_gemini.py new file mode 100644 index 000000000..709a8afab --- /dev/null +++ b/tests/test_protocol_gemini.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import json + +from rotator_library.protocols import get_protocol, list_protocols + + +def test_gemini_protocol_is_discovered_with_aliases() -> None: + assert "gemini" in list_protocols() + assert get_protocol("google_gemini") is get_protocol("gemini") + assert get_protocol("generate_content") is get_protocol("gemini") + + +def test_gemini_request_round_trip_preserves_parts_tools_and_settings() -> None: + adapter = get_protocol("gemini") + raw = { + "model": "gemini/gemini-test", + "systemInstruction": {"parts": [{"text": "system"}]}, + "contents": [ + {"role": "user", "parts": [{"text": "hello"}, {"inlineData": {"mimeType": "image/png", "data": "abc"}}]}, + { + "role": "model", + "parts": [ + {"text": "thought", "thought": True, "thoughtSignature": "sig_1"}, + {"functionCall": {"name": "lookup", "args": {"q": "x"}}}, + ], + }, + {"role": "user", "parts": [{"functionResponse": {"name": "lookup", "response": {"value": 1}}}]}, + ], + "tools": [{"functionDeclarations": [{"name": "lookup", "description": "Lookup", "parameters": {"type": "object"}}]}], + "generationConfig": {"temperature": 0.3}, + "safetySettings": [{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}], + "vendor_extension": {"kept": True}, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.model == "gemini/gemini-test" + assert unified.system[0].text == "system" + assert unified.messages[1].role == "assistant" + assert unified.messages[1].reasoning[0].signature == "sig_1" + assert unified.messages[1].tool_calls[0].name == "lookup" + assert unified.messages[2].content[0].tool_result.content == {"value": 1} + assert unified.tools[0].name == "lookup" + assert rebuilt["generationConfig"] == {"temperature": 0.3} + assert rebuilt["safetySettings"][0]["threshold"] == "BLOCK_NONE" + assert rebuilt["vendor_extension"] == {"kept": True} + + +def test_gemini_parses_multiple_function_declarations() -> None: + adapter = get_protocol("gemini") + raw = { + "contents": [], + "tools": [ + { + "functionDeclarations": [ + {"name": "one", "parameters": {"type": "object"}}, + {"name": "two", "parameters": {"type": "object"}}, + ] + } + ], + } + + unified = adapter.parse_request(raw) + unified.tools[1].description = "Second tool" + rebuilt = adapter.build_request(unified) + + assert [tool.name for tool in unified.tools] == ["one", "two"] + assert len(rebuilt["tools"]) == 1 + assert [declaration["name"] for declaration in rebuilt["tools"][0]["functionDeclarations"]] == ["one", "two"] + assert rebuilt["tools"][0]["functionDeclarations"][1]["description"] == "Second tool" + + +def test_gemini_response_extracts_usage_and_thought_signature() -> None: + adapter = get_protocol("gemini") + raw = { + "responseId": "resp_1", + "modelVersion": "gemini-test-001", + "candidates": [ + { + "finishReason": "STOP", + "content": {"role": "model", "parts": [{"text": "answer", "thought": True, "thoughtSignature": "sig_2"}]}, + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 5, + "thoughtsTokenCount": 2, + "cachedContentTokenCount": 3, + "totalTokenCount": 17, + }, + } + + unified = adapter.parse_response(raw) + + assert unified.id == "resp_1" + assert unified.stop_reason == "STOP" + assert unified.messages[0].content[0].text == "answer" + assert unified.messages[0].reasoning[0].signature == "sig_2" + assert unified.usage is not None + assert unified.usage.input_tokens == 10 + assert unified.usage.output_tokens == 5 + assert unified.usage.reasoning_tokens == 2 + assert unified.usage.cache_read_tokens == 3 + + +def test_gemini_stream_event_parses_candidate_delta() -> None: + adapter = get_protocol("gemini") + event = { + "candidates": [{"content": {"role": "model", "parts": [{"text": "Hi"}]}, "finishReason": None}], + "usageMetadata": {"promptTokenCount": 1, "candidatesTokenCount": 1, "totalTokenCount": 2}, + } + + parsed = adapter.parse_stream_event(f"data: {json.dumps(event)}\n\n") + + assert parsed.type == "message_delta" + assert parsed.delta is not None + assert parsed.delta.content[0].text == "Hi" + assert parsed.usage is not None + assert parsed.usage.total_tokens == 2 diff --git a/tests/test_protocol_ollama_mcp.py b/tests/test_protocol_ollama_mcp.py new file mode 100644 index 000000000..7c859203b --- /dev/null +++ b/tests/test_protocol_ollama_mcp.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from rotator_library.protocols import OPERATION_EMBEDDINGS, OPERATION_MCP, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, ProtocolContext, get_protocol + + +def test_ollama_chat_generate_and_stream_shapes() -> None: + adapter = get_protocol("ollama") + chat = adapter.parse_request({"model": "llama3", "messages": [{"role": "user", "content": "hi"}], "stream": True}) + generate = adapter.parse_request({"model": "llama3", "prompt": "write", "options": {"temperature": 0.1}}) + explicit_stream_false = adapter.parse_request({"model": "llama3", "prompt": "write", "stream": False}) + embeddings = adapter.parse_request({"model": "llama3", "operation": OPERATION_EMBEDDINGS, "prompt": "embed this"}) + contextual_embeddings = adapter.parse_request( + {"model": "llama3", "prompt": "embed via context"}, + ProtocolContext(provider_options={"operation": OPERATION_EMBEDDINGS}), + ) + new_embeddings = adapter.parse_request({"model": "llama3", "input": "embed input"}) + final_chat = adapter.parse_response({"model": "llama3", "message": {"role": "assistant", "content": "hello"}, "done": True}) + final_generate = adapter.parse_response({"model": "llama3", "response": "generated", "done": True}) + embeddings_response = adapter.parse_response({"model": "llama3", "embeddings": [[0.1, 0.2]]}) + chunk = adapter.parse_stream_event('{"model":"llama3","response":"he","done":false,"eval_count":2}') + + assert chat.operation == OPERATION_OLLAMA_CHAT + assert adapter.build_request(chat)["messages"][0]["content"] == "hi" + assert generate.operation == OPERATION_OLLAMA_GENERATE + assert adapter.build_request(generate)["prompt"] == "write" + assert "stream" not in adapter.build_request(generate) + assert adapter.build_request(explicit_stream_false)["stream"] is False + assert embeddings.operation == OPERATION_EMBEDDINGS + assert adapter.build_request(embeddings)["prompt"] == "embed this" + assert contextual_embeddings.operation == OPERATION_EMBEDDINGS + assert adapter.build_request(contextual_embeddings)["prompt"] == "embed via context" + assert adapter.build_request(new_embeddings)["input"] == "embed input" + assert final_chat.messages[0].content[0].text == "hello" + assert final_generate.output == ["generated"] + assert final_generate.messages == [] + assert embeddings_response.data == [[0.1, 0.2]] + assert chunk.delta is not None + assert chunk.delta.content[0].text == "he" + assert chunk.usage is not None + assert chunk.usage.output_tokens == 2 + + +def test_ollama_format_response_uses_mutated_unified_fields() -> None: + adapter = get_protocol("ollama") + chat = adapter.parse_response({"model": "llama3", "message": {"role": "assistant", "content": "before"}, "done": True}) + chat.messages[0].content[0].text = "after" + + generate = adapter.parse_response({"model": "llama3", "response": "before", "eval_count": 2}) + generate.output[0] = "after" + generate.usage.output_tokens = 5 + + embeddings = adapter.parse_response({"model": "llama3", "embedding": [0.1, 0.2]}) + embeddings.data = [0.3, 0.4] + + assert adapter.format_response(chat)["message"]["content"] == "after" + assert adapter.format_response(generate)["response"] == "after" + assert adapter.format_response(generate)["eval_count"] == 5 + assert adapter.format_response(embeddings)["embedding"] == [0.3, 0.4] + + +def test_mcp_jsonrpc_round_trip_and_error_preservation() -> None: + adapter = get_protocol("mcp") + request = {"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "lookup"}} + unified = adapter.parse_request(request) + rebuilt = adapter.build_request(unified) + response = adapter.parse_response({"jsonrpc": "2.0", "id": 1, "result": {"content": []}}) + error = adapter.parse_response({"jsonrpc": "2.0", "id": 1, "error": {"code": -1, "message": "failed"}}) + + assert unified.operation == OPERATION_MCP + assert rebuilt == request + assert adapter.format_response(response)["result"] == {"content": []} + assert adapter.format_response(error)["error"]["message"] == "failed" + + +def test_mcp_preserves_notifications_and_falsey_params() -> None: + adapter = get_protocol("mcp") + notification = {"jsonrpc": "2.0", "method": "notifications/initialized"} + false_params = {"jsonrpc": "2.0", "id": 0, "method": "tools/call", "params": False} + + assert adapter.build_request(adapter.parse_request(notification)) == notification + assert adapter.build_request(adapter.parse_request(false_params)) == false_params + + +def test_mcp_preserves_jsonrpc_batches() -> None: + adapter = get_protocol("mcp") + batch = [ + {"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}, + {"jsonrpc": "2.0", "id": 2, "method": "resources/list", "params": None}, + ] + response_batch = [ + {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}, + {"jsonrpc": "2.0", "id": 2, "error": {"code": -1, "message": "failed"}}, + ] + + assert adapter.build_request(adapter.parse_request(batch)) == batch + assert adapter.format_response(adapter.parse_response(response_batch)) == response_batch diff --git a/tests/test_protocol_openai_chat.py b/tests/test_protocol_openai_chat.py new file mode 100644 index 000000000..7bbb5dfab --- /dev/null +++ b/tests/test_protocol_openai_chat.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import json + +from rotator_library.protocols import get_protocol, list_protocols + + +def test_openai_chat_protocol_is_discovered_with_aliases() -> None: + assert "openai_chat" in list_protocols() + assert get_protocol("openai") is get_protocol("openai_chat") + assert get_protocol("chat_completions") is get_protocol("openai_chat") + + +def test_litellm_fallback_protocol_preserves_raw_payload() -> None: + adapter = get_protocol("litellm_fallback") + raw = {"model": "custom/model", "messages": [{"role": "user", "content": "hi"}], "vendor_flag": True} + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.extra["messages"] == raw["messages"] + assert rebuilt == raw + + +def test_openai_chat_request_round_trip_preserves_tools_and_reasoning() -> None: + adapter = get_protocol("openai_chat") + raw = { + "model": "openai/gpt-test", + "stream": True, + "temperature": 0.2, + "messages": [ + {"role": "system", "content": "system"}, + {"role": "developer", "content": "dev"}, + {"role": "user", "content": [{"type": "text", "text": "hello"}]}, + { + "role": "assistant", + "content": "thinking done", + "reasoning_content": "internal chain", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": "{\"q\":\"x\"}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup a value", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + }, + } + ], + "vendor_extension": {"kept": True}, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.model == "openai/gpt-test" + assert unified.stream is True + assert unified.generation_params["temperature"] == 0.2 + assert unified.messages[3].reasoning[0].text == "internal chain" + assert unified.messages[3].tool_calls[0].name == "lookup" + assert unified.tools[0].input_schema["properties"]["q"]["type"] == "string" + assert rebuilt["vendor_extension"] == {"kept": True} + assert rebuilt["messages"][3]["reasoning_content"] == "internal chain" + + +def test_openai_chat_build_uses_mutated_unified_fields_and_preserves_block_extras() -> None: + adapter = get_protocol("openai_chat") + raw = { + "model": "openai/gpt-test", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "before", "cache_control": {"type": "ephemeral"}}], + } + ], + } + + unified = adapter.parse_request(raw) + unified.messages[0].content[0].text = "after" + rebuilt = adapter.build_request(unified) + + assert isinstance(rebuilt["messages"][0]["content"], list) + assert rebuilt["messages"][0]["content"][0]["text"] == "after" + assert rebuilt["messages"][0]["content"][0]["cache_control"] == {"type": "ephemeral"} + + +def test_openai_chat_response_extracts_usage_cost_and_reasoning() -> None: + adapter = get_protocol("openai_chat") + raw = { + "id": "chatcmpl_1", + "object": "chat.completion", + "created": 123, + "model": "openai/gpt-test", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "answer", + "reasoning_content": "reasoned", + }, + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 18, + "prompt_tokens_details": {"cached_tokens": 3, "cache_creation_tokens": 2}, + "completion_tokens_details": {"reasoning_tokens": 3}, + "cost_details": {"total_cost": 0.01, "currency": "USD"}, + }, + } + + unified = adapter.parse_response(raw) + + assert unified.id == "chatcmpl_1" + assert unified.messages[0].content[0].text == "answer" + assert unified.messages[0].reasoning[0].text == "reasoned" + assert unified.usage is not None + assert unified.usage.input_tokens == 10 + assert unified.usage.output_tokens == 5 + assert unified.usage.reasoning_tokens == 3 + assert unified.usage.cache_read_tokens == 3 + assert unified.usage.cache_write_tokens == 2 + assert unified.usage.cost is not None + assert unified.usage.cost.provider_reported_cost == 0.01 + + formatted = adapter.format_response(unified) + assert formatted["usage"]["prompt_tokens"] == 10 + assert formatted["usage"]["completion_tokens"] == 5 + assert formatted["usage"]["total_tokens"] == 18 + assert formatted["usage"]["prompt_tokens_details"] == {"cached_tokens": 3, "cache_creation_tokens": 2} + assert formatted["usage"]["completion_tokens_details"] == {"reasoning_tokens": 3} + assert formatted["usage"]["cost_details"]["total_cost"] == 0.01 + assert "input_tokens" not in formatted["usage"] + assert "output_tokens" not in formatted["usage"] + + +def test_openai_legacy_function_call_is_unified_and_round_trips() -> None: + adapter = get_protocol("openai_chat") + raw = { + "model": "openai/gpt-test", + "messages": [ + { + "role": "assistant", + "content": None, + "function_call": {"name": "lookup", "arguments": "{\"q\":\"x\"}"}, + } + ], + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.messages[0].tool_calls[0].name == "lookup" + assert unified.messages[0].tool_calls[0].extra["legacy_function_call"] is True + unified.messages[0].tool_calls[0].arguments = {"q": "changed"} + assert rebuilt["messages"][0]["function_call"] == {"name": "lookup", "arguments": "{\"q\":\"x\"}"} + assert "tool_calls" not in rebuilt["messages"][0] + + rebuilt = adapter.build_request(unified) + assert rebuilt["messages"][0]["function_call"] == {"name": "lookup", "arguments": "{\"q\":\"changed\"}"} + + +def test_openai_chat_stream_event_parses_sse_delta_and_done() -> None: + adapter = get_protocol("openai_chat") + event = { + "id": "chunk_1", + "model": "openai/gpt-test", + "choices": [{"index": 0, "delta": {"role": "assistant", "content": "Hel"}, "finish_reason": None}], + } + + parsed = adapter.parse_stream_event(f"data: {json.dumps(event)}\n\n") + done = adapter.parse_stream_event("data: [DONE]\n\n") + + assert parsed.type == "message_delta" + assert parsed.delta is not None + assert parsed.delta.content[0].text == "Hel" + assert done.type == "done" diff --git a/tests/test_protocol_openai_embeddings.py b/tests/test_protocol_openai_embeddings.py new file mode 100644 index 000000000..f398f7a33 --- /dev/null +++ b/tests/test_protocol_openai_embeddings.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from rotator_library.protocols import OPERATION_EMBEDDINGS, get_protocol, list_protocols + + +def test_openai_embeddings_protocol_round_trip_and_usage() -> None: + adapter = get_protocol("openai_embeddings") + raw = {"model": "text-embedding-test", "input": ["one", "two"], "dimensions": 128, "custom": True} + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + response = adapter.parse_response( + { + "model": "text-embedding-test", + "data": [{"object": "embedding", "embedding": [0.1, 0.2], "index": 0}], + "usage": {"prompt_tokens": 4, "total_tokens": 4}, + } + ) + + assert "openai_embeddings" in list_protocols() + assert get_protocol("embeddings") is adapter + assert adapter.supports_operation(OPERATION_EMBEDDINGS) + assert unified.operation == OPERATION_EMBEDDINGS + assert unified.input == ["one", "two"] + assert rebuilt == raw + assert response.data[0]["embedding"] == [0.1, 0.2] + assert response.usage is not None + assert response.usage.input_tokens == 4 + assert adapter.format_response(response)["data"] == response.data + assert adapter.format_response(response)["usage"]["total_tokens"] == 4 diff --git a/tests/test_protocol_openai_images_audio.py b/tests/test_protocol_openai_images_audio.py new file mode 100644 index 000000000..4fd0cd3ce --- /dev/null +++ b/tests/test_protocol_openai_images_audio.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from rotator_library.protocols import ( + OPERATION_AUDIO_TRANSCRIPTION, + OPERATION_AUDIO_TRANSLATION, + OPERATION_IMAGE_EDIT, + OPERATION_IMAGE_GENERATION, + OPERATION_IMAGE_VARIATION, + OPERATION_SPEECH, + ProtocolContext, + get_protocol, +) + + +def test_openai_images_generation_and_edit_shapes() -> None: + adapter = get_protocol("openai_images") + generation = adapter.parse_request({"model": "gpt-image-test", "prompt": "draw a red cube", "size": "1024x1024"}) + edit = adapter.parse_request({"model": "gpt-image-test", "prompt": "make it blue", "image": "image-ref", "mask": "mask-ref"}) + variation = adapter.parse_request({"model": "gpt-image-test", "image": "image-ref"}) + response = adapter.parse_response({"data": [{"url": "https://example.test/image.png", "revised_prompt": "cube"}]}) + edit_response = adapter.parse_response( + {"data": [{"url": "https://example.test/edit.png"}]}, + ProtocolContext(provider_options={"operation": OPERATION_IMAGE_EDIT}), + ) + + assert generation.operation == OPERATION_IMAGE_GENERATION + assert adapter.build_request(generation)["prompt"] == "draw a red cube" + assert edit.operation == OPERATION_IMAGE_EDIT + assert variation.operation == OPERATION_IMAGE_VARIATION + assert {entry["field"] for entry in edit.files} == {"image", "mask"} + assert adapter.build_request(edit)["image"] == "image-ref" + assert response.data[0]["revised_prompt"] == "cube" + assert edit_response.operation == OPERATION_IMAGE_EDIT + + +def test_openai_audio_transcription_and_speech_shapes() -> None: + adapter = get_protocol("openai_audio") + transcription = adapter.parse_request({"model": "whisper-test", "file": "audio-ref", "language": "en"}) + translation = adapter.parse_request( + {"model": "whisper-test", "file": "audio-ref", "prompt": "domain hint"}, + ProtocolContext(provider_options={"operation": OPERATION_AUDIO_TRANSLATION}), + ) + speech = adapter.parse_request({"model": "tts-test", "input": "hello", "voice": "alloy"}) + text_response = adapter.parse_response({"text": "hello world"}) + translation_response = adapter.parse_response( + {"text": "bonjour"}, + ProtocolContext(provider_options={"operation": "audio_translation"}), + ) + binary_response = adapter.parse_response(b"RIFF") + + assert transcription.operation == OPERATION_AUDIO_TRANSCRIPTION + assert transcription.files[0]["value"] == "audio-ref" + assert translation.operation == OPERATION_AUDIO_TRANSLATION + translation.input = "mutated hint" + assert adapter.build_request(translation)["prompt"] == "mutated hint" + assert speech.operation == OPERATION_SPEECH + assert adapter.build_request(speech)["voice"] == "alloy" + assert text_response.operation == OPERATION_AUDIO_TRANSCRIPTION + assert text_response.output == ["hello world"] + assert adapter.format_response(text_response)["text"] == "hello world" + assert translation_response.operation == "audio_translation" + assert binary_response.content_type == "application/octet-stream" + assert adapter.format_response(binary_response) == b"RIFF" diff --git a/tests/test_protocol_operation_model.py b/tests/test_protocol_operation_model.py new file mode 100644 index 000000000..c2c5d0e93 --- /dev/null +++ b/tests/test_protocol_operation_model.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from rotator_library.protocols import ( + OPERATION_CHAT, + OPERATION_COUNT_TOKENS, + OPERATION_EMBEDDINGS, + OPERATION_MESSAGES, + OPERATION_RESPONSES, + OPERATION_UNKNOWN, + ProtocolAdapter, + ProtocolContext, + UnifiedRequest, + UnifiedResponse, + get_protocol, + normalize_operation, +) + + +def test_operation_names_are_extensible_strings() -> None: + assert normalize_operation(" Chat ") == OPERATION_CHAT + assert normalize_operation(None) == OPERATION_UNKNOWN + assert normalize_operation("custom_gateway_op") == "custom_gateway_op" + + +def test_unified_request_response_carry_non_chat_operation_fields() -> None: + request = UnifiedRequest( + operation=OPERATION_EMBEDDINGS, + model="text-embedding-test", + input=["a", "b"], + modalities=["text"], + files=[{"name": "audio.wav", "content_type": "audio/wav"}], + ) + response = UnifiedResponse( + operation=OPERATION_EMBEDDINGS, + model="text-embedding-test", + data=[{"embedding": [0.1, 0.2]}], + content_type="application/json", + ) + + assert request.to_dict()["operation"] == OPERATION_EMBEDDINGS + assert request.to_dict()["input"] == ["a", "b"] + assert response.to_dict()["data"][0]["embedding"] == [0.1, 0.2] + assert response.to_dict()["content_type"] == "application/json" + + +def test_base_adapter_preserves_operation_fields() -> None: + adapter = ProtocolAdapter() + raw = { + "operation": "custom_op", + "model": "provider/model", + "input": "payload", + "modalities": ["text"], + "files": [{"name": "file.bin"}], + "custom": True, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.operation == "custom_op" + assert unified.input == "payload" + assert rebuilt == raw + + +def test_protocols_advertise_supported_operations() -> None: + assert get_protocol("openai_chat").supports_operation(OPERATION_CHAT) + assert not get_protocol("openai_chat").supports_operation(OPERATION_EMBEDDINGS) + assert get_protocol("litellm_fallback").supports_operation(OPERATION_UNKNOWN) + + +def test_core_protocols_stamp_parsed_operation_fields() -> None: + chat = get_protocol("openai_chat").parse_request({"model": "m", "messages": []}) + messages = get_protocol("anthropic_messages").parse_request({"model": "m", "messages": []}) + responses = get_protocol("responses").parse_request({"model": "m", "input": "hi"}) + gemini = get_protocol("gemini").parse_request({"contents": []}) + + assert chat.operation == OPERATION_CHAT + assert messages.operation == OPERATION_MESSAGES + assert responses.operation == OPERATION_RESPONSES + assert gemini.operation == OPERATION_CHAT + + +def test_count_tokens_operation_can_be_context_selected() -> None: + context = ProtocolContext(provider_options={"operation": OPERATION_COUNT_TOKENS}) + anthropic = get_protocol("anthropic_messages").parse_response({"input_tokens": 12}, context) + gemini = get_protocol("gemini").parse_response({"totalTokens": 14}, context) + + assert anthropic.operation == OPERATION_COUNT_TOKENS + assert anthropic.usage is not None + assert anthropic.usage.input_tokens == 12 + assert get_protocol("anthropic_messages").format_response(anthropic) == {"input_tokens": 12} + assert gemini.operation == OPERATION_COUNT_TOKENS + assert gemini.usage is not None + assert gemini.usage.total_tokens == 14 + assert get_protocol("gemini").format_response(gemini) == {"totalTokens": 14} + + anthropic.usage.input_tokens = 13 + gemini.usage.total_tokens = 15 + assert get_protocol("anthropic_messages").format_response(anthropic)["input_tokens"] == 13 + assert get_protocol("gemini").format_response(gemini)["totalTokens"] == 15 + + +def test_context_operation_helpers_reject_unsupported_operations() -> None: + context = ProtocolContext(provider_options={"operation": OPERATION_EMBEDDINGS}) + + anthropic = get_protocol("anthropic_messages").parse_request({"model": "m", "messages": []}, context) + gemini = get_protocol("gemini").parse_request({"contents": []}, context) + + assert anthropic.operation == OPERATION_MESSAGES + assert gemini.operation == OPERATION_CHAT diff --git a/tests/test_protocol_registry.py b/tests/test_protocol_registry.py new file mode 100644 index 000000000..0b449eaf6 --- /dev/null +++ b/tests/test_protocol_registry.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import json +from datetime import datetime +from decimal import Decimal + +import pytest + +from rotator_library.protocols import ( + ContentBlock, + ProtocolAdapter, + ProtocolContext, + ProtocolError, + ToolCall, + UnifiedMessage, + UnifiedRequest, + get_protocol, + get_protocol_class, + list_protocols, + register_protocol, + serialize_value, +) + + +class ExampleProtocol(ProtocolAdapter): + name = "example_test_protocol" + aliases = ("example_test_alias",) + + +def test_protocol_serialization_preserves_nested_values() -> None: + request = UnifiedRequest( + model="provider/model", + messages=[ + UnifiedMessage( + role="assistant", + content=[ContentBlock(type="text", text="hello")], + tool_calls=[ToolCall(id="call_1", name="lookup", arguments={"q": "x"})], + extra={"provider_field": {"kept": True}}, + ) + ], + metadata={"request_id": "req_1"}, + ) + + serialized = request.to_dict() + + assert serialized["messages"][0]["content"][0]["text"] == "hello" + assert serialized["messages"][0]["tool_calls"][0]["arguments"] == {"q": "x"} + assert serialized["messages"][0]["extra"]["provider_field"] == {"kept": True} + json.dumps(serialize_value(request)) + + +def test_base_protocol_preserves_unknown_request_fields() -> None: + adapter = ProtocolAdapter() + raw = {"model": "provider/model", "stream": True, "custom": {"value": 1}} + + unified = adapter.parse_request(raw, ProtocolContext(provider="provider")) + rebuilt = adapter.build_request(unified) + + assert unified.model == "provider/model" + assert unified.stream is True + assert unified.extra == {"custom": {"value": 1}} + assert rebuilt == raw + + +def test_register_protocol_resolves_alias_and_reuses_instance() -> None: + register_protocol(ExampleProtocol, replace=True) + + assert get_protocol_class("example_test_protocol") is ExampleProtocol + assert get_protocol_class("example_test_alias") is ExampleProtocol + assert get_protocol("example_test_protocol") is get_protocol("example_test_alias") + assert "example_test_protocol" in list_protocols() + + +def test_register_protocol_rejects_duplicate_alias() -> None: + class FirstProtocol(ProtocolAdapter): + name = "duplicate_alias_first" + aliases = ("duplicate_alias",) + + class SecondProtocol(ProtocolAdapter): + name = "duplicate_alias_second" + aliases = ("duplicate_alias",) + + register_protocol(FirstProtocol, replace=True) + + with pytest.raises(ValueError, match="alias already registered"): + register_protocol(SecondProtocol) + + +def test_register_protocol_rejects_name_that_conflicts_with_existing_alias() -> None: + class AliasOwnerProtocol(ProtocolAdapter): + name = "alias_owner_protocol" + aliases = ("reserved_protocol_name",) + + class ConflictingNameProtocol(ProtocolAdapter): + name = "reserved_protocol_name" + + register_protocol(AliasOwnerProtocol, replace=True) + + with pytest.raises(ValueError, match="name conflicts with registered alias"): + register_protocol(ConflictingNameProtocol) + + +def test_protocol_serialization_handles_common_non_json_values() -> None: + value = serialize_value( + { + "bytes": b"hello", + "decimal": Decimal("1.25"), + "datetime": datetime(2026, 1, 2, 3, 4, 5), + "object": object(), + } + ) + + assert value["bytes"] == "hello" + assert value["decimal"] == 1.25 + assert value["datetime"] == "2026-01-02T03:04:05" + json.dumps(value) + + +def test_usage_total_does_not_double_count_reasoning_by_default() -> None: + from rotator_library.protocols import Usage + + usage = Usage(input_tokens=10, output_tokens=5, reasoning_tokens=3) + + assert usage.total_tokens == 15 + + +def test_protocol_error_includes_pass_and_payload_preview() -> None: + error = ProtocolError( + "failed", + protocol="example", + pass_name="parse_request", + payload={"secret": "not-redacted-here", "value": 1}, + ) + + assert "example.parse_request" in str(error) + assert "value" in str(error) diff --git a/tests/test_protocol_responses.py b/tests/test_protocol_responses.py new file mode 100644 index 000000000..77476adf3 --- /dev/null +++ b/tests/test_protocol_responses.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import json + +from rotator_library.protocols import get_protocol, list_protocols + + +def test_responses_protocol_is_discovered_with_aliases_and_websocket_support() -> None: + adapter = get_protocol("responses") + + assert "responses" in list_protocols() + assert get_protocol("openai_responses") is adapter + assert adapter.supports_transport("websocket") is False + assert adapter.is_future_transport("websocket") is True + + +def test_responses_request_round_trip_preserves_previous_response_and_tools() -> None: + adapter = get_protocol("responses") + raw = { + "model": "openai/gpt-test", + "instructions": "system", + "previous_response_id": "resp_prev", + "stream": True, + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"type": "function_call_output", "call_id": "call_1", "output": "result"}, + ], + "tools": [{"type": "function", "name": "lookup", "description": "Lookup", "parameters": {"type": "object"}}], + "reasoning": {"effort": "medium"}, + "vendor_extension": {"kept": True}, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.previous_response_id == "resp_prev" + assert unified.system[0].text == "system" + assert unified.messages[0].content[0].text == "hello" + assert unified.messages[1].tool_call_id == "call_1" + assert unified.tools[0].name == "lookup" + assert rebuilt["previous_response_id"] == "resp_prev" + assert rebuilt["reasoning"] == {"effort": "medium"} + assert rebuilt["vendor_extension"] == {"kept": True} + + +def test_responses_response_extracts_output_items_reasoning_calls_and_usage() -> None: + adapter = get_protocol("responses") + raw = { + "id": "resp_1", + "object": "response", + "created_at": 123, + "model": "openai/gpt-test", + "status": "completed", + "output": [ + {"id": "msg_1", "type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "answer"}]}, + {"id": "rs_1", "type": "reasoning", "summary": [{"type": "summary_text", "text": "reasoned"}]}, + {"id": "fc_1", "type": "function_call", "call_id": "call_1", "name": "lookup", "arguments": "{\"q\":\"x\"}"}, + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 18, + "input_tokens_details": {"cached_tokens": 3, "cache_creation_tokens": 2}, + "output_tokens_details": {"reasoning_tokens": 3}, + "cost_details": {"total_cost": 0.02, "currency": "USD"}, + }, + } + + unified = adapter.parse_response(raw) + + assert unified.id == "resp_1" + assert len(unified.output) == 3 + assert unified.messages[0].content[0].text == "answer" + assert unified.messages[1].reasoning[0].text == "reasoned" + assert unified.messages[2].tool_calls[0].name == "lookup" + assert unified.usage is not None + assert unified.usage.input_tokens == 10 + assert unified.usage.reasoning_tokens == 3 + assert unified.usage.cache_read_tokens == 3 + assert unified.usage.cache_write_tokens == 2 + assert unified.usage.cost is not None + assert unified.usage.cost.provider_reported_cost == 0.02 + + formatted = adapter.format_response(unified) + assert formatted["usage"]["input_tokens"] == 10 + assert formatted["usage"]["output_tokens"] == 5 + assert formatted["usage"]["total_tokens"] == 18 + assert formatted["usage"]["input_tokens_details"] == {"cached_tokens": 3, "cache_creation_tokens": 2} + assert formatted["usage"]["output_tokens_details"] == {"reasoning_tokens": 3} + assert formatted["usage"]["cost_details"]["total_cost"] == 0.02 + assert "raw" not in formatted["usage"] + assert "extra" not in formatted["usage"] + + +def test_responses_format_preserves_unknown_output_items() -> None: + adapter = get_protocol("responses") + raw = { + "id": "resp_1", + "model": "openai/gpt-test", + "output": [ + {"id": "msg_1", "type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "before"}]}, + {"id": "future_1", "type": "future_item", "custom": {"kept": True}}, + ], + } + + unified = adapter.parse_response(raw) + unified.messages[0].content[0].text = "after" + rebuilt = adapter.format_response(unified) + + assert rebuilt["output"][0]["content"][0]["text"] == "after" + assert rebuilt["output"][1] == {"id": "future_1", "type": "future_item", "custom": {"kept": True}} + + +def test_responses_stream_event_parses_text_delta_and_completed_response() -> None: + adapter = get_protocol("responses") + delta = {"type": "response.output_text.delta", "output_index": 0, "content_index": 0, "delta": "Hi"} + completed = { + "type": "response.completed", + "response": { + "id": "resp_1", + "model": "openai/gpt-test", + "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hi"}]}], + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + }, + } + + parsed_delta = adapter.parse_stream_event(f"data: {json.dumps(delta)}\n\n") + parsed_completed = adapter.parse_stream_event(completed) + + assert parsed_delta.type == "message_delta" + assert parsed_delta.delta is not None + assert parsed_delta.delta.content[0].text == "Hi" + assert parsed_completed.type == "response.completed" + assert parsed_completed.usage is not None + assert parsed_completed.usage.total_tokens == 2 diff --git a/tests/test_provider_protocol_declarations.py b/tests/test_provider_protocol_declarations.py new file mode 100644 index 000000000..a16ee56f3 --- /dev/null +++ b/tests/test_provider_protocol_declarations.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from rotator_library.field_cache import FieldCacheRule +from rotator_library.providers.provider_interface import ProviderInterface + + +class DeclarationProvider(ProviderInterface): + protocol_name = "openai_chat" + adapter_names = ("model_override", "suppress_developer_role") + field_cache_rules = ( + FieldCacheRule(name="reasoning_content", source="response", path="choices.*.message.reasoning_content"), + ) + + async def get_models(self, api_key, client): + return [] + + +class BareProvider(ProviderInterface): + async def get_models(self, api_key, client): + return [] + + +def test_provider_interface_plain_defaults_are_noop() -> None: + provider = BareProvider() + + assert provider.get_protocol_name("model") is None + assert provider.get_adapter_names("model") == () + assert provider.get_adapter_config("model") == {} + assert provider.get_field_cache_rules("model") == () + + +def test_provider_interface_defaults_are_noop_for_protocol_stack() -> None: + provider = DeclarationProvider() + + assert provider.get_protocol_name("model") == "openai_chat" + assert provider.get_adapter_names("model") == ("model_override", "suppress_developer_role") + assert provider.get_adapter_config("model") == {} + assert provider.get_field_cache_rules("model")[0].name == "reasoning_content" + + +def test_provider_interface_methods_can_be_model_specific() -> None: + class ModelSpecificProvider(DeclarationProvider): + def get_protocol_name(self, model: str = ""): + return "responses" if "response" in model else super().get_protocol_name(model) + + def get_adapter_config(self, model: str = ""): + return {"model_override": {"model": f"native/{model}"}} + + provider = ModelSpecificProvider() + + assert provider.get_protocol_name("response-model") == "responses" + assert provider.get_protocol_name("chat-model") == "openai_chat" + assert provider.get_adapter_config("chat-model") == {"model_override": {"model": "native/chat-model"}} diff --git a/tests/test_request_builder_routing.py b/tests/test_request_builder_routing.py new file mode 100644 index 000000000..53ff9e3a4 --- /dev/null +++ b/tests/test_request_builder_routing.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import pytest + +from rotator_library.client.request_builder import RequestContextBuilder + + +class FakeModelResolver: + def resolve_model_id(self, model, provider): + return model + + +class FakeSession: + session_id = "session" + affinity_key = "affinity" + possible_compaction = False + lineage_parent_session_id = None + tracking_namespace = "namespace" + + +class FakeSessionTracker: + def __init__(self): + self.calls = [] + + def infer_session(self, *args, **kwargs): + self.calls.append((args, kwargs)) + return FakeSession() + + +async def _scope(provider, classifier, request_api_keys, request_providers, private): + return { + "credentials": [f"{provider}-cred"], + "usage_manager_key": provider, + "provider_config": {"provider": provider}, + "credential_secrets": {f"{provider}-cred": f"{provider}-secret"}, + "classifier": classifier or "global", + } + + +def _builder(session_tracker=None) -> RequestContextBuilder: + return RequestContextBuilder( + resolve_scope_for_provider=_scope, + model_resolver=FakeModelResolver(), + session_tracker=session_tracker or FakeSessionTracker(), + get_global_timeout=lambda: 30, + get_enable_request_logging=lambda: False, + ) + + +@pytest.mark.asyncio +async def test_request_builder_leaves_no_config_provider_model_unrouted(monkeypatch) -> None: + monkeypatch.delenv("FALLBACK_GROUPS", raising=False) + context = await _builder().build_completion_context(None, None, {"model": "openai/gpt-5.1", "messages": []}) + + assert context.routing_targets is None + assert context.provider == "openai" + assert context.credentials == ["openai-cred"] + + +@pytest.mark.asyncio +async def test_request_builder_populates_fallback_group_targets_from_env(monkeypatch) -> None: + monkeypatch.setenv("FALLBACK_GROUPS", "code_chain") + monkeypatch.setenv("FALLBACK_GROUP_CODE_CHAIN", "codex/gpt-5.1-codex,openai/gpt-5.1") + monkeypatch.setenv("MODEL_ROUTE_CODEX", "group:code_chain") + + context = await _builder().build_completion_context(None, None, {"model": "codex", "messages": []}) + + assert context.provider == "codex" + assert context.model == "codex/gpt-5.1-codex" + assert context.routing_group_name == "code_chain" + assert context.routing_group is not None + assert context.routing_group.name == "code_chain" + assert [target.prefixed_model for target in context.routing_targets] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + assert context.routing_targets[1].metadata["request_scope"]["credentials"] == ["openai-cred"] + + +@pytest.mark.asyncio +async def test_request_builder_rejects_unprefixed_model_without_route(monkeypatch) -> None: + monkeypatch.delenv("FALLBACK_GROUPS", raising=False) + + with pytest.raises(ValueError): + await _builder().build_completion_context(None, None, {"model": "gpt-5.1", "messages": []}) + + +@pytest.mark.asyncio +async def test_request_builder_consumes_internal_session_tracking_hints(monkeypatch) -> None: + monkeypatch.delenv("FALLBACK_GROUPS", raising=False) + tracker = FakeSessionTracker() + kwargs = { + "model": "openai/gpt-5.1", + "messages": [], + "_session_tracking_hints": {"strong_anchors": ["responses_previous_response_id:resp_parent"], "affinity_key": "responses_previous_response_id:resp_parent"}, + } + + context = await _builder(tracker).build_completion_context(None, None, kwargs) + + assert "_session_tracking_hints" not in context.kwargs + hints = tracker.calls[0][1]["hints"] + assert hints.strong_anchors == ["responses_previous_response_id:resp_parent"] + assert hints.affinity_key == "responses_previous_response_id:resp_parent" diff --git a/tests/test_request_executor_fallback_error_summary.py b/tests/test_request_executor_fallback_error_summary.py new file mode 100644 index 000000000..bab3ea55b --- /dev/null +++ b/tests/test_request_executor_fallback_error_summary.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from types import MethodType + +import pytest + +from rotator_library.client.executor import RequestExecutor +from rotator_library.core.types import RequestContext +from rotator_library.routing import parse_route_target + + +def _context() -> RequestContext: + return RequestContext( + model="code", + provider="requested", + kwargs={"model": "code", "messages": []}, + streaming=False, + credentials=["cred-a"], + deadline=9999999999.0, + routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")), + routing_group_name="code_chain", + ) + + +@pytest.mark.asyncio +async def test_fallback_summary_includes_all_structured_target_failures_without_credentials() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context.provider) + return { + "error": { + "type": "proxy_all_credentials_exhausted", + "message": f"{context.provider} failed for cred-secret-value", + "details": {"normal_error_summary": "1 rate_limit"}, + } + } + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + response = await executor._execute_non_streaming_with_fallback(_context()) + + fallback_targets = response["error"]["details"]["fallback_targets"] + assert attempts == ["codex", "openai"] + assert [failure["provider"] for failure in fallback_targets] == ["codex", "openai"] + assert [failure["error_type"] for failure in fallback_targets] == ["rate_limit", "rate_limit"] + assert "credential" not in str(fallback_targets).lower() + assert "cred-secret-value" not in str(fallback_targets) diff --git a/tests/test_request_executor_fallback_groups.py b/tests/test_request_executor_fallback_groups.py new file mode 100644 index 000000000..6779d22ed --- /dev/null +++ b/tests/test_request_executor_fallback_groups.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import json +from types import MethodType + +import pytest + +from rotator_library.client.executor import RequestExecutor +from rotator_library.core.types import RequestContext +from rotator_library.routing import parse_route_target +from rotator_library.routing.types import FallbackGroup +from rotator_library.transaction_logger import TransactionLogger + + +class ClassifiedFailure(Exception): + def __init__(self, error_type: str) -> None: + super().__init__(error_type) + self.error_type = error_type + + +def _context(*, routing_targets=None, logger=None, routing_group=None) -> RequestContext: + return RequestContext( + model="code", + provider="requested", + kwargs={"model": "code", "messages": []}, + streaming=False, + credentials=["cred-a"], + deadline=9999999999.0, + transaction_logger=logger, + routing_targets=routing_targets, + routing_group_name="code_chain" if routing_targets else None, + routing_group=routing_group, + ) + + +def _executor_with_attempts(attempts): + executor = RequestExecutor.__new__(RequestExecutor) + + async def fake_execute(self, context): + attempts.append(context) + if len(attempts) == 1: + raise ClassifiedFailure("rate_limit") + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + return executor + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_tries_next_target_on_retryable_error() -> None: + attempts = [] + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) + + context = _context(routing_targets=targets) + result = await _executor_with_attempts(attempts)._execute_non_streaming_with_fallback(context) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + assert [attempt.provider for attempt in attempts] == ["codex", "openai"] + assert [attempt.kwargs["model"] for attempt in attempts] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + assert context.routing_attempt_history[0]["error_type"] == "rate_limit" + assert context.routing_attempt_history[0]["fallback_allowed"] is True + assert context.routing_attempt_history[1]["success"] is True + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_falls_back_on_quota_exceeded() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise ClassifiedFailure("quota_exceeded") + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + result = await executor._execute_non_streaming_with_fallback( + _context(routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1"))) + ) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + assert attempts == ["codex", "openai"] + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_stops_on_permanent_error() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context) + raise ClassifiedFailure("validation") + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + with pytest.raises(ClassifiedFailure): + await executor._execute_non_streaming_with_fallback( + _context(routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1"))) + ) + + assert len(attempts) == 1 + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_hard_stops_group_override() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise ClassifiedFailure("authentication") + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) + with pytest.raises(ClassifiedFailure): + await executor._execute_non_streaming_with_fallback( + _context( + routing_targets=targets, + routing_group=FallbackGroup( + name="code_chain", + targets=targets, + failover_on=frozenset({"authentication"}), + stop_on=frozenset({"validation"}), + ), + ) + ) + + assert attempts == ["codex"] + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_handles_structured_error_response() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context) + if len(attempts) == 1: + return {"error": {"type": "proxy_all_credentials_exhausted", "details": {"normal_error_summary": "2 rate_limit"}}} + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + result = await executor._execute_non_streaming_with_fallback( + _context(routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1"))) + ) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_emits_routing_trace(tmp_path) -> None: + attempts = [] + logger = TransactionLogger("routing", "code", parent_dir=tmp_path) + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) + + await _executor_with_attempts(attempts)._execute_non_streaming_with_fallback(_context(routing_targets=targets, logger=logger)) + + pass_names = [json.loads(line)["pass_name"] for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + assert "routing_decision" in pass_names + assert pass_names.count("routing_target_attempt_started") == 2 + assert "routing_target_attempt_failed" in pass_names + assert "routing_fallback_selected" in pass_names + assert "routing_target_attempt_succeeded" in pass_names + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_records_usage_only_on_successful_target() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + recorded = [] + + async def fake_execute(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise ClassifiedFailure("rate_limit") + recorded.append(context.provider) + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + result = await executor._execute_non_streaming_with_fallback( + _context(routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1"))) + ) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + assert attempts == ["codex", "openai"] + assert recorded == ["openai"] diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py new file mode 100644 index 000000000..ac7c0a7bd --- /dev/null +++ b/tests/test_request_executor_native_routing.py @@ -0,0 +1,564 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.client import executor as executor_module +from rotator_library.client.executor import RequestExecutor, RoutingExecutionError +from rotator_library.core.types import RequestContext +from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule +from rotator_library.providers.antigravity_provider import AntigravityProvider +from rotator_library.providers.claude_code_provider import ClaudeCodeProvider +from rotator_library.providers.codex_provider import CodexProvider +from rotator_library.providers.copilot_provider import CopilotProvider +from rotator_library.routing import parse_route_target + + +class FakeNativeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeHTTPClient: + def __init__(self): + self.calls = [] + + async def post(self, endpoint, *, headers, json): + self.calls.append({"endpoint": endpoint, "headers": headers, "json": json}) + return FakeNativeResponse({"id": "chat_native", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}) + + +class SequencedHTTPClient: + def __init__(self, responses): + self.responses = list(responses) + self.calls = [] + + async def post(self, endpoint, *, headers, json): + self.calls.append({"endpoint": endpoint, "headers": headers, "json": json}) + return FakeNativeResponse(self.responses.pop(0)) + + +class NativePlugin: + def has_custom_logic(self): + return False + + def get_protocol_name(self, model=""): + return "openai_chat" + + def get_native_endpoint(self, model="", operation="chat"): + return "https://native.test/chat" + + def get_native_headers(self, credential_identifier, model="", operation="chat"): + return {"Authorization": f"Bearer {credential_identifier}", "X-Operation": operation, "X-Model": model} + + def get_native_operation(self, model="", request=None, stream=False): + return "messages" if stream else "chat" + + def normalize_native_model(self, model=""): + return model.split("/", 1)[1] if "/" in model else model + + def get_adapter_names(self, model=""): + return () + + def get_adapter_config(self, model=""): + return {} + + def get_field_cache_rules(self, model=""): + return () + + +class NativePluginWithRule(NativePlugin): + def get_field_cache_rules(self, model=""): + return ( + FieldCacheRule( + name="state", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.state"), + allow_missing_session=True, + ), + ) + + +class NativePluginWithVendorRule(NativePlugin): + def get_field_cache_rules(self, model=""): + return ( + FieldCacheRule( + name="vendor_state", + source="response", + path="choices.0.message.vendor_state", + inject=FieldCacheInjection(target="request", path="metadata.vendor_state"), + allow_missing_session=True, + ), + ) + + +class NativePluginWithStreamVendorRule(NativePlugin): + def get_field_cache_rules(self, model=""): + return ( + FieldCacheRule( + name="vendor_state", + source="stream_event", + path="raw.choices.0.message.vendor_state", + allow_missing_session=True, + ), + ) + + +class NativeOptOutPlugin(NativePlugin): + def should_use_native_protocol(self, model="", operation="chat", *, stream=False, execution="auto"): + return False + + +class CustomPlugin: + def __init__(self): + self.calls = [] + + def has_custom_logic(self): + return True + + async def acompletion(self, client, **kwargs): + self.calls.append(kwargs) + return {"id": "custom"} + + def get_protocol_name(self, model=""): + return "gemini" + + def get_native_endpoint(self, model="", operation="chat"): + return "https://native.test/should-not-run" + + def get_native_headers(self, credential_identifier, model="", operation="chat"): + return {"Authorization": f"Bearer {credential_identifier}"} + + +def _context(target=None) -> RequestContext: + return RequestContext( + model="provider/gpt-test", + provider="provider", + kwargs={"model": "provider/gpt-test", "messages": [{"role": "user", "content": "hi"}]}, + streaming=False, + credentials=["cred"], + deadline=9999999999.0, + routing_targets=(target,) if target else None, + ) + + +def _provider_context(provider: str, model: str, kwargs: dict, target=None) -> RequestContext: + return RequestContext( + model=model, + provider=provider, + kwargs=kwargs, + streaming=False, + credentials=["cred"], + deadline=9999999999.0, + routing_targets=(target,) if target else None, + ) + + +def _executor(http_client=None) -> RequestExecutor: + executor = RequestExecutor.__new__(RequestExecutor) + executor._http_client = http_client or FakeHTTPClient() + executor._apply_litellm_logger = lambda kwargs: None + return executor + + +@pytest.mark.asyncio +async def test_native_declared_provider_uses_native_executor_in_auto_mode() -> None: + http_client = FakeHTTPClient() + target = parse_route_target("provider/gpt-test") + context = _context(target) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request( + "provider", "provider/gpt-test", NativePlugin(), "secret", "stable", dict(context.kwargs), context + ) + + assert response["id"] == "chat_native" + assert http_client.calls[0]["endpoint"] == "https://native.test/chat" + assert http_client.calls[0]["headers"]["Authorization"] == "Bearer secret" + assert http_client.calls[0]["headers"]["X-Operation"] == "chat" + assert http_client.calls[0]["headers"]["X-Model"] == "gpt-test" + assert http_client.calls[0]["json"]["model"] == "gpt-test" + + +@pytest.mark.asyncio +async def test_auto_mode_prefers_custom_logic_over_native_declaration() -> None: + http_client = FakeHTTPClient() + plugin = CustomPlugin() + context = _context(parse_route_target("provider/gpt-test")) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request( + "provider", "provider/gpt-test", plugin, "secret", "stable", dict(context.kwargs), context + ) + + assert response == {"id": "custom"} + assert plugin.calls[0]["credential_identifier"] == "secret" + assert http_client.calls == [] + + +@pytest.mark.asyncio +async def test_request_executor_reuses_native_field_cache_store() -> None: + http_client = SequencedHTTPClient( + [ + {"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "cached"}}]}, + {"id": "chat_2", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}, + ] + ) + executor = _executor(http_client) + context = _context(parse_route_target("provider/gpt-test")) + context.routing_target_index = 0 + + await executor._execute_provider_request("provider", "provider/gpt-test", NativePluginWithRule(), "secret", "stable", dict(context.kwargs), context) + await executor._execute_provider_request("provider", "provider/gpt-test", NativePluginWithRule(), "secret", "stable", dict(context.kwargs), context) + + assert http_client.calls[1]["json"]["metadata"]["state"] == "cached" + + +def test_native_context_merges_json_field_cache_rules(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "field_cache": { + "provider": { + "*": [ + { + "name": "state", + "source": "response", + "path": "json.path", + "inject": {"target": "request", "path": "metadata.state"}, + }, + { + "name": "extra", + "source": "response", + "path": "json.extra", + "inject": {"target": "request", "path": "metadata.extra"}, + }, + ] + } + } + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("LLM_PROXY_CONFIG_FILE", str(config_path)) + + native_context = _executor()._build_native_provider_context( + "provider", + "provider/gpt-test", + NativePluginWithRule(), + "secret", + "stable", + _context(), + None, + ) + + assert [rule.name for rule in native_context.field_cache_rules] == ["state", "extra"] + assert native_context.field_cache_rules[0].path == "json.path" + + +@pytest.mark.asyncio +async def test_claude_code_provider_runs_mock_live_native_request(monkeypatch) -> None: + monkeypatch.setenv("CLAUDE_CODE_API_BASE", "https://claude-code.test") + http_client = SequencedHTTPClient([ + {"id": "msg_1", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "ok"}], "usage": {"input_tokens": 1, "output_tokens": 1}} + ]) + provider = ClaudeCodeProvider() + target = parse_route_target("claude_code/claude-sonnet-4-5") + context = _provider_context( + "claude_code", + "claude_code/claude-sonnet-4-5", + {"model": "claude_code/claude-sonnet-4-5", "messages": [{"role": "developer", "content": "rules"}, {"role": "user", "content": "hi"}]}, + target, + ) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request("claude_code", context.model, provider, "secret", "stable", dict(context.kwargs), context) + + assert response["id"] == "msg_1" + assert response["choices"][0]["message"]["content"] == "ok" + assert http_client.calls[0]["endpoint"] == "https://claude-code.test/v1/messages" + assert http_client.calls[0]["json"]["model"] == "claude-sonnet-4-5" + assert http_client.calls[0]["json"]["max_tokens"] == 4096 + assert http_client.calls[0]["json"]["messages"][0]["role"] == "user" + + +@pytest.mark.asyncio +async def test_codex_provider_runs_mock_live_native_request(monkeypatch) -> None: + monkeypatch.setenv("CODEX_API_BASE", "https://codex.test") + http_client = SequencedHTTPClient([ + {"id": "resp_1", "object": "response", "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "ok"}]}]} + ]) + provider = CodexProvider() + target = parse_route_target("codex/gpt-5.1-codex") + context = _provider_context("codex", "codex/gpt-5.1-codex", {"model": "codex/gpt-5.1-codex", "messages": [{"role": "user", "content": "hi"}]}, target) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request("codex", context.model, provider, "secret", "stable", dict(context.kwargs), context) + + assert response["id"] == "resp_1" + assert response["choices"][0]["message"]["content"] == "ok" + assert http_client.calls[0]["endpoint"] == "https://codex.test/v1/responses" + assert http_client.calls[0]["json"]["model"] == "gpt-5.1-codex" + assert http_client.calls[0]["json"]["input"][0]["content"] == [{"type": "text", "text": "hi"}] + assert "messages" not in http_client.calls[0]["json"] + + +@pytest.mark.asyncio +async def test_copilot_provider_runs_mock_live_native_request(monkeypatch) -> None: + monkeypatch.setenv("COPILOT_API_BASE", "https://copilot.test") + http_client = SequencedHTTPClient([ + {"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]} + ]) + provider = CopilotProvider() + target = parse_route_target("copilot/gpt-4.1") + context = _provider_context("copilot", "copilot/gpt-4.1", {"model": "copilot/gpt-4.1", "messages": [{"role": "developer", "content": "rules"}, {"role": "user", "content": "hi"}]}, target) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request("copilot", context.model, provider, "secret", "stable", dict(context.kwargs), context) + + assert response["id"] == "chat_1" + assert http_client.calls[0]["endpoint"] == "https://copilot.test/chat/completions" + assert http_client.calls[0]["json"]["model"] == "gpt-4.1" + assert http_client.calls[0]["json"]["messages"][0]["role"] == "system" + + +@pytest.mark.asyncio +async def test_antigravity_provider_runs_mock_live_native_request(monkeypatch) -> None: + monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") + http_client = SequencedHTTPClient([ + {"candidates": [{"content": {"role": "model", "parts": [{"text": "ok"}]}, "finishReason": "STOP"}], "usageMetadata": {"totalTokenCount": 2}} + ]) + provider = AntigravityProvider() + target = parse_route_target("antigravity/claude-sonnet-4.5") + context = _provider_context("antigravity", "antigravity/claude-sonnet-4.5", {"model": "antigravity/claude-sonnet-4.5", "messages": [{"role": "user", "content": "hi"}]}, target) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request("antigravity", context.model, provider, "secret", "stable", dict(context.kwargs), context) + + assert response["choices"][0]["message"]["content"] == "ok" + assert http_client.calls[0]["endpoint"] == "https://antigravity.test/v1internal:generateContent" + assert http_client.calls[0]["json"]["model"] == "claude-sonnet-4-5" + assert http_client.calls[0]["json"]["request"]["contents"][0]["parts"][0]["text"] == "hi" + assert http_client.calls[0]["json"]["requestType"] == "CHAT_COMPLETION" + assert "requestId" in http_client.calls[0]["json"] + + +def test_claude_code_header_modes(monkeypatch) -> None: + provider = ClaudeCodeProvider() + + monkeypatch.setenv("CLAUDE_CODE_AUTH_HEADER", "auto") + assert provider.get_native_headers("sk-ant-test")["x-api-key"] == "sk-ant-test" + assert provider.get_native_headers("oauth-token")["Authorization"] == "Bearer oauth-token" + monkeypatch.setenv("CLAUDE_CODE_AUTH_HEADER", "x-api-key") + assert provider.get_native_headers("any-token")["x-api-key"] == "any-token" + + +def test_antigravity_alias_normalization_preserves_thinking_level() -> None: + provider = AntigravityProvider() + request = {"_proxy_model": "antigravity/gemini-3-pro-low", "model": "gemini-3-pro-preview", "messages": [{"role": "user", "content": "hi"}]} + + prepared = provider.prepare_native_request(request, model=provider.normalize_native_model("antigravity/gemini-3-pro-low"), operation="generate") + + assert provider.normalize_native_model("antigravity/gemini-3-pro-low") == "gemini-3-pro-preview" + assert prepared["model"] == "gemini-3-pro-low" + assert prepared["generationConfig"]["thinkingConfig"]["thinkingLevel"] == "low" + assert prepared["metadata"]["thinking_level"] == "low" + assert "_proxy_model" not in prepared + assert "messages" not in prepared + assert prepared["contents"][0]["parts"][0]["text"] == "hi" + + +def test_auto_native_selection_honors_provider_opt_out() -> None: + assert executor_module._should_use_native_protocol( + NativeOptOutPlugin(), + "provider/gpt-test", + None, + {"model": "provider/gpt-test", "messages": []}, + stream=False, + execution="auto", + ) is False + + +def test_auto_native_streaming_selection_honors_provider_opt_out() -> None: + assert executor_module._should_use_native_streaming( + NativeOptOutPlugin(), + "provider/gpt-test", + parse_route_target("provider/gpt-test"), + "auto", + "provider", + ) is False + + +def test_native_request_payload_drops_litellm_only_fields() -> None: + payload = executor_module._native_request_payload( + { + "model": "provider/gpt-test", + "messages": [{"role": "user", "content": "hi"}], + "custom_llm_provider": "openai", + "api_base": "https://litellm-only.test", + "transaction_context": {"id": "trace"}, + "litellm_call_id": "call", + } + ) + + assert payload == {"model": "provider/gpt-test", "messages": [{"role": "user", "content": "hi"}]} + + +def test_route_error_type_from_response_hard_stop_wins_over_retry_summary() -> None: + response = { + "error": { + "type": "authentication", + "message": "provider said quota secret-token", + "details": {"normal_error_summary": "rate_limit quota capacity", "status_code": 401}, + } + } + + assert executor_module._route_error_type_from_response(response) == "authentication" + + +def test_route_error_type_from_response_uses_structured_status_codes() -> None: + assert executor_module._route_error_type_from_response({"error": {"code": 403}}) == "forbidden" + assert executor_module._route_error_type_from_response({"error": {"status": 401}}) == "authentication" + assert executor_module._route_error_type_from_response({"error": {"details": {"status": 403}}}) == "forbidden" + + +def test_route_error_type_from_response_uses_structured_aliases() -> None: + assert executor_module._route_error_type_from_response({"error": {"type": "invalid_api_key"}}) == "authentication" + assert executor_module._route_error_type_from_response({"error": {"code": "invalid_argument"}}) == "invalid_request" + assert executor_module._route_error_type_from_response({"error": {"code": "resource_exhausted"}}) == "quota_exceeded" + assert executor_module._route_error_type_from_response({"error": {"code": "unavailable"}}) == "server_error" + assert executor_module._route_error_type_from_response({"error": {"code": "deadline_exceeded"}}) == "api_connection" + assert executor_module._route_error_type_from_response({"error": {"details": {"status_code": 503}}}) == "server_error" + + +def test_route_error_type_from_response_reads_abnormal_errors_before_proxy_summary() -> None: + response = { + "error": { + "type": "proxy_all_credentials_exhausted", + "details": { + "normal_error_summary": "rate_limit quota", + "abnormal_errors": [{"error_type": "authentication", "status_code": 401}], + }, + } + } + + assert executor_module._route_error_type_from_response(response) == "authentication" + + +def test_stream_chunk_error_type_detects_terminal_error_frames() -> None: + assert executor_module._stream_chunk_error_type('data: {"error":{"type":"rate_limit"}}\n\n') == "rate_limit" + assert executor_module._stream_chunk_error_type('event: response.failed\ndata: {"error":{"type":"authentication"}}\n\n') == "authentication" + assert executor_module._stream_chunk_error_type('event: error\ndata: {"type":"error","code":429}\n\n') == "rate_limit" + assert executor_module._stream_chunk_error_type('event: error\ndata: {"type":"error","code":"context_length_exceeded"}\n\n') == "context_window_exceeded" + + +def test_target_failure_summary_is_structural_and_sanitized() -> None: + summary = executor_module._target_failure_summary(parse_route_target("openai/gpt"), "rate-limit", status_code=429) + + assert summary["error_type"] == "rate_limit" + assert summary["status_code"] == 429 + assert summary["message"] == "" + + +def test_explicit_native_streaming_fails_when_provider_does_not_support_it() -> None: + target = parse_route_target("provider/gpt-test@native") + + with pytest.raises(RoutingExecutionError) as exc: + executor_module._should_use_native_streaming(NativePlugin(), "provider/gpt-test", target, "native", "provider") + + assert exc.value.error_type == "configuration_error" + + +def test_antigravity_cache_injection_targets_safe_envelope() -> None: + provider = AntigravityProvider() + + rules = provider.get_field_cache_rules("gemini-3-flash") + + assert rules[0].inject.path == "request.metadata.thoughtSignatures" + + +def test_native_context_raises_on_invalid_field_cache_config(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"field_cache": {"provider": {"*": [{"name": "bad", "source": "response"}]}}}), encoding="utf-8") + monkeypatch.setenv("LLM_PROXY_CONFIG_FILE", str(config_path)) + + with pytest.raises(RoutingExecutionError) as exc: + _executor()._build_native_provider_context("provider", "provider/gpt-test", NativePlugin(), "secret", "stable", _context(), None) + + assert exc.value.error_type == "configuration_error" + + +def test_executor_trace_redaction_uses_native_field_cache_response_paths() -> None: + context = _context(parse_route_target("provider/gpt-test")) + payload = {"choices": [{"message": {"role": "assistant", "content": "ok", "vendor_state": "opaque-vendor-state"}}]} + + redacted = executor_module._redact_context_field_cache_paths(payload, context, "response", NativePluginWithVendorRule()) + + assert redacted["choices"][0]["message"]["vendor_state"] == "[REDACTED]" + assert payload["choices"][0]["message"]["vendor_state"] == "opaque-vendor-state" + + +def test_executor_stream_trace_redaction_uses_native_field_cache_paths() -> None: + context = _context(parse_route_target("provider/gpt-test")) + sse_line = 'data: {"choices":[{"delta":{"content":"ok"},"message":{"vendor_state":"opaque-vendor-state"}}]}\n\n' + + redacted = executor_module._redact_stream_sse_for_trace(sse_line, context, NativePluginWithStreamVendorRule()) + + parsed = json.loads(redacted[6:].strip()) + + assert "opaque-vendor-state" not in redacted + assert parsed["choices"][0]["message"]["vendor_state"] == "[REDACTED]" + + +@pytest.mark.asyncio +async def test_litellm_fallback_execution_is_explicit(monkeypatch) -> None: + calls = [] + + async def fake_acompletion(**kwargs): + calls.append(kwargs) + return {"id": "litellm"} + + monkeypatch.setattr(executor_module.litellm, "acompletion", fake_acompletion) + target = parse_route_target("provider/gpt-test@litellm_fallback") + context = _context(target) + context.routing_target_index = 0 + + response = await _executor()._execute_provider_request("provider", "provider/gpt-test", NativePlugin(), "secret", "stable", dict(context.kwargs), context) + + assert response == {"id": "litellm"} + assert calls[0]["api_key"] == "secret" + + +@pytest.mark.asyncio +async def test_custom_execution_mode_requires_custom_plugin() -> None: + target = parse_route_target("provider/gpt-test@custom") + context = _context(target) + context.routing_target_index = 0 + + plugin = CustomPlugin() + + response = await _executor()._execute_provider_request("provider", "provider/gpt-test", plugin, "secret", "stable", dict(context.kwargs), context) + + assert response == {"id": "custom"} + assert plugin.calls[0]["credential_identifier"] == "secret" + + +@pytest.mark.asyncio +async def test_native_execution_mode_fails_when_provider_has_no_native_declaration() -> None: + target = parse_route_target("provider/gpt-test@native") + context = _context(target) + context.routing_target_index = 0 + + with pytest.raises(RoutingExecutionError) as exc: + await _executor()._execute_provider_request("provider", "provider/gpt-test", None, "secret", "stable", dict(context.kwargs), context) + + assert exc.value.error_type == "configuration_error" diff --git a/tests/test_request_executor_stream_metrics.py b/tests/test_request_executor_stream_metrics.py new file mode 100644 index 000000000..45d94acdc --- /dev/null +++ b/tests/test_request_executor_stream_metrics.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import json +import asyncio + +import pytest + +from rotator_library.client.streaming import StreamingHandler +from rotator_library.core.errors import StreamedAPIError +from rotator_library.transaction_logger import TransactionLogger + + +async def _chunks(): + yield {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + yield {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}} + + +class HangingStream: + def __init__(self) -> None: + self.closed = False + + def __aiter__(self): + return self + + async def __anext__(self): + await asyncio.sleep(1) + return {"choices": [{"delta": {"content": "late"}}]} + + async def aclose(self) -> None: + self.closed = True + + +class DelayedStream: + def __init__(self, delay: float = 0.03) -> None: + self.delay = delay + self.index = 0 + self.closed = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index == 0: + self.index += 1 + await asyncio.sleep(self.delay) + return {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + if self.index == 1: + self.index += 1 + return {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}} + raise StopAsyncIteration + + async def aclose(self) -> None: + self.closed = True + + +class FirstThenHangStream(HangingStream): + def __init__(self) -> None: + super().__init__() + self.index = 0 + + async def __anext__(self): + if self.index == 0: + self.index += 1 + return {"id": "chunk_1", "choices": [{"delta": {}}]} + return await super().__anext__() + + +class DisconnectedRequest: + async def is_disconnected(self) -> bool: + return True + + +class DelayedDisconnectedRequest: + def __init__(self) -> None: + self.calls = 0 + + async def is_disconnected(self) -> bool: + self.calls += 1 + await asyncio.sleep(0.01) + return self.calls >= 1 + + +def _trace_passes(log_dir): + return [json.loads(line)["pass_name"] for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_streaming_handler_emits_lifecycle_metrics_without_changing_output(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_chunks(), "cred", "openai/gpt-test", transaction_logger=logger)] + + assert chunks[0].startswith("data: ") + assert chunks[-1] == "data: [DONE]\n\n" + pass_names = _trace_passes(logger.log_dir) + assert "stream_started" in pass_names + assert "stream_first_byte" in pass_names + assert "stream_first_visible_output" in pass_names + assert "stream_completed" in pass_names + assert "stream_metrics_final" in pass_names + + +@pytest.mark.asyncio +async def test_stream_trace_metrics_can_be_disabled_without_changing_output(tmp_path, monkeypatch) -> None: + monkeypatch.setenv("STREAM_TRACE_METRICS", "false") + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_chunks(), "cred", "openai/gpt-test", transaction_logger=logger)] + + assert chunks[0].startswith("data: ") + assert chunks[-1] == "data: [DONE]\n\n" + if (logger.log_dir / "transform_trace.jsonl").exists(): + pass_names = _trace_passes(logger.log_dir) + assert "stream_started" in pass_names + assert "stream_metrics_final" in pass_names + + +@pytest.mark.asyncio +async def test_streaming_handler_closes_upstream_on_client_disconnect(monkeypatch) -> None: + monkeypatch.delenv("STREAM_TTFB_TIMEOUT_SECONDS", raising=False) + stream = HangingStream() + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test", request=DisconnectedRequest())] + + assert chunks == [] + assert stream.closed is True + + +@pytest.mark.asyncio +async def test_streaming_handler_closes_upstream_when_disconnect_happens_during_wait(monkeypatch) -> None: + monkeypatch.delenv("STREAM_TTFB_TIMEOUT_SECONDS", raising=False) + monkeypatch.delenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", raising=False) + stream = HangingStream() + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test", request=DelayedDisconnectedRequest())] + + assert chunks == [] + assert stream.closed is True + + +@pytest.mark.asyncio +async def test_streaming_handler_emits_configured_heartbeats(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + monkeypatch.delenv("STREAM_TTFB_TIMEOUT_SECONDS", raising=False) + monkeypatch.delenv("STREAM_STALL_TIMEOUT_SECONDS", raising=False) + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(DelayedStream(), "cred", "openai/gpt-test")] + + assert any(chunk.startswith(": heartbeat") for chunk in chunks) + assert chunks[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_streaming_handler_ttfb_timeout_closes_upstream(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.01") + monkeypatch.delenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", raising=False) + stream = HangingStream() + + with pytest.raises(StreamedAPIError) as exc: + _ = [chunk async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test")] + + assert stream.closed is True + assert exc.value.data["error"]["details"]["timeout_type"] == "ttfb" + + +@pytest.mark.asyncio +async def test_stream_timeout_closes_upstream_even_when_disconnect_close_disabled(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.01") + monkeypatch.setenv("STREAM_CANCEL_UPSTREAM_ON_DISCONNECT", "false") + stream = HangingStream() + + with pytest.raises(StreamedAPIError): + _ = [chunk async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test")] + + assert stream.closed is True + + +@pytest.mark.asyncio +async def test_streaming_handler_stall_timeout_after_first_byte(monkeypatch) -> None: + monkeypatch.setenv("STREAM_STALL_TIMEOUT_SECONDS", "0.01") + monkeypatch.delenv("STREAM_TTFB_TIMEOUT_SECONDS", raising=False) + stream = FirstThenHangStream() + chunks = [] + + with pytest.raises(StreamedAPIError) as exc: + async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test"): + chunks.append(chunk) + + assert chunks and chunks[0].startswith("data: ") + assert stream.closed is True + assert exc.value.data["error"]["details"]["timeout_type"] == "stall" + + +@pytest.mark.asyncio +async def test_streaming_handler_passes_through_formatted_sse_chunks(monkeypatch) -> None: + async def formatted_stream(): + yield 'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(formatted_stream(), "cred", "openai/gpt-test")] + + assert chunks[0] == 'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' + assert chunks[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_streaming_handler_does_not_duplicate_direct_done_sentinel(monkeypatch) -> None: + async def done_stream(): + yield "data: [DONE]\n\n" + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(done_stream(), "cred", "openai/gpt-test")] + + assert chunks == ["data: [DONE]\n\n"] diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py new file mode 100644 index 000000000..a5dfe9706 --- /dev/null +++ b/tests/test_responses_bridge.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from rotator_library.protocols.responses import ResponsesProtocol +from rotator_library.responses import ResponsesBridge + + +def test_bridge_converts_responses_request_to_chat_kwargs() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request( + { + "model": "gpt-test", + "instructions": "Follow rules.", + "input": "Hello", + "max_output_tokens": 20, + "metadata": {"trace": "yes"}, + } + ) + + kwargs = bridge.to_chat_kwargs(unified) + + assert kwargs["model"] == "gpt-test" + assert kwargs["messages"] == [ + {"role": "system", "content": "Follow rules."}, + {"role": "user", "content": "Hello"}, + ] + assert kwargs["max_tokens"] == 20 + assert kwargs["metadata"] == {"trace": "yes"} + + +def test_bridge_adds_parent_response_messages_for_previous_response_id() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}) + parent = {"output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Earlier"}]}]} + + kwargs = bridge.to_chat_kwargs(unified, parent_response=parent) + + assert kwargs["messages"] == [ + {"role": "assistant", "content": "Earlier"}, + {"role": "user", "content": "Continue"}, + ] + assert kwargs["_responses_bridge"]["previous_response_id"] == "resp_parent" + assert kwargs["_session_tracking_hints"]["strong_anchors"] == ["responses_previous_response_id:resp_parent"] + assert kwargs["_session_tracking_hints"]["affinity_key"] == "responses_previous_response_id:resp_parent" + assert "session_scope" not in kwargs["_session_tracking_hints"] + + +def test_bridge_replays_parent_input_and_output_lineage() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Now"}) + lineage = [ + { + "request": {"model": "gpt-test", "input": "First user"}, + "response": {"output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "First assistant"}]}]}, + }, + { + "request": {"model": "gpt-test", "input": "Second user", "previous_response_id": "resp_1"}, + "response": {"output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Second assistant"}]}]}, + }, + ] + + kwargs = bridge.to_chat_kwargs(unified, parent_responses=lineage) + + assert kwargs["messages"] == [ + {"role": "user", "content": "First user"}, + {"role": "assistant", "content": "First assistant"}, + {"role": "user", "content": "Second user"}, + {"role": "assistant", "content": "Second assistant"}, + {"role": "user", "content": "Now"}, + ] + + +def test_bridge_replays_parent_tool_call_outputs() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Continue"}) + lineage = [ + { + "request": {"model": "gpt-test", "input": "Use tool"}, + "response": {"output": [{"id": "call_1", "type": "function_call", "call_id": "call_1", "name": "lookup", "arguments": "{}"}]}, + } + ] + + kwargs = bridge.to_chat_kwargs(unified, parent_responses=lineage) + + assert kwargs["messages"] == [ + {"role": "user", "content": "Use tool"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}]}, + {"role": "user", "content": "Continue"}, + ] + + +def test_bridge_replays_parent_tool_result_inputs_as_tool_messages() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Continue"}) + lineage = [ + { + "request": { + "model": "gpt-test", + "input": [ + {"type": "message", "role": "user", "content": "Use tool"}, + {"type": "function_call_output", "call_id": "call_1", "output": "tool result"}, + ], + }, + "response": {"output": []}, + } + ] + + kwargs = bridge.to_chat_kwargs(unified, parent_responses=lineage) + + assert kwargs["messages"] == [ + {"role": "user", "content": "Use tool"}, + {"role": "tool", "content": "tool result", "tool_call_id": "call_1"}, + {"role": "user", "content": "Continue"}, + ] + + +def test_bridge_preserves_tool_definitions() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request( + { + "model": "gpt-test", + "input": "Use tool", + "tools": [{"type": "function", "name": "lookup", "description": "Lookup", "parameters": {"type": "object"}}], + } + ) + + kwargs = bridge.to_chat_kwargs(unified) + + assert kwargs["tools"] == [ + {"type": "function", "function": {"name": "lookup", "description": "Lookup", "parameters": {"type": "object"}}} + ] + + +def test_bridge_preserves_unsupported_fields_for_trace_metadata() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hi", "custom_unsupported": 42}) + + kwargs = bridge.to_chat_kwargs(unified) + + assert kwargs["_responses_bridge"]["extra"]["custom_unsupported"] == 42 + + +def test_bridge_converts_chat_response_to_responses_payload() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hello"}) + chat_response = { + "id": "chat_1", + "model": "gpt-test", + "created": 123, + "choices": [{"message": {"role": "assistant", "content": "Hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + + response = bridge.from_chat_response(chat_response, unified) + + assert response["id"] == "chat_1" + assert response["object"] == "response" + assert response["status"] == "completed" + assert response["output"][0]["content"] == [{"type": "output_text", "text": "Hi"}] + assert response["usage"]["input_tokens"] == 1 + assert response["usage"]["output_tokens"] == 2 + assert response["usage"]["total_tokens"] == 3 + + +def test_bridge_preserves_chat_response_top_level_cost_with_usage() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hello"}) + chat_response = { + "id": "chat_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "Hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2}, + "request_cost_usd": 0.012, + } + + response = bridge.from_chat_response(chat_response, unified) + + assert response["usage"]["cost_details"]["request_cost_usd"] == 0.012 + assert response["usage"]["cost_details"]["total_cost"] == 0.012 + + +def test_bridge_preserves_chat_response_estimated_cost_with_usage() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hello"}) + chat_response = { + "id": "chat_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "Hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2}, + "estimated_cost": 0.013, + } + + response = bridge.from_chat_response(chat_response, unified) + + assert response["usage"]["cost_details"]["estimated_cost"] == 0.013 + + +def test_bridge_converts_chat_tool_calls_to_responses_output_items() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hello"}) + chat_response = { + "model": "gpt-test", + "choices": [ + { + "message": { + "role": "assistant", + "tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": "{}"}}], + } + } + ], + } + + response = bridge.from_chat_response(chat_response, unified, response_id="resp_tool") + + assert response["id"] == "resp_tool" + assert response["output"] == [ + {"id": "call_1", "type": "function_call", "call_id": "call_1", "name": "lookup", "arguments": "{}"} + ] diff --git a/tests/test_responses_routes.py b/tests/test_responses_routes.py new file mode 100644 index 000000000..bf1e47188 --- /dev/null +++ b/tests/test_responses_routes.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from fastapi.testclient import TestClient + +from proxy_app import main as proxy_main +from rotator_library.responses import InMemoryResponsesStore, ResponsesService + + +class FakeClient: + async def acompletion(self, **kwargs): + if kwargs.get("stream"): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"route"}}]}\n\n' + yield 'data: {"choices":[{"delta":{"content":" stream"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + return { + "id": "chat_route_1", + "model": kwargs["model"], + "choices": [{"message": {"role": "assistant", "content": "route ok"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + + +def _client() -> TestClient: + proxy_main.PROXY_API_KEY = None + proxy_main.ENABLE_RAW_LOGGING = False + proxy_main.app.state.rotating_client = FakeClient() + proxy_main.app.state.responses_service = ResponsesService(store=InMemoryResponsesStore()) + return TestClient(proxy_main.app) + + +def test_post_responses_non_stream_success() -> None: + client = _client() + + response = client.post("/v1/responses", json={"model": "gpt-test", "input": "hello"}) + + assert response.status_code == 200 + body = response.json() + assert body["id"] == "chat_route_1" + assert body["object"] == "response" + assert body["output"][0]["content"][0]["text"] == "route ok" + + +def test_post_responses_missing_model_returns_400() -> None: + client = _client() + + response = client.post("/v1/responses", json={"input": "hello"}) + + assert response.status_code == 400 + assert response.json()["error"]["type"] == "invalid_request_error" + + +def test_post_responses_stream_missing_model_returns_400_before_sse() -> None: + client = _client() + + response = client.post("/v1/responses", json={"input": "hello", "stream": True}) + + assert response.status_code == 400 + assert response.json()["error"]["type"] == "invalid_request_error" + + +def test_post_responses_stream_missing_previous_response_returns_404_before_sse() -> None: + client = _client() + + response = client.post("/v1/responses", json={"model": "gpt-test", "input": "hello", "stream": True, "previous_response_id": "missing"}) + + assert response.status_code == 404 + assert response.json()["error"]["type"] == "not_found_error" + + +def test_get_delete_and_input_items_routes() -> None: + client = _client() + created = client.post("/v1/responses", json={"model": "gpt-test", "input": ["hello"]}).json() + + get_response = client.get(f"/v1/responses/{created['id']}") + input_items = client.get(f"/v1/responses/{created['id']}/input_items") + deleted = client.delete(f"/v1/responses/{created['id']}") + missing = client.get(f"/v1/responses/{created['id']}") + + assert get_response.status_code == 200 + assert get_response.json()["id"] == created["id"] + assert input_items.status_code == 200 + assert input_items.json() == {"object": "list", "data": ["hello"]} + assert deleted.status_code == 200 + assert deleted.json() == {"id": created["id"], "object": "response.deleted", "deleted": True} + assert missing.status_code == 404 + assert missing.json()["error"]["type"] == "not_found_error" + + +def test_post_responses_stream_returns_sse_events() -> None: + client = _client() + + response = client.post("/v1/responses", json={"model": "gpt-test", "input": "hello", "stream": True}) + + assert response.status_code == 200 + assert "event: response.created" in response.text + assert "event: response.completed" in response.text diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py new file mode 100644 index 000000000..2c0bcf1c0 --- /dev/null +++ b/tests/test_responses_service.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import json + +import pytest + +import rotator_library.responses.service as responses_service_module +from rotator_library.responses import InMemoryResponsesStore, ResponsesService, ResponsesServiceError, ResponsesStoreSettings, StoredResponse +from rotator_library.transaction_logger import TransactionLogger + + +class FakeClient: + def __init__(self) -> None: + self.calls = [] + + async def acompletion(self, **kwargs): + self.calls.append(kwargs) + return { + "id": "chat_response_1", + "model": kwargs["model"], + "choices": [{"message": {"role": "assistant", "content": "Hello back"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + + +class FakeCostClient(FakeClient): + async def acompletion(self, **kwargs): + response = await super().acompletion(**kwargs) + response["usage"]["cost_details"] = {"total_cost": 0.033, "source": "responses_provider"} + return response + + +class FakeInternalClient(FakeClient): + def __init__(self) -> None: + super().__init__() + self._request_builder = object() + self._executor = object() + + async def acompletion(self, **kwargs): + callback = kwargs.pop("_request_context_callback", None) + hints = kwargs.pop("_session_tracking_hints", None) + if callback: + callback( + type( + "Context", + (), + { + "session_id": "session-parent", + "session_affinity_key": "affinity-parent", + "usage_manager_key": "scope-parent", + "classifier": "global", + "session_tracker": None, + "provider": "openai", + "model": "gpt-test", + "session_tracking_namespace": "namespace", + }, + )() + ) + self.internal_hints = hints + return await super().acompletion(**kwargs) + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_create_response_stores_non_streaming_response() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + client = FakeClient() + + response = await service.create_response({"model": "gpt-test", "input": "Hello"}, client) + + assert response["id"] == "chat_response_1" + assert response["output"][0]["content"][0]["text"] == "Hello back" + assert (await store.get("chat_response_1")) is not None + assert client.calls[0]["messages"] == [{"role": "user", "content": "Hello"}] + + +@pytest.mark.asyncio +async def test_store_false_does_not_persist_response() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + response = await service.create_response({"model": "gpt-test", "input": "Hello", "store": False}, FakeClient()) + + assert await store.get(response["id"]) is None + + +@pytest.mark.asyncio +async def test_create_response_applies_storage_ttl() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store, store_settings=ResponsesStoreSettings(ttl_seconds=60)) + + response = await service.create_response({"model": "gpt-test", "input": "Hello"}, FakeClient()) + stored = await store.get(response["id"]) + + assert stored is not None + assert stored.expires_at is not None + assert stored.expires_at > stored.created_at + assert stored.metadata["response_id"] == response["id"] + + +@pytest.mark.asyncio +async def test_service_default_store_honors_max_items() -> None: + class SequencedClient(FakeClient): + def __init__(self) -> None: + super().__init__() + self.index = 0 + + async def acompletion(self, **kwargs): + self.index += 1 + response = await super().acompletion(**kwargs) + response["id"] = f"chat_response_{self.index}" + return response + + service = ResponsesService(store_settings=ResponsesStoreSettings(max_items=1)) + client = SequencedClient() + + first = await service.create_response({"model": "gpt-test", "input": "one"}, client) + second = await service.create_response({"model": "gpt-test", "input": "two"}, client) + + assert await service.store.get(first["id"]) is None + assert await service.store.get(second["id"]) is not None + + +@pytest.mark.asyncio +async def test_previous_response_id_loads_parent_context() -> None: + store = InMemoryResponsesStore() + await store.save( + StoredResponse( + id="resp_parent", + model="gpt-test", + status="completed", + request={"input": "Earlier"}, + response={ + "id": "resp_parent", + "object": "response", + "model": "gpt-test", + "status": "completed", + "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Earlier"}]}], + }, + input_items=["Earlier"], + output_items=[{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Earlier"}]}], + ) + ) + client = FakeClient() + service = ResponsesService(store=store) + + await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}, client) + + assert client.calls[0]["messages"] == [ + {"role": "user", "content": "Earlier"}, + {"role": "assistant", "content": "Earlier"}, + {"role": "user", "content": "Continue"}, + ] + + +@pytest.mark.asyncio +async def test_previous_response_id_loads_full_lineage_oldest_first() -> None: + store = InMemoryResponsesStore() + await store.save( + StoredResponse( + id="resp_grandparent", + model="gpt-test", + status="completed", + request={"model": "gpt-test", "input": "First"}, + response={"id": "resp_grandparent", "object": "response", "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "First answer"}]}]}, + ) + ) + await store.save( + StoredResponse( + id="resp_parent", + model="gpt-test", + status="completed", + request={"model": "gpt-test", "input": "Second", "previous_response_id": "resp_grandparent"}, + response={"id": "resp_parent", "object": "response", "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Second answer"}]}]}, + ) + ) + client = FakeClient() + + await ResponsesService(store=store).create_response({"model": "gpt-test", "input": "Third", "previous_response_id": "resp_parent"}, client) + + assert client.calls[0]["messages"] == [ + {"role": "user", "content": "First"}, + {"role": "assistant", "content": "First answer"}, + {"role": "user", "content": "Second"}, + {"role": "assistant", "content": "Second answer"}, + {"role": "user", "content": "Third"}, + ] + + +@pytest.mark.asyncio +async def test_internal_session_hints_do_not_leak_to_direct_clients_or_traces(tmp_path) -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + await store.save( + StoredResponse( + id="resp_parent", + model="gpt-test", + status="completed", + response={"id": "resp_parent", "object": "response", "output": []}, + metadata={"session_affinity_key": "affinity-parent"}, + ) + ) + client = FakeClient() + + await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}, client, transaction_logger=logger) + + assert "_session_tracking_hints" not in client.calls[0] + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "_session_tracking_hints" not in trace_text + assert "has_session_hints" in trace_text + + +@pytest.mark.asyncio +async def test_internal_client_context_metadata_is_stored_with_response() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + await store.save( + StoredResponse( + id="resp_parent", + model="gpt-test", + status="completed", + response={"id": "resp_parent", "object": "response", "output": []}, + metadata={"session_affinity_key": "affinity-parent"}, + ) + ) + client = FakeInternalClient() + + response = await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}, client) + stored = await store.get(response["id"]) + + assert client.internal_hints["affinity_key"] == "affinity-parent" + assert stored is not None + assert stored.session_id == "session-parent" + assert stored.metadata["session_affinity_key"] == "affinity-parent" + + +@pytest.mark.asyncio +async def test_responses_service_records_response_id_session_anchor() -> None: + class Tracker: + def __init__(self) -> None: + self.calls = [] + + def record_response(self, *args, **kwargs): + self.calls.append((args, kwargs)) + + tracker = Tracker() + + class Client(FakeInternalClient): + async def acompletion(self, **kwargs): + callback = kwargs.pop("_request_context_callback", None) + kwargs.pop("_session_tracking_hints", None) + if callback: + callback( + type( + "Context", + (), + { + "session_id": "session-parent", + "session_affinity_key": "affinity-parent", + "usage_manager_key": "scope-parent", + "classifier": "global", + "session_tracker": tracker, + "provider": "openai", + "model": "gpt-test", + "session_tracking_namespace": "namespace", + }, + )() + ) + self.calls.append(kwargs) + return { + "id": "resp_parent", + "model": kwargs["model"], + "choices": [{"message": {"role": "assistant", "content": "Hello back"}, "finish_reason": "stop"}], + } + + await ResponsesService(store=InMemoryResponsesStore()).create_response({"model": "gpt-test", "input": "Hello"}, Client()) + + assert tracker.calls[0][0][0] == "session-parent" + assert tracker.calls[0][1]["response"] == {"id": "resp_parent", "object": "response"} + + +@pytest.mark.asyncio +async def test_missing_previous_response_id_raises_not_found() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + with pytest.raises(ResponsesServiceError) as exc_info: + await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "missing"}, FakeClient()) + + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_delete_and_list_input_items() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + response = await service.create_response({"model": "gpt-test", "input": ["Hello"]}, FakeClient()) + + assert (await service.get_response(response["id"]))["id"] == response["id"] + assert await service.list_input_items(response["id"]) == {"object": "list", "data": ["Hello"]} + assert await service.delete_response(response["id"]) == {"id": response["id"], "object": "response.deleted", "deleted": True} + with pytest.raises(ResponsesServiceError): + await service.get_response(response["id"]) + + +@pytest.mark.asyncio +async def test_service_emits_transform_trace_passes(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + await service.create_response({"model": "gpt-test", "input": "Hello"}, FakeClient(), transaction_logger=logger) + + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert pass_names == [ + "responses_raw_request", + "responses_parsed_request", + "responses_bridge_chat_request", + "responses_bridge_chat_response", + "responses_parsed_response", + "usage_accounting_summary", + "responses_stored_response", + "responses_final_response", + ] + + +@pytest.mark.asyncio +async def test_service_usage_trace_includes_provider_reported_cost(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + await service.create_response({"model": "gpt-test", "input": "Hello"}, FakeCostClient(), transaction_logger=logger) + + usage_entry = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "usage_accounting_summary"][-1] + assert usage_entry["data"]["cost"]["provider_reported_cost"] == 0.033 + assert usage_entry["metadata"]["pricing_source"] == "usage.cost_details" + + +def test_trace_responses_usage_returns_before_conversion_without_logger(monkeypatch) -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + def fail_extract(*args, **kwargs): + raise AssertionError("usage conversion should be skipped when tracing is disabled") + + monkeypatch.setattr(responses_service_module, "extract_usage_record", fail_extract) + + service._trace_responses_usage(None, {"usage": {"input_tokens": 1}}, "gpt-test", source="test") + + +@pytest.mark.asyncio +async def test_previous_response_trace_payload_skipped_without_logger() -> None: + class Parent: + id = "resp_parent" + response = {"output": []} + output_items = [] + input_items = [] + + def to_dict(self): + raise AssertionError("previous response trace payload should not be built without a logger") + + class Store(InMemoryResponsesStore): + async def get(self, response_id): + return Parent() + + service = ResponsesService(store=Store()) + + parent = await service._load_previous_response("resp_parent", None) + + assert parent.id == "resp_parent" diff --git a/tests/test_responses_store.py b/tests/test_responses_store.py new file mode 100644 index 000000000..15316bb55 --- /dev/null +++ b/tests/test_responses_store.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import json +import time + +import pytest + +from rotator_library.responses import InMemoryResponsesStore, ProviderCacheResponsesStore, StoredResponse, create_configured_responses_store, generate_response_id + + +def _stored(response_id: str = "resp_test") -> StoredResponse: + return StoredResponse( + id=response_id, + model="gpt-test", + status="completed", + request={"model": "gpt-test", "input": "hello"}, + response={"id": response_id, "object": "response", "output": []}, + input_items=[{"type": "message", "role": "user", "content": "hello"}], + output_items=[{"type": "message", "role": "assistant", "content": []}], + usage={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + metadata={"previous_response_id": None}, + ) + + +def test_generate_response_id_uses_responses_prefix() -> None: + assert generate_response_id().startswith("resp_") + assert generate_response_id() != generate_response_id() + + +@pytest.mark.asyncio +async def test_in_memory_store_save_get_delete_and_input_items() -> None: + store = InMemoryResponsesStore() + stored = _stored() + + await store.save(stored) + loaded = await store.get(stored.id) + input_items = await store.list_input_items(stored.id) + deleted = await store.delete(stored.id) + + assert loaded is not None + assert loaded.to_dict() == stored.to_dict() + assert input_items == stored.input_items + assert deleted is True + assert await store.get(stored.id) is None + + +@pytest.mark.asyncio +async def test_in_memory_store_returns_copies_and_expires() -> None: + store = InMemoryResponsesStore() + stored = _stored("resp_expiring") + stored.expires_at = time.time() + 100 + await store.save(stored) + + loaded = await store.get(stored.id) + assert loaded is not None + loaded.response["mutated"] = True + assert (await store.get(stored.id)).response.get("mutated") is None + + stored.expires_at = time.time() - 1 + await store.save(stored) + assert await store.get(stored.id) is None + + +@pytest.mark.asyncio +async def test_in_memory_store_prunes_oldest_when_max_items_exceeded() -> None: + store = InMemoryResponsesStore(max_items=2) + first = _stored("resp_first") + first.created_at = 1 + second = _stored("resp_second") + second.created_at = 2 + third = _stored("resp_third") + third.created_at = 3 + + await store.save(first) + await store.save(second) + await store.save(third) + + assert await store.get("resp_first") is None + assert await store.get("resp_second") is not None + assert await store.get("resp_third") is not None + + +@pytest.mark.asyncio +async def test_provider_cache_store_serializes_json_and_does_not_clear_without_key_delete() -> None: + class FakeProviderCache: + def __init__(self) -> None: + self.values = {} + + async def store_async(self, key: str, value: str) -> None: + json.loads(value) + self.values[key] = value + + async def retrieve_async(self, key: str): + return self.values.get(key) + + cache = FakeProviderCache() + store = ProviderCacheResponsesStore(cache) + stored = _stored("resp_provider_cache") + + await store.save(stored) + loaded = await store.get(stored.id) + + assert loaded is not None + assert loaded.id == stored.id + assert await store.list_input_items(stored.id) == stored.input_items + assert await store.delete(stored.id) is False + assert cache.values + + +@pytest.mark.asyncio +async def test_provider_cache_store_uses_key_delete_when_available() -> None: + class FakeProviderCache: + def __init__(self) -> None: + self.values = {} + + async def store_async(self, key: str, value: str) -> None: + self.values[key] = value + + async def retrieve_async(self, key: str): + return self.values.get(key) + + async def delete_async(self, key: str) -> bool: + return self.values.pop(key, None) is not None + + store = ProviderCacheResponsesStore(FakeProviderCache()) + stored = _stored("resp_delete") + await store.save(stored) + + assert await store.delete(stored.id) is True + assert await store.get(stored.id) is None + + +@pytest.mark.asyncio +async def test_configured_provider_cache_store_persists_between_instances(tmp_path) -> None: + env = { + "RESPONSES_STORE_BACKEND": "provider_cache", + "RESPONSES_STORE_CACHE_NAME": "responses_test", + "RESPONSES_STORE_CACHE_PREFIX": "responses", + "RESPONSES_STORE_CACHE_DIR": str(tmp_path), + "RESPONSES_STORE_CACHE_MEMORY_TTL_SECONDS": "60", + "RESPONSES_STORE_CACHE_DISK_TTL_SECONDS": "60", + } + first_store = create_configured_responses_store(env=env) + stored = _stored("resp_durable") + + await first_store.save(stored) + second_store = create_configured_responses_store(env=env) + loaded = await second_store.get(stored.id) + + assert loaded is not None + assert loaded.to_dict() == stored.to_dict() diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py new file mode 100644 index 000000000..82fdc9af1 --- /dev/null +++ b/tests/test_responses_streaming.py @@ -0,0 +1,697 @@ +from __future__ import annotations + +import asyncio +import json + +import pytest + +from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesStoreSettings, ResponsesStreamEvent, ResponsesWebSocketFormatter +from rotator_library.transaction_logger import TransactionLogger + + +def _event_names(events: list[str]) -> list[str]: + return [line.removeprefix("event: ") for event in events for line in event.splitlines() if line.startswith("event: ")] + + +class FakeStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"role":"assistant"}}]}\n\n' + yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"content":"Hel"}}]}\n\n' + yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"content":"lo"}}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"cost_details":{"total_cost":0.044,"source":"stream_provider"}}}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class DelayedCloseableStream: + def __init__(self, chunks: list[str], delays: list[float]) -> None: + self.chunks = chunks + self.delays = delays + self.index = 0 + self.closed = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.chunks): + raise StopAsyncIteration + delay = self.delays[self.index] + chunk = self.chunks[self.index] + self.index += 1 + await asyncio.sleep(delay) + return chunk + + async def aclose(self) -> None: + self.closed = True + + +class DelayedStreamingClient: + def __init__(self, stream: DelayedCloseableStream) -> None: + self.stream = stream + + async def acompletion(self, **kwargs): + return self.stream + + +class SlowAcquireStreamingClient: + def __init__(self, stream: DelayedCloseableStream, delay: float) -> None: + self.stream = stream + self.delay = delay + self.calls = 0 + + async def acompletion(self, **kwargs): + self.calls += 1 + await asyncio.sleep(self.delay) + return self.stream + + +class CancelAwareAcquireClient: + def __init__(self, delay: float = 1.0) -> None: + self.cancelled = False + self.delay = delay + + async def acompletion(self, **kwargs): + try: + await asyncio.sleep(self.delay) + except asyncio.CancelledError: + self.cancelled = True + raise + return DelayedCloseableStream([], []) + + +class DisconnectRequest: + async def is_disconnected(self) -> bool: + return True + + +class DisconnectAfterAcquireRequest: + def __init__(self) -> None: + self.calls = 0 + + async def is_disconnected(self) -> bool: + self.calls += 1 + return self.calls > 1 + + +class FailingStreamingClient: + def __init__(self, message: str = "stream exploded") -> None: + self.message = message + + async def acompletion(self, **kwargs): + message = self.message + + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"before"}}]}\n\n' + raise RuntimeError(message) + + return chunks() + + +class ErrorChunkStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"partial"}}]}\n\n' + yield 'data: {"error":{"message":"provider failed"}}\n\n' + + return chunks() + + +class EventErrorStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"partial"}}]}\n\n' + yield 'event: error\ndata: {"message":"event failed"}\n\n' + + return chunks() + + +class CostCommentStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield ': cost {"total_cost":0.031,"currency":"USD","source":"responses_sse"}\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class CostEventStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'event: cost\ndata: {"total_cost":0.017,"currency":"EUR","source":"responses_event_cost"}\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class RequestCostCommentStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield ': cost {"request_cost_usd":0.023,"source":"reference_sse"}\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class EstimatedCostCommentStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield ': cost {"estimated_cost":0.024,"source":"reference_estimate"}\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class TopLevelUsageCostStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"priced"}}],"usage":{"prompt_tokens":1,"completion_tokens":1},"total_cost":0.029}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class CostCommentFinalOverrideStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield ': cost 0.031\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}],"usage":{"prompt_tokens":1,"completion_tokens":1,"cost_details":{"total_cost":0.062,"source":"final_usage"}}}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class FailingStore: + async def save(self, response): + raise RuntimeError("store failed Authorization: Bearer secret-token") + + async def get(self, response_id): + return None + + async def delete(self, response_id): + return False + + async def list_input_items(self, response_id): + return None + + +@pytest.mark.asyncio +async def test_stream_response_emits_responses_sse_events_and_stores_final_response() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient())] + + event_text = "".join(events) + assert _event_names(events) == [ + "response.created", + "response.output_item.added", + "response.output_text.delta", + "response.output_text.delta", + "response.output_item.done", + "response.completed", + ] + assert event_text.endswith("data: [DONE]\n\n") + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.output_items[0]["content"][0]["text"] == "Hello" + assert stored.usage == {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3, "cost_details": {"total_cost": 0.044, "source": "stream_provider"}} + + +@pytest.mark.asyncio +async def test_stream_response_store_false_does_not_persist(tmp_path) -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True, "store": False}, FakeStreamingClient(), transaction_logger=logger)] + + response_id = events[0].split('"id": "')[1].split('"')[0] + assert await store.get(response_id) is None + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "responses_store_skipped" in trace_text + assert "responses_stored_stream_response" not in trace_text + + +@pytest.mark.asyncio +async def test_stream_response_errors_emit_failed_event() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FailingStreamingClient())] + + event_text = "".join(events) + assert "event: response.failed" in event_text + assert "stream exploded" in event_text + assert event_text.endswith("data: [DONE]\n\n") + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.status == "failed" + + +@pytest.mark.asyncio +async def test_stream_response_error_chunks_store_failed_with_partial_output() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, ErrorChunkStreamingClient())] + + event_text = "".join(events) + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert "event: response.failed" in event_text + assert stored is not None + assert stored.status == "failed" + assert stored.output_items[0]["content"][0]["text"] == "partial" + + +@pytest.mark.asyncio +async def test_stream_response_event_error_frames_are_failed() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, EventErrorStreamingClient())] + + event_text = "".join(events) + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert "event: response.failed" in event_text + assert stored is not None + assert stored.response["error"]["message"] == "event failed" + + +@pytest.mark.asyncio +async def test_stream_response_preserves_sse_cost_comment() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, CostCommentStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["total_cost"] == 0.031 + assert stored.usage["cost_details"]["source"] == "responses_sse" + + +@pytest.mark.asyncio +async def test_stream_response_preserves_sse_cost_event() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, CostEventStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["total_cost"] == 0.017 + assert stored.usage["cost_details"]["currency"] == "EUR" + + +@pytest.mark.asyncio +async def test_stream_response_preserves_reference_request_cost_comment() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, RequestCostCommentStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["request_cost_usd"] == 0.023 + + +@pytest.mark.asyncio +async def test_stream_response_preserves_estimated_cost_comment() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, EstimatedCostCommentStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["estimated_cost"] == 0.024 + + +@pytest.mark.asyncio +async def test_stream_response_preserves_top_level_chunk_cost_with_usage() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, TopLevelUsageCostStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["total_cost"] == 0.029 + + +@pytest.mark.asyncio +async def test_stream_response_final_usage_cost_overrides_sse_cost_comment() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, CostCommentFinalOverrideStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["total_cost"] == 0.062 + assert stored.usage["cost_details"]["source"] == "final_usage" + + +@pytest.mark.asyncio +async def test_stream_response_can_skip_failed_storage() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store, store_settings=ResponsesStoreSettings(store_failed=False)) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FailingStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + assert await store.get(response_id) is None + + +@pytest.mark.asyncio +async def test_stream_events_can_store_in_progress_state() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store, store_settings=ResponsesStoreSettings(store_in_progress=True)) + stream = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient()) + + created = await anext(stream) + stored = await store.get(created.payload["id"]) + await stream.aclose() + + assert stored is not None + assert stored.status == "in_progress" + + +@pytest.mark.asyncio +async def test_stream_response_store_failures_emit_store_specific_trace(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=FailingStore()) + + with pytest.raises(RuntimeError): + _ = [ + chunk + async for chunk in service.stream_response( + {"model": "gpt-test", "input": "Hello", "stream": True}, + FakeStreamingClient(), + transaction_logger=logger, + ) + ] + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + errors = [entry for entry in entries if entry["pass_name"] == "transform_log_error"] + assert any(entry["data"]["failed_pass_name"] == "responses_store_stream_response" for entry in errors) + assert "secret-token" not in json.dumps(errors) + + +@pytest.mark.asyncio +async def test_stream_current_state_store_failures_emit_store_specific_trace(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=FailingStore(), store_settings=ResponsesStoreSettings(store_in_progress=True)) + + with pytest.raises(RuntimeError): + _ = [ + event + async for event in service.stream_events( + {"model": "gpt-test", "input": "Hello", "stream": True}, + FakeStreamingClient(), + transaction_logger=logger, + ) + ] + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + errors = [entry for entry in entries if entry["pass_name"] == "transform_log_error"] + assert any(entry["data"]["failed_pass_name"] == "responses_store_stream_current_state" for entry in errors) + assert "secret-token" not in json.dumps(errors) + + +@pytest.mark.asyncio +async def test_stream_response_failure_trace_scrubs_header_like_secret_text(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + _ = [ + chunk + async for chunk in service.stream_response( + {"model": "gpt-test", "input": "Hello", "stream": True}, + FailingStreamingClient("{'Authorization': 'Bearer secret-token'}"), + transaction_logger=logger, + ) + ] + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "secret-token" not in trace_text + assert "[REDACTED]" in trace_text + + +def test_transport_formatters_expose_sse_and_websocket_seam() -> None: + assert ResponsesSSEFormatter().transport == "sse" + assert ResponsesSSEFormatter().format_stream_event(ResponsesStreamEvent("heartbeat", {"comment": "heartbeat"})) == ": heartbeat\n\n" + websocket = ResponsesWebSocketFormatter() + assert websocket.transport == "websocket" + assert websocket.future_supported is True + assert websocket.format_stream_event(ResponsesStreamEvent("response.created", {"id": "resp"})) == '{"event": "response.created", "data": {"id": "resp"}}' + + +@pytest.mark.asyncio +async def test_stream_response_emits_non_visible_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream( + ['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n', "data: [DONE]\n\n"], + [0.03, 0.0], + ) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream))] + + assert ": heartbeat\n\n" in events + assert "event: response.output_text.delta" in "".join(events) + + +@pytest.mark.asyncio +async def test_stream_response_ttfb_timeout_closes_upstream(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.05]) + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream))] + + event_text = "".join(events) + assert stream.closed is True + assert "event: response.failed" in event_text + assert "ttfb" in event_text.lower() + assert event_text.endswith("data: [DONE]\n\n") + + +@pytest.mark.asyncio +async def test_stream_response_heartbeat_does_not_reset_ttfb_timeout(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.03") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.1]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream))] + + event_text = "".join(events) + assert ": heartbeat\n\n" in events + assert stream.closed is True + assert "event: response.failed" in event_text + assert "ttfb" in event_text.lower() + + +@pytest.mark.asyncio +async def test_stream_response_acquire_wait_honors_ttfb_timeout(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.0]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, SlowAcquireStreamingClient(stream, 0.05))] + + event_text = "".join(events) + assert "event: response.failed" in event_text + assert "ttfb" in event_text.lower() + + +@pytest.mark.asyncio +async def test_stream_response_ttfb_spans_acquire_and_first_chunk(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.05") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.03]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, SlowAcquireStreamingClient(stream, 0.03))] + + event_text = "".join(events) + assert stream.closed is True + assert "event: response.failed" in event_text + assert "ttfb" in event_text.lower() + + +@pytest.mark.asyncio +async def test_stream_response_acquire_wait_can_emit_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n', "data: [DONE]\n\n"], [0.0, 0.0]) + service = ResponsesService(store=InMemoryResponsesStore()) + + client = SlowAcquireStreamingClient(stream, 0.03) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, client)] + + assert ": heartbeat\n\n" in events + assert client.calls == 1 + assert "event: response.completed" in "".join(events) + + +@pytest.mark.asyncio +async def test_stream_response_heartbeat_does_not_drop_completed_first_chunk(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream( + ['data: {"choices":[{"delta":{"content":"kept"}}]}\n\n', "data: [DONE]\n\n"], + [0.05, 0.0], + ) + service = ResponsesService(store=InMemoryResponsesStore()) + events = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream)) + + created = await anext(events) + heartbeat = await anext(events) + await asyncio.sleep(0.06) + added = await anext(events) + delta = await anext(events) + await events.aclose() + + assert created.event_name == "response.created" + assert heartbeat.heartbeat is True + assert added.event_name == "response.output_item.added" + assert delta.event_name == "response.output_text.delta" + assert delta.payload["delta"] == "kept" + + +@pytest.mark.asyncio +async def test_stream_events_aclose_cancels_pending_read_after_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [1.0]) + service = ResponsesService(store=InMemoryResponsesStore()) + events = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream)) + + assert (await anext(events)).event_name == "response.created" + assert (await anext(events)).heartbeat is True + await events.aclose() + + assert stream.closed is True + + +@pytest.mark.asyncio +async def test_stream_events_aclose_cancels_pending_acquire_after_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + client = CancelAwareAcquireClient() + service = ResponsesService(store=InMemoryResponsesStore()) + events = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, client) + + assert (await anext(events)).event_name == "response.created" + assert (await anext(events)).heartbeat is True + await events.aclose() + + assert client.cancelled is True + + +@pytest.mark.asyncio +async def test_stream_events_aclose_closes_completed_acquire_stream_after_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [1.0]) + client = SlowAcquireStreamingClient(stream, 0.015) + service = ResponsesService(store=InMemoryResponsesStore()) + events = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, client) + + assert (await anext(events)).event_name == "response.created" + assert (await anext(events)).heartbeat is True + await asyncio.sleep(0.02) + await events.aclose() + + assert client.calls == 1 + assert stream.closed is True + + +@pytest.mark.asyncio +async def test_stream_response_stall_timeout_preserves_partial_output(monkeypatch) -> None: + monkeypatch.setenv("STREAM_STALL_TIMEOUT_SECONDS", "0.01") + stream = DelayedCloseableStream( + ['data: {"choices":[{"delta":{"content":"partial"}}]}\n\n', 'data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], + [0.0, 0.05], + ) + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream))] + + event_text = "".join(events) + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stream.closed is True + assert "event: response.failed" in event_text + assert "stall" in event_text.lower() + assert stored is not None + assert stored.output_items[0]["content"][0]["text"] == "partial" + + +@pytest.mark.asyncio +async def test_stream_events_disconnect_closes_upstream(monkeypatch) -> None: + monkeypatch.setenv("STREAM_CANCEL_UPSTREAM_ON_DISCONNECT", "true") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.05]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [event async for event in service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream), request=DisconnectAfterAcquireRequest())] + + assert [event.event_name for event in events] == ["response.created"] + assert stream.closed is True + + +@pytest.mark.asyncio +async def test_responses_stream_records_common_stream_metrics(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + _ = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient(), transaction_logger=logger)] + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "stream_started" in trace_text + assert "stream_first_byte" in trace_text + assert "stream_first_visible_output" in trace_text + assert "stream_completed" in trace_text + assert "stream_metrics_final" in trace_text + assert "stream_done_event" in trace_text + assert "responses_stream_event_created" in trace_text + assert "responses_stream_event_output_text_delta" in trace_text + assert "responses_stream_event_completed" in trace_text + assert "responses_sse_formatted_event" in trace_text + usage_entry = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines() if '"usage_accounting_summary"' in line][-1] + assert usage_entry["data"]["cost"]["provider_reported_cost"] == 0.044 + + +@pytest.mark.asyncio +async def test_stream_events_are_transport_neutral_and_sse_wraps_them() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [event async for event in service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient())] + sse = [ResponsesSSEFormatter().format_stream_event(event) for event in events] + + assert [event.event_name for event in events if not event.terminal] == [ + "response.created", + "response.output_item.added", + "response.output_text.delta", + "response.output_text.delta", + "response.output_item.done", + "response.completed", + ] + assert events[-1].terminal is True + assert "".join(sse).endswith("data: [DONE]\n\n") diff --git a/tests/test_responses_usage_accounting.py b/tests/test_responses_usage_accounting.py new file mode 100644 index 000000000..789868bb0 --- /dev/null +++ b/tests/test_responses_usage_accounting.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.responses import InMemoryResponsesStore, ResponsesService +from rotator_library.transaction_logger import TransactionLogger + + +class FakeUsageClient: + async def acompletion(self, **kwargs): + return { + "id": "chatcmpl_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "hi"}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 20, + "completion_tokens": 8, + "completion_tokens_details": {"reasoning_tokens": 3}, + }, + } + + +class FakeStreamingUsageClient: + async def acompletion(self, **kwargs): + async def gen(): + yield 'data: {"id":"chunk_1","choices":[{"delta":{"content":"hi"}}]}\n\n' + yield 'data: {"id":"chunk_2","choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":20,"completion_tokens":8,"total_tokens":28,"completion_tokens_details":{"reasoning_tokens":3}}}\n\n' + + return gen() + + +@pytest.mark.asyncio +async def test_responses_create_traces_normalized_usage(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + await service.create_response({"model": "gpt-test", "input": "hello"}, FakeUsageClient(), transaction_logger=logger) + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + usage_entries = [entry for entry in entries if entry["pass_name"] == "usage_accounting_summary"] + assert usage_entries + assert usage_entries[-1]["data"]["usage"]["completion_tokens"] == 5 + assert usage_entries[-1]["data"]["usage"]["reasoning_tokens"] == 3 + + +@pytest.mark.asyncio +async def test_responses_stream_preserves_usage_details_in_completed_event() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + chunks = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "hello", "stream": True}, FakeStreamingUsageClient())] + completed = [chunk for chunk in chunks if "response.completed" in chunk][0] + payload = json.loads(completed.split("data: ", 1)[1]) + + assert payload["usage"]["output_tokens_details"] == {"reasoning_tokens": 3} diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py new file mode 100644 index 000000000..6e2fee72e --- /dev/null +++ b/tests/test_retry_policy.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import asyncio + +from rotator_library.error_handler import ClassifiedError, PreRequestCallbackError, classify_error +from rotator_library.retry_policy import ( + FailureHistory, + classify_route_error, + decide_provider_cooldown, + is_model_capacity_error, + is_target_failover_eligible, + should_retry_same_credential, + should_rotate_credential, +) + + +class ExplicitRouteError(Exception): + error_type = "unsupported_operation" + + +def _classified(error_type: str, **kwargs) -> ClassifiedError: + return ClassifiedError(error_type, original_exception=Exception(error_type), **kwargs) + + +def test_classifier_output_maps_to_fallback_policy_categories() -> None: + retryable = ["rate_limit", "quota_exceeded", "server_error", "api_connection", "unsupported_operation"] + stopped = [ + "authentication", + "forbidden", + "invalid_request", + "context_window_exceeded", + "credential_reauth_needed", + "pre_request_callback_error", + "cancelled", + ] + + assert all(is_target_failover_eligible(error_type) for error_type in retryable) + assert all(not is_target_failover_eligible(error_type) for error_type in stopped) + assert not is_target_failover_eligible("unknown") + + +def test_classify_route_error_preserves_explicit_and_cancelled_types() -> None: + assert classify_route_error(ExplicitRouteError()) == "unsupported_operation" + assert classify_route_error(asyncio.CancelledError()) == "cancelled" + assert classify_route_error(PreRequestCallbackError("boom")) == "pre_request_callback_error" + + +def test_retry_and_rotation_helpers_delegate_to_existing_semantics() -> None: + small_rate_limit = _classified("rate_limit", retry_after=3) + invalid_request = _classified("invalid_request") + + assert should_retry_same_credential(small_rate_limit, small_cooldown_threshold=10) is True + assert should_rotate_credential(small_rate_limit) is True + assert should_rotate_credential(invalid_request) is False + + +def test_provider_cooldown_uses_large_retry_after_not_small_retry_after() -> None: + small = decide_provider_cooldown( + _classified("rate_limit", retry_after=3), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + ) + large = decide_provider_cooldown( + _classified("rate_limit", retry_after=60), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + ) + + assert small.should_start is False + assert small.reason == "small_retry_after" + assert large.should_start is True + assert large.duration == 60 + assert large.scope == "provider" + + +def test_provider_cooldown_is_conservative_for_quota_by_default() -> None: + disabled = decide_provider_cooldown( + _classified("quota_exceeded", retry_after=3600), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + ) + enabled = decide_provider_cooldown( + _classified("quota_exceeded", retry_after=3600), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + cooldown_on_quota=True, + ) + + assert disabled.should_start is False + assert disabled.reason == "quota_cooldown_disabled" + assert enabled.should_start is True + + +def test_model_capacity_error_uses_model_scoped_cooldown() -> None: + error = Exception("503 MODEL_CAPACITY_EXHAUSTED") + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=30, + model="gpt-test", + original_error=error, + ) + + assert is_model_capacity_error(error) is True + assert decision.should_start is True + assert decision.scope == "model" + assert decision.model == "gpt-test" + assert decision.reason == "model_capacity_cooldown" + + +def test_failure_history_escalates_repeated_transient_backoff(monkeypatch) -> None: + now = 1000.0 + history = FailureHistory(clock=lambda: now) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "3") + monkeypatch.setenv("PROVIDER_BACKOFF_BASE_SECONDS", "10") + monkeypatch.setenv("PROVIDER_BACKOFF_MAX_SECONDS", "40") + history.record(provider="openai", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + history.record(provider="openai", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="openai", + failure_history=history, + ) + + assert decision.duration == 10 + assert decision.backoff_level == 1 + history.record(provider="openai", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="openai", + failure_history=history, + ) + assert decision.duration == 20 + assert decision.backoff_level == 2 + + +def test_single_generic_transient_without_retry_after_does_not_cooldown(monkeypatch) -> None: + history = FailureHistory(clock=lambda: 1000.0) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "3") + + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="openai", + failure_history=history, + ) + + assert decision.should_start is False + assert decision.reason == "transient_backoff_threshold_not_met" + + +def test_failure_history_clear_resets_repeated_transient_backoff(monkeypatch) -> None: + history = FailureHistory(clock=lambda: 1000.0) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") + history.record(provider="openai", model=None, error_type="server_error", scope="provider", duration=0, reason="first") + + history.clear(provider="openai", model="gpt-test") + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="openai", + failure_history=history, + ) + + assert history.snapshot() == () + assert decision.should_start is False + + +def test_failure_history_backoff_is_provider_scoped(monkeypatch) -> None: + history = FailureHistory(clock=lambda: 1000.0) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") + monkeypatch.setenv("PROVIDER_BACKOFF_BASE_SECONDS", "10") + history.record(provider="provider-a", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + + provider_b = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="provider-b", + failure_history=history, + ) + provider_a = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="provider-a", + failure_history=history, + ) + + assert provider_b.backoff_level == 0 + assert provider_a.backoff_level == 1 + + +def test_failure_history_backoff_requires_matching_provider(monkeypatch) -> None: + history = FailureHistory(clock=lambda: 1000.0) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") + history.record(provider="provider-a", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + failure_history=history, + ) + + assert decision.backoff_level == 0 + + +def test_shared_classifier_handles_structured_dict_status_codes() -> None: + assert classify_error({"error": {"status": 401}}).error_type == "authentication" + assert classify_error({"error": {"details": {"status_code": 403}}}).error_type == "forbidden" + assert classify_error({"error": {"code": 429}}).error_type == "rate_limit" + + +def test_shared_classifier_handles_structured_dict_type_and_code_text() -> None: + assert classify_error({"error": {"type": "authentication"}}).error_type == "authentication" + assert classify_error({"error": {"code": "permission_denied"}}).error_type == "forbidden" + assert classify_error({"error": {"code": "context_length_exceeded"}}).error_type == "context_window_exceeded" + assert classify_error({"error": {"status_code": 400, "code": "context_length_exceeded"}}).error_type == "context_window_exceeded" + assert classify_error({"error": {"type": "rate_limit"}}).error_type == "rate_limit" + + +def test_shared_classifier_preserves_explicit_error_type_attributes() -> None: + class ConfigurationFailure(Exception): + error_type = "configuration_error" + + assert classify_error(ConfigurationFailure("raw secret message")).error_type == "configuration_error" diff --git a/tests/test_routing_attempts.py b/tests/test_routing_attempts.py new file mode 100644 index 000000000..4f84c76a9 --- /dev/null +++ b/tests/test_routing_attempts.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from rotator_library.core.types import RequestContext +from rotator_library.routing import clone_context_for_target, parse_route_target + + +def test_clone_context_for_target_updates_model_provider_without_mutating_original() -> None: + original = RequestContext( + model="requested", + provider="original", + kwargs={"model": "requested", "messages": []}, + streaming=False, + credentials=["cred-a"], + deadline=123.0, + usage_manager_key="original", + ) + target = parse_route_target("codex/gpt-5.1-codex@native") + + cloned = clone_context_for_target(original, target, credentials=["cred-b"], target_index=1) + + assert cloned.model == "codex/gpt-5.1-codex" + assert cloned.provider == "codex" + assert cloned.kwargs["model"] == "codex/gpt-5.1-codex" + assert cloned.credentials == ["cred-b"] + assert cloned.usage_manager_key == "codex" + assert cloned.routing_target_index == 1 + assert original.model == "requested" + assert original.kwargs["model"] == "requested" + assert original.credentials == ["cred-a"] + + +def test_clone_context_for_target_preserves_request_metadata() -> None: + original = RequestContext( + model="requested", + provider="original", + kwargs={"model": "requested"}, + streaming=True, + credentials=["cred-a"], + deadline=123.0, + session_id="session-1", + classifier="global", + routing_group_name="chain", + ) + + cloned = clone_context_for_target(original, parse_route_target("openai/gpt-5.1")) + + assert cloned.streaming is True + assert cloned.session_id == "session-1" + assert cloned.classifier == "global" + assert cloned.routing_group_name == "chain" + + +def test_clone_context_for_target_rewrites_standard_session_namespace() -> None: + original = RequestContext( + model="openai/gpt-5", + provider="openai", + kwargs={"model": "openai/gpt-5"}, + streaming=False, + credentials=["cred-a"], + deadline=123.0, + session_tracking_namespace="scope:openai:provider:openai:model:openai/gpt-5", + ) + + cloned = clone_context_for_target(original, parse_route_target("anthropic/claude")) + + assert cloned.session_tracking_namespace == "scope:anthropic:provider:anthropic:model:anthropic/claude" + + +def test_clone_context_for_target_preserves_custom_session_namespace() -> None: + original = RequestContext( + model="openai/gpt-5", + provider="openai", + kwargs={"model": "openai/gpt-5"}, + streaming=False, + credentials=["cred-a"], + deadline=123.0, + session_tracking_namespace="custom-namespace", + ) + + cloned = clone_context_for_target(original, parse_route_target("anthropic/claude")) + + assert cloned.session_tracking_namespace == "custom-namespace" diff --git a/tests/test_routing_config.py b/tests/test_routing_config.py new file mode 100644 index 000000000..353f0da60 --- /dev/null +++ b/tests/test_routing_config.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import pytest + +from rotator_library.routing import RoutingConfigError, load_routing_config_from_env, parse_route_target + + +def test_parse_route_target_supports_execution_suffix() -> None: + target = parse_route_target("codex/gpt-5.1-codex@native") + + assert target.provider == "codex" + assert target.model == "gpt-5.1-codex" + assert target.execution == "native" + assert target.prefixed_model == "codex/gpt-5.1-codex" + + +def test_parse_route_target_rejects_missing_provider() -> None: + with pytest.raises(RoutingConfigError): + parse_route_target("gpt-5.1") + + +def test_load_routing_config_from_env_parses_group_and_model_route() -> None: + config = load_routing_config_from_env( + { + "FALLBACK_GROUPS": "sonnet_chain", + "FALLBACK_GROUP_SONNET_CHAIN": "claude_code/claude-sonnet-4-5,copilot/claude-sonnet-4-5@litellm_fallback", + "MODEL_ROUTE_CLAUDE_SONNET": "group:sonnet_chain", + } + ) + + assert tuple(config.fallback_groups) == ("sonnet_chain",) + assert config.fallback_groups["sonnet_chain"].targets[1].execution == "litellm_fallback" + assert config.model_routes["claude_sonnet"] == "group:sonnet_chain" + + +def test_load_routing_config_rejects_empty_group() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env({"FALLBACK_GROUPS": "empty", "FALLBACK_GROUP_EMPTY": ""}) + + +def test_load_routing_config_rejects_duplicate_group_names() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env({"FALLBACK_GROUPS": "a,a", "FALLBACK_GROUP_A": "openai/gpt"}) + + +def test_load_routing_config_rejects_unknown_group_route() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env({"MODEL_ROUTE_CODEX": "group:missing"}) + + +def test_load_routing_config_parses_group_policy_overrides() -> None: + config = load_routing_config_from_env( + { + "FALLBACK_GROUPS": "chain", + "FALLBACK_GROUP_CHAIN": "openai/gpt,copilot/gpt", + "FALLBACK_GROUP_CHAIN_FAILOVER_ON": "network,transient", + "FALLBACK_GROUP_CHAIN_STOP_ON": "validation", + "FALLBACK_GROUP_CHAIN_STREAMING_POLICY": "never", + } + ) + + group = config.fallback_groups["chain"] + assert group.failover_on == frozenset({"api_connection", "server_error"}) + assert group.stop_on == frozenset({"invalid_request"}) + assert group.streaming_policy == "never" + + +def test_load_routing_config_rejects_hard_stop_failover() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env( + { + "FALLBACK_GROUPS": "chain", + "FALLBACK_GROUP_CHAIN": "openai/gpt,copilot/gpt", + "FALLBACK_GROUP_CHAIN_FAILOVER_ON": "auth", + } + ) + + with pytest.raises(RoutingConfigError): + load_routing_config_from_env( + { + "FALLBACK_GROUPS": "chain", + "FALLBACK_GROUP_CHAIN": "openai/gpt,copilot/gpt", + "FALLBACK_GROUP_CHAIN_FAILOVER_ON": "pre_request_callback", + } + ) + + +def test_load_routing_config_rejects_unknown_streaming_policy() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env( + { + "FALLBACK_GROUPS": "chain", + "FALLBACK_GROUP_CHAIN": "openai/gpt,copilot/gpt", + "FALLBACK_GROUP_CHAIN_STREAMING_POLICY": "always", + } + ) diff --git a/tests/test_session_tracking.py b/tests/test_session_tracking.py index 22262252a..adb85ec8c 100644 --- a/tests/test_session_tracking.py +++ b/tests/test_session_tracking.py @@ -166,6 +166,30 @@ def test_response_anchors_bridge_next_request(self): self.assertEqual(inferred.session_id, continued.session_id) + def test_response_id_anchor_bridges_responses_previous_response_id(self): + tracker = SessionTracker(ttl_seconds=3600) + inferred = tracker.infer_session( + {"messages": [{"role": "user", "content": "Start a Responses API conversation."}]}, + provider="openai", + model="gpt-test", + ) + tracker.record_response( + inferred.session_id, + provider="openai", + model="gpt-test", + response={"id": "resp_parent", "object": "response"}, + ) + + continued = tracker.infer_session( + {"messages": [{"role": "user", "content": "Continue."}]}, + provider="openai", + model="gpt-test", + hints={"strong_anchors": ["responses_previous_response_id:resp_parent"]}, + ) + + self.assertEqual(inferred.session_id, continued.session_id) + self.assertEqual(continued.confidence, "strong") + def test_compaction_probe_detects_user_summary_from_prior_response(self): tracker = SessionTracker(ttl_seconds=3600) request = { diff --git a/tests/test_stream_events.py b/tests/test_stream_events.py new file mode 100644 index 000000000..365e9cfbe --- /dev/null +++ b/tests/test_stream_events.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from rotator_library.streaming import stream_event_from_sse_chunk + + +def test_stream_event_from_sse_chunk_detects_visible_chat_delta() -> None: + event = stream_event_from_sse_chunk('data: {"choices":[{"delta":{"content":"hi"}}]}\n\n') + + assert event.event_type == "delta" + assert event.visible_output is True + + +def test_stream_event_from_sse_chunk_treats_error_and_done_as_not_visible() -> None: + error_event = stream_event_from_sse_chunk('data: {"error":{"type":"rate_limit"}}\n\n') + done_event = stream_event_from_sse_chunk("data: [DONE]\n\n") + + assert error_event.event_type == "error" + assert error_event.visible_output is False + assert done_event.event_type == "completed" + assert done_event.visible_output is False + + +def test_stream_event_from_sse_chunk_malformed_fails_closed() -> None: + event = stream_event_from_sse_chunk("data: not-json\n\n") + + assert event.event_type == "metadata" + assert event.visible_output is False diff --git a/tests/test_stream_metrics.py b/tests/test_stream_metrics.py new file mode 100644 index 000000000..5f082f0c1 --- /dev/null +++ b/tests/test_stream_metrics.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from rotator_library.streaming import StreamEvent, StreamMonitor + + +class FakeClock: + def __init__(self) -> None: + self.now = 0.0 + + def __call__(self) -> float: + return self.now + + def advance(self, seconds: float) -> None: + self.now += seconds + + +def test_stream_metrics_records_ttfb_ttft_and_counts() -> None: + clock = FakeClock() + monitor = StreamMonitor(clock=clock) + clock.advance(0.5) + monitor.record_event(StreamEvent("parsed_chunk")) + clock.advance(0.7) + monitor.record_event(StreamEvent("delta", visible_output=True)) + clock.advance(0.3) + monitor.complete() + + assert monitor.metrics.ttfb_seconds == 0.5 + assert monitor.metrics.ttft_seconds == 1.2 + assert monitor.metrics.duration_seconds == 1.5 + assert monitor.metrics.chunk_count == 2 + assert monitor.metrics.visible_chunk_count == 1 + + +def test_stream_monitor_records_cancel_and_stall() -> None: + clock = FakeClock() + monitor = StreamMonitor(clock=clock) + monitor.record_event(StreamEvent("parsed_chunk")) + clock.advance(5) + + assert monitor.is_stalled(2) is True + monitor.cancel() + assert monitor.metrics.cancelled is True diff --git a/tests/test_stream_policy.py b/tests/test_stream_policy.py new file mode 100644 index 000000000..4f279eab2 --- /dev/null +++ b/tests/test_stream_policy.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from rotator_library.client.stream_retry_policy import can_retry_stream_after_error as compat_retry_policy +from rotator_library.streaming.policy import can_retry_stream_after_error, is_visible_stream_output + + +def test_reasoning_only_retry_policy_is_preserved() -> None: + reasoning_chunk = 'data: {"choices":[{"delta":{"reasoning_content":"thinking"}}]}\n\n' + text_chunk = 'data: {"choices":[{"delta":{"content":"visible"}}]}\n\n' + + assert can_retry_stream_after_error(None, False) is True + assert can_retry_stream_after_error(reasoning_chunk, True) is True + assert can_retry_stream_after_error(reasoning_chunk, False) is False + assert can_retry_stream_after_error(text_chunk, True) is False + assert compat_retry_policy(reasoning_chunk, True) is True + + +def test_heartbeat_comments_do_not_block_stream_retry() -> None: + assert can_retry_stream_after_error(': heartbeat\n\n', False) is True + + +def test_cost_events_do_not_count_as_visible_output() -> None: + cost_event = 'event: cost\ndata: {"total_cost":0.01}\n\n' + scalar_cost_event = "event: cost\ndata: 0.01\n\n" + + assert is_visible_stream_output(cost_event) is False + assert is_visible_stream_output(cost_event, protocol="responses") is False + assert can_retry_stream_after_error(cost_event, False) is True + assert is_visible_stream_output(scalar_cost_event) is False + assert can_retry_stream_after_error(scalar_cost_event, False) is True + + +def test_visible_output_detection_for_chat_chunks() -> None: + assert is_visible_stream_output('data: {"choices":[{"delta":{"content":"hello"}}]}\n\n') is True + assert is_visible_stream_output('data: {"choices":[{"delta":{"tool_calls":[{"id":"call_1"}]}}]}\n\n') is True + assert is_visible_stream_output('data: {"error":{"type":"rate_limit"}}\n\n') is False + assert is_visible_stream_output('event: error\ndata: {"type":"error","error":{"message":"x"}}\n\n') is False + assert is_visible_stream_output('data: {"type":"error","error":{"message":"x"}}\n\n') is False + assert is_visible_stream_output(': heartbeat\n\n') is False + assert is_visible_stream_output("data: [DONE]\n\n") is False + assert is_visible_stream_output("not-sse") is True + + +def test_visible_output_detection_for_responses_events() -> None: + assert is_visible_stream_output('data: {"event_type":"response.output_text.delta","delta":"hi"}\n\n', protocol="responses") is True + assert is_visible_stream_output('data: {"event_type":"response.output_text.delta","delta":"hi"}\n\n') is True + assert is_visible_stream_output('data: {"event_type":"response.function_call_arguments.delta","delta":"{\\"x\\":"}\n\n') is True + assert is_visible_stream_output('data: {"event_type":"response.failed","error":{"message":"x"}}\n\n', protocol="responses") is False + assert is_visible_stream_output('event: response.failed\ndata: {"error":{"message":"x"}}\n\n', protocol="responses") is False diff --git a/tests/test_stream_transport.py b/tests/test_stream_transport.py new file mode 100644 index 000000000..4b52e1e3f --- /dev/null +++ b/tests/test_stream_transport.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from rotator_library.streaming import JSONLineStreamFormatter, SSEStreamFormatter, StreamEvent, WebSocketStreamFormatter + + +def test_sse_formatter_outputs_named_event() -> None: + formatted = SSEStreamFormatter().format_event(StreamEvent("delta", data={"text": "hi"}, visible_output=True)) + + assert formatted.startswith("event: delta\n") + assert "data:" in formatted + + +def test_sse_formatter_outputs_comment_heartbeat() -> None: + formatted = SSEStreamFormatter().format_heartbeat("keepalive") + + assert formatted == ": keepalive\n\n" + + +def test_websocket_formatter_exposes_future_transport_shape() -> None: + formatted = WebSocketStreamFormatter().format_event(StreamEvent("delta", data={"text": "hi"})) + + assert formatted["type"] == "delta" + assert formatted["payload"]["transport"] == "sse" + assert WebSocketStreamFormatter().format_heartbeat()["type"] == "heartbeat" + + +def test_jsonl_formatter_outputs_one_line_json() -> None: + formatted = JSONLineStreamFormatter().format_event(StreamEvent("metadata", data={"ok": True})) + + assert formatted.endswith("\n") + assert "metadata" in formatted + assert "heartbeat" in JSONLineStreamFormatter().format_heartbeat() diff --git a/tests/test_streaming_error_handler.py b/tests/test_streaming_error_handler.py new file mode 100644 index 000000000..f85b513d1 --- /dev/null +++ b/tests/test_streaming_error_handler.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from litellm import RateLimitError + +from rotator_library.streaming.errors import decide_streaming_error_action + + +def _rate_limit(retry_after: int | None = None) -> Exception: + error = RateLimitError("rate limited", llm_provider="openai", model="gpt-test") + if retry_after is not None: + error.retry_after = retry_after + return error + + +def test_streaming_error_decision_retries_same_key_for_small_retry_after() -> None: + decision = decide_streaming_error_action( + _rate_limit(3), + provider="openai", + last_streamed_chunk=None, + attempt=0, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.action == "retry_same" + assert decision.start_provider_cooldown is False + + +def test_streaming_error_decision_starts_cooldown_before_visible_output() -> None: + decision = decide_streaming_error_action( + _rate_limit(60), + provider="openai", + last_streamed_chunk=None, + attempt=1, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.action == "rotate" + assert decision.start_provider_cooldown is True + assert decision.provider_cooldown_duration == 60 + assert decision.provider_cooldown_scope == "provider" + + +def test_streaming_error_decision_blocks_after_visible_output_and_skips_cooldown() -> None: + decision = decide_streaming_error_action( + _rate_limit(60), + provider="openai", + last_streamed_chunk='data: {"choices":[{"delta":{"content":"hi"}}]}\n\n', + attempt=0, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.action == "fallback_blocked_after_output" + assert decision.start_provider_cooldown is False + + +def test_streaming_error_decision_blocks_after_prior_visible_output() -> None: + decision = decide_streaming_error_action( + _rate_limit(60), + provider="openai", + last_streamed_chunk='data: {"usage":{"total_tokens":1}}\n\n', + emitted_output=True, + attempt=0, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.action == "fallback_blocked_after_output" + assert decision.start_provider_cooldown is False + + +def test_streaming_error_decision_allows_reasoning_only_retry_when_enabled() -> None: + decision = decide_streaming_error_action( + _rate_limit(3), + provider="openai", + last_streamed_chunk='data: {"choices":[{"delta":{"reasoning_content":"thinking"}}]}\n\n', + attempt=0, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + allow_reasoning_only_retry=True, + ) + + assert decision.action == "retry_same" + + +def test_streaming_error_decision_reports_model_cooldown_scope() -> None: + decision = decide_streaming_error_action( + Exception("MODEL_CAPACITY_EXHAUSTED"), + provider="openai", + model="gpt-5", + last_streamed_chunk=None, + attempt=1, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.start_provider_cooldown is True + assert decision.provider_cooldown_scope == "model" + assert decision.provider_cooldown_model == "gpt-5" diff --git a/tests/test_streaming_fallback_policy.py b/tests/test_streaming_fallback_policy.py new file mode 100644 index 000000000..55084e55e --- /dev/null +++ b/tests/test_streaming_fallback_policy.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import json +from types import MethodType + +import pytest + +from rotator_library.client import executor as executor_module +from rotator_library.client.executor import RequestExecutor +from rotator_library.core.types import RequestContext +from rotator_library.routing import parse_route_target +from rotator_library.routing.types import FallbackGroup +from rotator_library.transaction_logger import TransactionLogger + + +class StreamFailure(Exception): + def __init__(self, error_type: str) -> None: + super().__init__(error_type) + self.error_type = error_type + + +def _context(*, logger=None) -> RequestContext: + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) + return RequestContext( + model="code", + provider="requested", + kwargs={"model": "code", "messages": [], "stream": True}, + streaming=True, + credentials=["cred-a"], + deadline=9999999999.0, + transaction_logger=logger, + routing_targets=targets, + routing_group_name="code_chain", + routing_group=FallbackGroup(name="code_chain", targets=targets, failover_on=frozenset({"authentication", "rate_limit"}), stop_on=frozenset({"validation"})), + ) + + +def _context_never_streaming_fallback(*, logger=None) -> RequestContext: + context = _context(logger=logger) + context.routing_group = FallbackGroup(name="code_chain", targets=context.routing_targets, failover_on=frozenset({"rate_limit"}), streaming_policy="never") + return context + + +@pytest.mark.asyncio +async def test_streaming_fallback_hard_stops_auth_even_with_group_override() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_stream(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise StreamFailure("authentication") + yield "data: [DONE]\n\n" + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + + assert attempts == ["codex"] + + +@pytest.mark.asyncio +async def test_streaming_fallback_tries_next_target_before_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_stream(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise StreamFailure("rate_limit") + yield 'data: {"choices":[{"delta":{"content":"ok"}}]}\n\n' + yield "data: [DONE]\n\n" + + executor._execute_streaming = MethodType(fake_stream, executor) + + context = _context() + chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(context)] + + assert attempts == ["codex", "openai"] + assert chunks[-1] == "data: [DONE]\n\n" + assert context.routing_attempt_history[0]["error_type"] == "rate_limit" + assert context.routing_attempt_history[0]["fallback_allowed"] is True + assert context.routing_attempt_history[1]["success"] is True + + +@pytest.mark.asyncio +async def test_streaming_fallback_blocks_after_visible_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + + async def fake_stream(self, context): + yield 'data: {"choices":[{"delta":{"content":"partial"}}]}\n\n' + raise StreamFailure("rate_limit") + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + + +@pytest.mark.asyncio +async def test_streaming_fallback_trace_records_blocked_after_output(tmp_path) -> None: + executor = RequestExecutor.__new__(RequestExecutor) + logger = TransactionLogger("routing", "code", parent_dir=tmp_path) + + async def fake_stream(self, context): + yield 'data: {"choices":[{"delta":{"content":"partial"}}]}\n\n' + raise StreamFailure("rate_limit") + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context(logger=logger))] + + pass_names = [json.loads(line)["pass_name"] for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + assert "routing_stream_target_attempt_started" in pass_names + assert "routing_stream_target_attempt_failed" in pass_names + assert "routing_stream_fallback_blocked_after_output" in pass_names + + +@pytest.mark.asyncio +async def test_streaming_fallback_treats_error_chunk_as_not_visible_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_stream(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + yield 'data: {"error":{"type":"rate_limit"}}\n\n' + yield "data: [DONE]\n\n" + return + yield "data: [DONE]\n\n" + + executor._execute_streaming = MethodType(fake_stream, executor) + + chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + + assert attempts == ["codex", "openai"] + assert chunks == ["data: [DONE]\n\n"] + + +@pytest.mark.asyncio +async def test_streaming_fallback_respects_never_policy_before_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_stream(self, context): + attempts.append(context.provider) + raise StreamFailure("rate_limit") + yield "" + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context_never_streaming_fallback())] + + assert attempts == ["codex"] + + +@pytest.mark.asyncio +async def test_streaming_fallback_exhaustion_trace_uses_sanitized_summaries(tmp_path) -> None: + executor = RequestExecutor.__new__(RequestExecutor) + logger = TransactionLogger("routing", "code", parent_dir=tmp_path) + + async def fake_stream(self, context): + raise StreamFailure("rate_limit") + yield "" + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context_never_streaming_fallback(logger=logger))] + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + exhausted = [entry for entry in entries if entry["pass_name"] == "routing_fallback_exhausted"][-1] + assert exhausted["metadata"]["fallback_targets"][0]["message"] == "" + assert exhausted["metadata"]["streaming_policy"] == "never" + + +def test_stream_timeout_details_merge_into_aggregate_error() -> None: + error_data = {"error": {"type": "proxy_error", "details": {"attempts": 1}}} + stream_error = {"error": {"type": "api_connection", "details": {"timeout_type": "ttfb", "timeout_seconds": 0.1}}} + + executor_module._merge_stream_error_details(error_data, stream_error) + + assert error_data["error"]["type"] == "api_connection" + assert error_data["error"]["details"]["attempts"] == 1 + assert error_data["error"]["details"]["timeout_type"] == "ttfb" diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py new file mode 100644 index 000000000..8077ee634 --- /dev/null +++ b/tests/test_streaming_usage_accounting.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.client.streaming import StreamingHandler +from rotator_library.transaction_logger import TransactionLogger + + +class FakeCredentialContext: + def __init__(self) -> None: + self.success_kwargs = None + + def mark_success(self, **kwargs) -> None: + self.success_kwargs = kwargs + + +async def _usage_chunks(): + yield {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + yield { + "id": "chunk_2", + "choices": [{"delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 30, + "prompt_tokens_details": {"cached_tokens": 40, "cache_creation_tokens": 5}, + "completion_tokens_details": {"reasoning_tokens": 10}, + }, + } + + +async def _zero_usage_chunks(): + yield {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + yield {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}]} + + +async def _cost_comment_chunks(): + yield ': cost {"total_cost":0.042,"currency":"USD","source":"provider_sse"}\n\n' + yield {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + yield {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _cost_event_chunks(): + yield 'event: cost\ndata: {"total_cost":0.021,"currency":"EUR","source":"event_cost"}\n\n' + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _scalar_cost_event_chunks(): + yield "event: cost\ndata: 0.033\n\n" + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _request_cost_comment_chunks(): + yield ': cost {"request_cost_usd":0.044,"source":"reference_sse"}\n\n' + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _estimated_cost_comment_chunks(): + yield ': cost {"estimated_cost":0.045,"source":"reference_estimate"}\n\n' + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _top_level_cost_usage_sse_chunks(): + yield 'data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1},"total_cost":0.055}\n\n' + + +async def _top_level_cost_usage_dict_chunks(): + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}, "total_cost": 0.066} + + +async def _cost_comment_overridden_by_final_usage_chunks(): + yield ': cost 0.042\n\n' + yield { + "id": "chunk_2", + "choices": [{"delta": {}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "cost_details": {"total_cost": 0.084, "source": "final_usage"}}, + } + + +@pytest.mark.asyncio +async def test_streaming_usage_uses_normalized_accounting_and_trace(tmp_path, monkeypatch) -> None: + monkeypatch.setattr( + "rotator_library.usage.costs.litellm.get_model_info", + lambda model: {"input_cost_per_token": 0.001, "output_cost_per_token": 0.002}, + ) + cred_context = FakeCredentialContext() + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "gpt-test", cred_context=cred_context, transaction_logger=logger)] + + assert chunks[-1] == "data: [DONE]\n\n" + assert cred_context.success_kwargs["prompt_tokens"] == 55 + assert cred_context.success_kwargs["prompt_tokens_cache_read"] == 40 + assert cred_context.success_kwargs["prompt_tokens_cache_write"] == 5 + assert cred_context.success_kwargs["completion_tokens"] == 20 + assert cred_context.success_kwargs["thinking_tokens"] == 10 + assert cred_context.success_kwargs["approx_cost"] > 0 + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + assert any(entry["pass_name"] == "usage_accounting_summary" for entry in entries) + + +@pytest.mark.asyncio +async def test_streaming_usage_skip_cost_returns_zero() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "gpt-test", cred_context=cred_context, skip_cost_calculation=True)] + + assert cred_context.success_kwargs["approx_cost"] == 0.0 + + +@pytest.mark.asyncio +async def test_streaming_without_usage_still_marks_success_with_zero_usage() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_zero_usage_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["prompt_tokens"] == 0 + assert cred_context.success_kwargs["completion_tokens"] == 0 + assert cred_context.success_kwargs["thinking_tokens"] == 0 + assert cred_context.success_kwargs["prompt_tokens_cache_read"] == 0 + assert cred_context.success_kwargs["prompt_tokens_cache_write"] == 0 + + +@pytest.mark.asyncio +async def test_streaming_completed_calls_success_callback() -> None: + called = [] + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_zero_usage_chunks(), "cred", "gpt-test", success_callback=lambda: called.append(True))] + + assert called == [True] + + +@pytest.mark.asyncio +async def test_streaming_usage_uses_configured_env_pricing(monkeypatch) -> None: + monkeypatch.setenv("MODEL_PRICE_OPENAI_GPT_TEST_INPUT", "2.0") + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "openai/gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 110.0 + + +@pytest.mark.asyncio +async def test_streaming_cost_comment_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_cost_comment_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert chunks[0].startswith(": cost") + assert cred_context.success_kwargs["approx_cost"] == 0.042 + + +@pytest.mark.asyncio +async def test_streaming_cost_event_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_cost_event_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.021 + + +@pytest.mark.asyncio +async def test_streaming_scalar_cost_event_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_scalar_cost_event_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.033 + + +@pytest.mark.asyncio +async def test_streaming_reference_request_cost_comment_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_request_cost_comment_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.044 + + +@pytest.mark.asyncio +async def test_streaming_estimated_cost_comment_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_estimated_cost_comment_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.045 + + +@pytest.mark.asyncio +async def test_streaming_sse_usage_preserves_top_level_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_top_level_cost_usage_sse_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.055 + + +@pytest.mark.asyncio +async def test_streaming_dict_usage_preserves_top_level_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_top_level_cost_usage_dict_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.066 + + +@pytest.mark.asyncio +async def test_streaming_final_usage_cost_overrides_comment_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_cost_comment_overridden_by_final_usage_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.084 diff --git a/tests/test_transaction_logger_transform_trace.py b/tests/test_transaction_logger_transform_trace.py new file mode 100644 index 000000000..250a0df60 --- /dev/null +++ b/tests/test_transaction_logger_transform_trace.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.client.executor import RequestExecutor +from rotator_library.client.transforms import ProviderTransforms +from rotator_library.transaction_logger import ProviderLogger, TransactionLogger +from rotator_library.transform_trace import REDACTED + + +def _trace_entries(log_dir): + trace_file = log_dir / "transform_trace.jsonl" + return [json.loads(line) for line in trace_file.read_text(encoding="utf-8").splitlines()] + + +def test_log_request_writes_legacy_file_and_raw_trace(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + logger.log_request({"model": "gpt-test", "api_key": "secret", "messages": [{"role": "user", "content": "hi"}]}) + + assert (logger.log_dir / "request.json").exists() + entries = _trace_entries(logger.log_dir) + assert entries[0]["pass_name"] == "raw_client_request" + assert entries[0]["request_id"] == logger.request_id + assert entries[0]["direction"] == "request" + assert entries[0]["data"]["api_key"] == REDACTED + assert (logger.log_dir / "transforms" / "0001_raw_client_request.json").exists() + + +def test_log_transformed_request_records_trace_even_when_legacy_file_skips(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + logger.set_trace_context(session_id="session_1", scope_key="scope_1", classifier="class_a") + request = {"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]} + + logger.log_transformed_request(request, dict(request), credential_id="cred_1") + + entries = _trace_entries(logger.log_dir) + assert entries[0]["pass_name"] == "prepared_provider_request" + assert entries[0]["changed_from_previous"] is False + assert entries[0]["credential_id"] == "cred_1" + assert entries[0]["session_id"] == "session_1" + assert entries[0]["scope_key"] == "scope_1" + assert entries[0]["classifier"] == "class_a" + assert not (logger.log_dir / "request_transformed.json").exists() + + +def test_log_response_and_stream_chunk_write_trace_entries(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + logger.log_request({"model": "gpt-test", "stream": True}) + + logger.log_stream_chunk({"choices": [{"delta": {"content": "Hi"}}]}) + logger.log_response({"model": "gpt-test", "choices": [], "usage": {"total_tokens": 2}}) + + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert "parsed_stream_chunk" in pass_names + assert "final_client_response" in pass_names + stream_entry = next(entry for entry in entries if entry["pass_name"] == "parsed_stream_chunk") + assert stream_entry["direction"] == "stream" + assert not (logger.log_dir / "transforms" / "0002_parsed_stream_chunk.json").exists() + + +def test_provider_logger_writes_provider_trace_entries(tmp_path) -> None: + logger = TransactionLogger("gemini_cli", "gemini_cli/gemini-test", parent_dir=tmp_path) + provider_logger = ProviderLogger(logger.get_context()) + + provider_logger.log_request({"credential_identifier": "secret", "body": {"text": "hi"}}) + provider_logger.log_response_chunk("data: chunk") + provider_logger.log_final_response({"ok": True}) + provider_logger.log_error("provider failed") + + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names == [ + "provider_request_payload", + "provider_raw_stream_chunk", + "provider_final_response", + "provider_error", + ] + assert entries[0]["component"] == "provider" + assert entries[0]["request_id"] == logger.request_id + assert entries[0]["data"]["credential_identifier"] == REDACTED + assert (logger.log_dir / "provider" / "request_payload.json").exists() + + +@pytest.mark.asyncio +async def test_stream_wrapper_records_raw_parsed_and_assembled_trace(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + executor = RequestExecutor.__new__(RequestExecutor) + + async def stream(): + yield 'data: {"id":"chunk_1","choices":[{"delta":{"content":"Hi"},"finish_reason":null}]}\n\n' + yield 'data: {"choices":[{"delta":{},"finish_reason":"stop"}]}\n\n' + yield "data: [DONE]\n\n" + + chunks = [chunk async for chunk in executor._transaction_logging_stream_wrapper(stream(), logger, {})] + + assert chunks[-1] == "data: [DONE]\n\n" + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names.count("raw_stream_chunk") == 3 + assert pass_names.count("parsed_stream_chunk") == 2 + assert pass_names.count("stream_done_event") == 1 + assert "assembled_stream_response" in pass_names + assert "final_client_response" in pass_names + + +@pytest.mark.asyncio +async def test_stream_wrapper_records_error_events(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + executor = RequestExecutor.__new__(RequestExecutor) + + async def stream(): + yield 'data: {"error":{"type":"rate_limit","message":"slow down"}}\n\n' + yield "data: [DONE]\n\n" + + chunks = [chunk async for chunk in executor._transaction_logging_stream_wrapper(stream(), logger, {})] + + assert chunks[-1] == "data: [DONE]\n\n" + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "stream_error_event" in pass_names + assert "stream_done_event" in pass_names + + +def test_executor_terminal_stream_errors_are_traced(tmp_path) -> None: + class Context: + pass + + context = Context() + context.transaction_logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + context.streaming = True + context.provider = "openai" + context.model = "openai/gpt-test" + context.session_id = None + context.usage_manager_key = "openai" + context.classifier = None + + executor = RequestExecutor.__new__(RequestExecutor) + + lines = executor._terminal_stream_error_lines(context, {"error": {"type": "proxy_error"}}) + + assert lines[-1] == "data: [DONE]\n\n" + pass_names = [entry["pass_name"] for entry in _trace_entries(context.transaction_logger.log_dir)] + assert pass_names == ["stream_error_event", "stream_done_event"] + + +def test_transaction_logger_disabled_writes_no_trace(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", enabled=False, parent_dir=tmp_path) + + logger.log_request({"model": "gpt-test"}) + logger.log_response({"model": "gpt-test"}) + + assert logger.log_dir is None + assert not list(tmp_path.iterdir()) + + +def test_provider_error_trace_scrubs_header_like_secret_text(tmp_path) -> None: + logger = TransactionLogger("gemini_cli", "gemini_cli/gemini-test", parent_dir=tmp_path) + provider_logger = ProviderLogger(logger.get_context()) + + provider_logger.log_error("upstream failed Authorization: Bearer secret-token") + + entries = _trace_entries(logger.log_dir) + assert entries[0]["pass_name"] == "provider_error" + assert "secret-token" not in entries[0]["data"]["message"] + assert "[REDACTED]" in entries[0]["data"]["message"] + + +def test_log_transform_error_uses_standard_shape_and_scrubs_text(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + logger.log_transform_error( + "after_field_cache_injection", + RuntimeError("bad Authorization: Bearer secret"), + payload={"cookie": "sid=secret"}, + ) + + entries = _trace_entries(logger.log_dir) + assert entries[0]["pass_name"] == "transform_log_error" + assert entries[0]["data"]["failed_pass_name"] == "after_field_cache_injection" + assert "secret" not in json.dumps(entries[0]["data"]) + + +def test_trace_redacts_camel_case_secret_keys(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + logger.log_transform_pass( + "camel_secret_payload", + {"apiKey": "a", "accessToken": "b", "refreshToken": "c", "clientSecret": "d", "idToken": "e"}, + direction="request", + stage="client", + ) + + data = _trace_entries(logger.log_dir)[0]["data"] + assert set(data.values()) == {REDACTED} + + +@pytest.mark.asyncio +async def test_provider_transforms_trace_each_live_boundary(tmp_path) -> None: + class HookPlugin: + async def transform_request(self, kwargs, model, credential): + kwargs["hooked"] = credential + return ["hooked request"] + + def get_model_options(self, model): + return {"reasoning_effort": "low", "temperature": 0.2} + + class Config: + def convert_for_litellm(self, provider_override=None, **kwargs): + converted = dict(kwargs) + converted["converted_for_litellm"] = True + return converted + + logger = TransactionLogger("dedaluslabs", "dedaluslabs/test", parent_dir=tmp_path) + transforms = ProviderTransforms( + {"dedaluslabs": HookPlugin()}, + provider_config=Config(), + ) + + result = await transforms.apply( + "dedaluslabs", + "dedaluslabs/test", + "secret-credential", + {"model": "dedaluslabs/test", "tool_choice": "auto"}, + transaction_logger=logger, + credential_id="stable_cred", + transport="http", + trace_metadata={"scope_key": "scope"}, + ) + + assert "tool_choice" not in result + assert result["hooked"] == "secret-credential" + assert result["reasoning_effort"] == "low" + assert result["converted_for_litellm"] is True + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names == [ + "pre_provider_transform_request", + "after_builtin_provider_transform", + "after_provider_hook_transform", + "after_provider_model_options", + "before_litellm_conversion", + "after_litellm_conversion", + ] + assert all(entry["credential_id"] == "stable_cred" for entry in entries) + assert entries[1]["metadata"]["transform_provider"] == "dedaluslabs" + assert entries[-1]["changed_from_previous"] is True + + +@pytest.mark.asyncio +async def test_provider_builtin_transform_errors_are_traced(tmp_path) -> None: + def broken_transform(kwargs, model, provider): + raise RuntimeError("bad apiKey: secret") + + logger = TransactionLogger("broken", "broken/test", parent_dir=tmp_path) + transforms = ProviderTransforms({}) + transforms._transforms["broken"] = [broken_transform] + + with pytest.raises(RuntimeError): + await transforms.apply( + "broken", + "broken/test", + "cred", + {"model": "broken/test", "apiKey": "secret"}, + transaction_logger=logger, + credential_id="cred_1", + ) + + entries = _trace_entries(logger.log_dir) + error_entry = [entry for entry in entries if entry["pass_name"] == "transform_log_error"][-1] + assert error_entry["data"]["failed_pass_name"] == "builtin_provider_transform" + assert "secret" not in json.dumps(error_entry["data"]) + + +@pytest.mark.asyncio +async def test_provider_transforms_do_not_deepcopy_for_trace_when_disabled() -> None: + class NoCopyValue: + def __deepcopy__(self, memo): + raise AssertionError("trace comparison should not copy when tracing is disabled") + + class HookPlugin: + async def transform_request(self, kwargs, model, credential): + kwargs["hooked"] = True + return ["hooked request"] + + transforms = ProviderTransforms({"dedaluslabs": HookPlugin()}) + + result = await transforms.apply( + "dedaluslabs", + "dedaluslabs/test", + "secret-credential", + {"model": "dedaluslabs/test", "tool_choice": "auto", "opaque": NoCopyValue()}, + ) + + assert "tool_choice" not in result + assert result["hooked"] is True diff --git a/tests/test_transform_trace.py b/tests/test_transform_trace.py new file mode 100644 index 000000000..682ae6b1f --- /dev/null +++ b/tests/test_transform_trace.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import datetime +from decimal import Decimal + +from rotator_library.transform_trace import ( + REDACTED, + TransformTraceWriter, + sanitize_filename, + sanitize_for_trace, + scrub_sensitive_text, +) + + +@dataclass +class ExamplePayload: + when: datetime + amount: Decimal + + +def test_sanitize_for_trace_redacts_sensitive_keys_recursively() -> None: + payload = { + "api_key": "secret-key", + "headers": {"Authorization": "Bearer secret", "Cookie": "sid=secret", "normal": "token in normal text"}, + "items": [{"refresh_token": "refresh", "text": "token should remain in value"}], + } + + sanitized = sanitize_for_trace(payload) + + assert sanitized["api_key"] == REDACTED + assert sanitized["headers"]["Authorization"] == REDACTED + assert sanitized["headers"]["Cookie"] == REDACTED + assert sanitized["headers"]["normal"] == "token in normal text" + assert sanitized["items"][0]["refresh_token"] == REDACTED + assert sanitize_for_trace({"reasoning_content": "provider-state"})["reasoning_content"] == REDACTED + assert sanitize_for_trace({"thoughtSignature": "provider-state"})["thoughtSignature"] == REDACTED + assert sanitize_for_trace({"metadata": {"state": "provider-state"}})["metadata"]["state"] == REDACTED + assert sanitize_for_trace({"metadata": {"prompt_cache_key": "cache-state"}})["metadata"]["prompt_cache_key"] == REDACTED + assert sanitized["items"][0]["text"] == "token should remain in value" + + +def test_scrub_sensitive_text_targets_header_like_fragments_only() -> None: + text = "normal token text remains\nAuthorization: Bearer abc123\nset-cookie: sid=secret\n{'Authorization': 'Bearer quoted'}" + + scrubbed = scrub_sensitive_text(text) + + assert "normal token text remains" in scrubbed + assert "Authorization: [REDACTED]" in scrubbed + assert "set-cookie: [REDACTED]" in scrubbed + assert "quoted" not in scrubbed + + +def test_sanitize_for_trace_extracts_sdk_like_objects_before_repr() -> None: + class SdkError: + def __init__(self) -> None: + self.status_code = 401 + self.headers = {"Authorization": "Bearer secret"} + + def __repr__(self) -> str: + return "SdkError(Authorization: Bearer leaked)" + + sanitized = sanitize_for_trace(SdkError(), scrub_strings=True) + + assert sanitized["status_code"] == 401 + assert sanitized["headers"]["Authorization"] == REDACTED + assert "leaked" not in json.dumps(sanitized) + + +def test_sanitize_for_trace_serializes_common_non_json_values() -> None: + payload = ExamplePayload(when=datetime(2026, 1, 2, 3, 4, 5), amount=Decimal("1.5")) + + sanitized = sanitize_for_trace({"payload": payload, "bytes": b"hello"}) + + assert sanitized["payload"]["when"] == "2026-01-02T03:04:05" + assert sanitized["payload"]["amount"] == 1.5 + assert sanitized["bytes"] == "hello" + json.dumps(sanitized) + + +def test_sanitize_filename_is_stable_and_filesystem_safe() -> None: + assert sanitize_filename("Raw Client Request") == "raw_client_request" + assert sanitize_filename("provider/request:payload") == "provider_request_payload" + + +def test_transform_trace_writer_records_jsonl_and_snapshots(tmp_path) -> None: + writer = TransformTraceWriter(tmp_path, component="client", provider="openai", model="gpt-test", request_id="req_1") + + first = writer.record("raw_client_request", {"model": "gpt-test"}, direction="request", stage="client") + second = writer.record("parsed_stream_chunk", {"delta": "hi"}, direction="stream", stage="client", snapshot=True) + + assert first is not None + assert second is not None + assert first.sequence == 1 + assert second.sequence == 2 + + lines = (tmp_path / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines() + assert len(lines) == 2 + assert json.loads(lines[0])["pass_name"] == "raw_client_request" + assert json.loads(lines[0])["request_id"] == "req_1" + assert json.loads(lines[1])["direction"] == "stream" + assert (tmp_path / "transforms" / "0001_raw_client_request.json").exists() + assert not (tmp_path / "transforms" / "0002_parsed_stream_chunk.json").exists() + + +def test_transform_trace_writer_disabled_writes_nothing(tmp_path) -> None: + writer = TransformTraceWriter(tmp_path, component="client", enabled=False) + + assert writer.record("raw_client_request", {}, direction="request", stage="client") is None + assert not (tmp_path / "transform_trace.jsonl").exists() + + +def test_transform_trace_writer_snapshot_namespace_prevents_collisions(tmp_path) -> None: + first = TransformTraceWriter(tmp_path, component="provider", snapshot_namespace="provider_a") + second = TransformTraceWriter(tmp_path, component="provider", snapshot_namespace="provider_b") + + first.record("provider_request_payload", {"a": 1}, direction="request", stage="provider") + second.record("provider_request_payload", {"b": 2}, direction="request", stage="provider") + + assert (tmp_path / "transforms" / "0001_provider_a_provider_request_payload.json").exists() + assert (tmp_path / "transforms" / "0001_provider_b_provider_request_payload.json").exists() diff --git a/tests/test_usage_accounting.py b/tests/test_usage_accounting.py new file mode 100644 index 000000000..e30e82886 --- /dev/null +++ b/tests/test_usage_accounting.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from rotator_library.protocols.types import CostDetails, Usage +from rotator_library.usage.accounting import UsageRecord, extract_usage_record + + +def test_openai_dict_usage_extracts_cache_and_reasoning_without_double_counting() -> None: + record = extract_usage_record( + { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 30, + "total_tokens": 130, + "prompt_tokens_details": {"cached_tokens": 40, "cache_creation_tokens": 5}, + "completion_tokens_details": {"reasoning_tokens": 10}, + } + }, + provider="openai", + model="gpt-test", + ) + + assert record.input_tokens == 55 + assert record.cache_read_tokens == 40 + assert record.cache_write_tokens == 5 + assert record.completion_tokens == 20 + assert record.reasoning_tokens == 10 + assert record.output_tokens == 30 + assert record.raw_total_tokens == 130 + assert record.total_tokens == 130 + + +def test_openai_usage_extracts_provider_reported_cost_details() -> None: + record = extract_usage_record( + { + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "cost_details": {"total_cost": 0.0123, "currency": "EUR", "source": "provider_usage"}, + } + } + ) + + assert record.provider_reported_cost == 0.0123 + assert record.cost_currency == "EUR" + assert record.cost_source == "provider_usage" + + +def test_top_level_cost_is_preserved_when_usage_exists() -> None: + record = extract_usage_record( + { + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + "total_cost": 0.03, + "currency": "EUR", + } + ) + + assert record.provider_reported_cost == 0.03 + assert record.cost_currency == "EUR" + + +def test_reference_request_cost_usd_is_preserved() -> None: + record = extract_usage_record({"usage": {"prompt_tokens": 1, "completion_tokens": 1}, "request_cost_usd": 0.019}) + + assert record.provider_reported_cost == 0.019 + + +def test_top_level_estimated_cost_is_preserved() -> None: + record = extract_usage_record({"usage": {"prompt_tokens": 1, "completion_tokens": 1}, "estimated_cost": 0.027}) + + assert record.provider_reported_cost == 0.027 + + +def test_structured_cost_breakdown_without_total_is_summed() -> None: + record = extract_usage_record( + { + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "cost_details": {"cached_input_cost": 0.01, "upstream_inference_cost": 0.02, "web_search_cost": 0.003}, + } + } + ) + + assert record.provider_reported_cost == 0.033 + assert record.cost_source == "provider_reported_breakdown" + + +def test_reference_extended_cost_breakdown_aliases_are_summed() -> None: + record = extract_usage_record( + { + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "cost_details": { + "upstream_inference_input_cost": 0.01, + "upstream_inference_output_cost": 0.02, + "image_input_cost": 0.003, + "audio_input_cost": 0.004, + "data_storage_cost": 0.005, + "estimated_cost": 0.006, + "cost_in_usd_ticks": 70_000_000, + }, + } + } + ) + + assert record.provider_reported_cost == 0.055 + + +def test_openai_object_usage_extracts_attributes() -> None: + response = SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=12, + completion_tokens=7, + prompt_tokens_details=SimpleNamespace(cached_tokens=2, cache_creation_tokens=1), + completion_tokens_details=SimpleNamespace(reasoning_tokens=3), + ) + ) + + record = extract_usage_record(response) + + assert record.input_tokens == 9 + assert record.cache_read_tokens == 2 + assert record.cache_write_tokens == 1 + assert record.completion_tokens == 4 + assert record.reasoning_tokens == 3 + + +def test_anthropic_usage_extracts_cache_buckets() -> None: + record = extract_usage_record( + { + "input_tokens": 100, + "output_tokens": 20, + "cache_creation_input_tokens": 8, + "cache_read_input_tokens": 12, + } + ) + + assert record.input_tokens == 80 + assert record.cache_read_tokens == 12 + assert record.cache_write_tokens == 8 + assert record.completion_tokens == 20 + assert record.metadata["shape"] == "anthropic" + + +def test_gemini_usage_metadata_extracts_thought_and_cached_tokens() -> None: + record = extract_usage_record( + { + "usageMetadata": { + "promptTokenCount": 80, + "candidatesTokenCount": 25, + "thoughtsTokenCount": 5, + "cachedContentTokenCount": 30, + "totalTokenCount": 110, + } + } + ) + + assert record.input_tokens == 50 + assert record.cache_read_tokens == 30 + assert record.completion_tokens == 20 + assert record.reasoning_tokens == 5 + assert record.metadata["shape"] == "gemini" + + +def test_gemini_usage_extracts_cost_metadata() -> None: + record = extract_usage_record( + { + "usageMetadata": { + "promptTokenCount": 1, + "totalTokenCount": 1, + "costMetadata": {"cost": "0.004", "currency": "USD", "source": "gemini_usage"}, + } + } + ) + + assert record.provider_reported_cost == 0.004 + assert record.cost_source == "gemini_usage" + + +def test_protocol_usage_cost_details_are_preserved() -> None: + record = extract_usage_record( + Usage(input_tokens=2, output_tokens=3, cost=CostDetails(provider_reported_cost=0.02, currency="GBP", source="protocol_usage")) + ) + + assert record.input_tokens == 2 + assert record.completion_tokens == 3 + assert record.provider_reported_cost == 0.02 + assert record.cost_currency == "GBP" + assert record.cost_source == "protocol_usage" + + +def test_responses_usage_extracts_nested_details() -> None: + record = extract_usage_record( + { + "input_tokens": 42, + "output_tokens": 13, + "input_tokens_details": {"cached_tokens": 10}, + "output_tokens_details": {"reasoning_tokens": 4}, + }, + source="responses", + ) + + assert record.input_tokens == 32 + assert record.cache_read_tokens == 10 + assert record.completion_tokens == 9 + assert record.reasoning_tokens == 4 + + +def test_unknown_usage_shape_returns_empty_record() -> None: + record = extract_usage_record({"not_usage": True}, provider="x", model="y") + + assert record == UsageRecord(provider="x", model="y", source="response", metadata={"shape": "openai_like"}) diff --git a/tests/test_usage_costs.py b/tests/test_usage_costs.py new file mode 100644 index 000000000..0ffe6ec5d --- /dev/null +++ b/tests/test_usage_costs.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from rotator_library.usage.accounting import UsageRecord +from rotator_library.usage.costs import CostCalculator, ModelPricing + + +class PricingProvider: + def get_model_pricing(self, model: str): + return ModelPricing( + input_cost_per_token=0.001, + cache_read_cost_per_token=0.0001, + cache_write_cost_per_token=0.0005, + output_cost_per_token=0.002, + reasoning_cost_per_token=0.003, + source="provider_test", + ) + + +class SkipCostProvider: + skip_cost_calculation = True + + +def test_explicit_pricing_calculates_all_usage_buckets() -> None: + usage = UsageRecord(input_tokens=10, cache_read_tokens=20, cache_write_tokens=5, completion_tokens=7, reasoning_tokens=3) + + cost = CostCalculator(provider_plugin=PricingProvider(), use_litellm_fallback=False).calculate(usage, model="test") + + assert cost.input_cost == 0.01 + assert cost.cache_read_cost == 0.002 + assert cost.cache_write_cost == 0.0025 + assert cost.output_cost == 0.014 + assert cost.reasoning_cost == 0.009000000000000001 + assert cost.pricing_source == "provider_test" + + +def test_skip_cost_provider_returns_zero_skipped_breakdown() -> None: + cost = CostCalculator(provider_plugin=SkipCostProvider()).calculate(UsageRecord(input_tokens=100), model="test") + + assert cost.total_cost == 0.0 + assert cost.pricing_source == "skipped" + + +def test_provider_reported_cost_wins_over_advisory_pricing() -> None: + usage = UsageRecord(input_tokens=10, completion_tokens=10, provider_reported_cost=0.123, cost_currency="EUR", cost_source="provider_actual") + + cost = CostCalculator(provider_plugin=PricingProvider(), use_litellm_fallback=False).calculate(usage, model="test") + + assert cost.total_cost == 0.123 + assert cost.provider_reported_cost == 0.123 + assert cost.currency == "EUR" + assert cost.pricing_source == "provider_actual" + assert cost.input_cost == 0.0 + + +def test_skip_cost_still_wins_over_provider_reported_cost() -> None: + usage = UsageRecord(provider_reported_cost=0.123) + + cost = CostCalculator(provider_plugin=SkipCostProvider()).calculate(usage, model="test") + + assert cost.total_cost == 0.0 + assert cost.pricing_source == "skipped" + + +def test_missing_pricing_returns_unavailable_zero() -> None: + cost = CostCalculator(use_litellm_fallback=False).calculate(UsageRecord(input_tokens=100), model="unknown-model") + + assert cost.total_cost == 0.0 + assert cost.pricing_source == "unavailable" + + +def test_litellm_model_info_fallback(monkeypatch) -> None: + monkeypatch.setattr( + "rotator_library.usage.costs.litellm.get_model_info", + lambda model: {"input_cost_per_token": 0.001, "output_cost_per_token": 0.002}, + ) + + cost = CostCalculator().calculate(UsageRecord(input_tokens=2, completion_tokens=3), model="gpt-test") + + assert cost.total_cost == 0.008 + assert cost.pricing_source == "litellm_model_info" diff --git a/tests/test_usage_quota_snapshots.py b/tests/test_usage_quota_snapshots.py new file mode 100644 index 000000000..f110ac4ef --- /dev/null +++ b/tests/test_usage_quota_snapshots.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from rotator_library.usage.quota import build_quota_snapshots +from rotator_library.usage.types import CredentialState, WindowStats + + +def _state() -> CredentialState: + state = CredentialState(stable_id="credential-secret", provider="openai", accessor="credential-secret") + model_stats = state.get_model_stats("gpt-test") + model_stats.windows["daily"] = WindowStats(name="daily", request_count=3, limit=10, reset_at=123.0) + group_stats = state.get_group_stats("chat") + group_stats.windows["daily"] = WindowStats(name="daily", request_count=5, limit=20, reset_at=456.0) + return state + + +def test_build_quota_snapshots_for_model_window() -> None: + snapshots = build_quota_snapshots(provider="openai", states={"credential-secret": _state()}, model="gpt-test") + + assert len(snapshots) == 1 + snapshot = snapshots[0] + assert snapshot.provider == "openai" + assert snapshot.model == "gpt-test" + assert snapshot.quota_group is None + assert snapshot.used == 3 + assert snapshot.remaining == 7 + assert snapshot.credential_id != "credential-secret" + assert snapshot.metadata == {"scope": "request_token_window"} + assert "cost_used" not in snapshot.to_dict() + + +def test_build_quota_snapshots_for_group_window_without_credentials() -> None: + snapshots = build_quota_snapshots( + provider="openai", + states={"credential-secret": _state()}, + model="gpt-test", + quota_group="chat", + include_credentials=False, + ) + + group_snapshot = [snapshot for snapshot in snapshots if snapshot.source == "group"][0] + assert group_snapshot.quota_group == "chat" + assert group_snapshot.used == 5 + assert group_snapshot.remaining == 15 + assert group_snapshot.credential_id is None + + +def test_build_quota_snapshots_missing_windows_returns_empty() -> None: + assert build_quota_snapshots(provider="openai", states={"credential-secret": _state()}, model="missing") == []