From f7d4c6cda812b8f22d91b85c6048d22e01f2953b Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 22:35:55 +0200 Subject: [PATCH 001/182] docs(experimental): add native protocol roadmap Captures the experimental branch workflow, protocol architecture goals, transform logging requirements, field-cache rules, provider priorities, routing, retry, usage, streaming, and config direction. Documents that every phase must be freshly planned in conversation, written as planning docs, reviewed by explore and explore-heavy agents, and reported to the user without committing reports by default. --- docs/experimental/00-master-plan.md | 132 +++++ docs/experimental/01-protocol-architecture.md | 132 +++++ docs/experimental/02-transform-logging.md | 84 +++ docs/experimental/03-field-cache-rules.md | 113 ++++ docs/experimental/04-provider-roadmap.md | 115 ++++ .../05-routing-retry-usage-roadmap.md | 114 ++++ docs/experimental/06-phase-workflow.md | 89 +++ .../experimental/07-detailed-phase-roadmap.md | 540 ++++++++++++++++++ docs/experimental/phase-1-protocol-core.md | 172 ++++++ 9 files changed, 1491 insertions(+) create mode 100644 docs/experimental/00-master-plan.md create mode 100644 docs/experimental/01-protocol-architecture.md create mode 100644 docs/experimental/02-transform-logging.md create mode 100644 docs/experimental/03-field-cache-rules.md create mode 100644 docs/experimental/04-provider-roadmap.md create mode 100644 docs/experimental/05-routing-retry-usage-roadmap.md create mode 100644 docs/experimental/06-phase-workflow.md create mode 100644 docs/experimental/07-detailed-phase-roadmap.md create mode 100644 docs/experimental/phase-1-protocol-core.md diff --git a/docs/experimental/00-master-plan.md b/docs/experimental/00-master-plan.md new file mode 100644 index 000000000..2c85482a2 --- /dev/null +++ b/docs/experimental/00-master-plan.md @@ -0,0 +1,132 @@ +# Experimental Native Protocol Roadmap + +This branch is for a long-running experimental rewrite that makes native protocol support the first-class extension point of `rotator_library`, while preserving the existing credential rotation, quota, fair-cycle, session tracking, and provider plugin strengths. + +## Operating Rules + +- Work only on the `experimental` branch. +- Keep all repository work inside `C:\Projects\test\LLM-API-Key-Proxy` and child paths. +- Treat commits as checkpoints. A phase may contain many commits. +- Commit messages must include a body describing what changed, why, tests run, and follow-up considerations. +- Do not commit phase reports written for the user unless explicitly requested. Planning docs under `docs/experimental/` are committed. +- Before each phase implementation, first produce a fresh exhaustive phase plan in conversation text, based on the current code state. Only after that plan is settled should it be written to `docs/experimental/phase-N-*.md`. +- After each phase implementation, call both `explore` and `explore-heavy` agents to review the work against the phase plan, external reference areas, and current proxy behavior. Fix findings and re-review as needed. +- Keep LiteLLM as a fallback path for protocols/providers that are not natively covered yet. Native protocol support should be preferred when available. + +## Strategic Goal + +The target architecture is: + +```text +client API request + -> protocol parse into unified representation + -> field-cache injection + -> adapter chain + -> provider override hooks + -> provider-native request build + -> provider execution and credential rotation + -> provider-native response/stream parse + -> field-cache extraction + -> adapter chain + -> protocol formatting for the client + -> transaction logging for every transform state +``` + +Providers should be able to declare an existing protocol and only override the parts that are genuinely provider-specific. A custom provider should usually be configurable through protocol choice, adapters, field-cache rules, auth strategy, and model options rather than requiring a large bespoke provider implementation. + +## Priority Order + +1. Native protocol foundations, unified types, transformers, adapters, and field-cache rules. +2. OpenAI Responses API support, including future WebSocket extension points. +3. Provider work following the protocol layer: Claude Code, Codex, Copilot, Antigravity, and Gemini CLI parity review. +4. Routing and fallback groups, with optional target-group selectors later. +5. Retry, provider/model cooldown, and failover cleanup. +6. Protocol-aware quota, usage, and cost normalization. +7. Streaming library hardening: SSE now, WebSocket-ready later. +8. Config polish using `.env` and optional JSON. No SQLite dependency for now. +9. Extensive staged tests and review-agent verification. + +## Non-Goals For This Branch + +- Do not make the proxy a full multi-user admin product yet. +- Do not require SQLite or Postgres for the main feature set. +- Do not remove LiteLLM before native coverage exists. +- Do not replace the existing `UsageManager`, fair-cycle, custom caps, or evidence-based `SessionTracker`. +- Do not port frontend/UI work from the external reference gateway. + +## Current Strengths To Preserve + +- Credential-level rotation and priority-aware selection. +- Fair cycle and custom caps. +- Windowed quota tracking and quota groups. +- Evidence-based session tracking with compaction handling. +- Provider plugin discovery. +- Gemini CLI provider behavior unless a reviewed change is clearly better. +- Resilient file/JSON state writing. +- Dynamic OpenAI-compatible provider discovery. + +## Reference Gateway Ideas To Import Carefully + +- Unified protocol/transformer style. +- Adapter registry and configurable provider/model adapters. +- Target groups and direct routing syntax, adapted into fallback-first routing. +- Responses API transformer and storage concepts. +- Stream TTFB/stall detection concepts, implemented with Python-native async primitives. +- Provider/model cooldown and retry-history concepts. +- Usage/cost normalization and provider-reported cost extraction. +- Broader provider support patterns for Claude Code, Codex, Copilot, and Antigravity. + +## Phase Index + +1. Protocol Core. +2. Transform Pass Logging. +3. Adapter and Field Cache System. +4. Responses API and WebSocket-Ready Transport Shape. +5. Provider Protocol Overhaul. +6. Routing and Fallback Groups. +7. Retry/Cooldown/Failover Cleanup. +8. Streaming Library Upgrade. +9. Usage, Quota, and Cost Accuracy. +10. Config Polish. + +Each phase may be subdivided if implementation scope becomes too large. + +## Completeness Matrix + +This matrix exists so the branch does not lose any requested scope while phases evolve. The phase plans are still refreshed before implementation, but every item below must remain accounted for. + +| Requested area | Planned coverage | +| --- | --- | +| Protocols are priority #1 | Phases 1 and 4 create native protocol foundations and Responses support before provider work. | +| Protocols are bases, not gospel | Phase 1 requires override-friendly protocol methods, subclassing, copy/mutate registration, and provider-specific overrides. | +| Move away from LiteLLM | Phase 1 adds a `litellm_fallback` protocol path; later providers should prefer native protocols and use LiteLLM only for unsupported coverage. | +| Add protocols automatically like providers | Phase 1 adds protocol auto-discovery and registry behavior modeled after provider discovery. | +| Cover current providers and reference providers | Phase 1 protocols must cover shapes used by current providers; Phase 5 covers Claude Code, Codex, Copilot, Antigravity, and Gemini CLI parity. | +| Responses API is very needed | Phase 4 is dedicated to Responses, `previous_response_id`, storage, SSE, and WebSocket-ready transport shape. | +| WebSocket support later | Phases 1, 4, and 8 require transport separation so WebSocket can be added without rewriting protocol logic. | +| Adapters/transformers tied to protocols | Phases 1, 2, and 3 define protocol parse/build plus transform tracing, adapter registry, and field-cache rules. | +| Cache and return provider fields | Phase 3 implements configurable extraction/injection rules for request, response, and stream fields with scope and mode controls. | +| Reasoning content and similar fields | Phase 3 explicitly covers reasoning content, thinking signatures, prompt cache keys, response IDs, and provider session IDs. | +| Return all possible or last user/assistant use | Phase 3 modes include `last`, `all`, `last_user_turn`, `last_assistant_turn`, and `per_tool_call`. | +| Per-model custom provider behavior | Phases 3, 5, and 10 cover provider/model field cache rules, adapters, model options, and optional JSON config. | +| Transaction logging after every transform | Phase 2 adds ordered request, response, and stream transform trace passes and integrates them with transaction logging. | +| Comments, docstrings, and key decisions | All implementation phases require docstrings for public abstractions and comments for non-obvious transform, protocol, and future-extension decisions. | +| Providers are priority #2 | Phase 5 follows protocol foundations with Claude Code, Codex, Copilot, Antigravity, and Gemini CLI parity review. | +| Antigravity comparison | Phase 5 explicitly compares the reference Antigravity behavior against `src/rotator_library/providers/_retired/`. | +| Routing is interesting | Phase 6 implements fallback chains first, with target-group selectors later if useful. | +| Fallback groups preferred over target groups | Phase 6 starts with ordered fallback groups and only adds target-group-style selectors after that base works. | +| Retry/cooldown/failover cleanup | Phase 7 makes provider/model cooldown real, adds retry history, backoff, retry-after precedence, and success reset. | +| Quota/usage/cost improvements | Phase 9 adds protocol-aware normalizers, provider-reported cost extraction, structured cost fields, and checker abstractions while keeping existing usage engines. | +| Streaming as library capability | Phase 8 hardens streaming below the proxy route layer with TTFB, TTFT, stall detection, cancellation, and transport-aware stream events. | +| Config via env/json, no SQLite | Phase 10 adds optional JSON config with env overrides and validation. SQLite remains out of scope. | +| Multi-user proxy later | The branch keeps multi-user/admin features as a future expansion and only preserves extension points where natural. | +| Exhaustive tests in stages | Every phase requires tests alongside implementation and phase-end review by both `explore` and `explore-heavy`. | +| Reports are for the user, not git | `06-phase-workflow.md` says planning docs are committed, but phase reports are not committed by default. | + +## Code Quality Expectations + +- Public protocol, adapter, transport, field-cache, and provider-extension classes must have docstrings that explain intent, override points, and future expansion hooks. +- Non-obvious transformations must have comments explaining why data is changed, preserved, reordered, or intentionally dropped. +- Lossy protocol conversions must be documented at the conversion site. +- Future WebSocket, target-group, and multi-user extension seams should be noted in comments where they affect today's design. +- Tests should prefer golden fixtures for protocol shapes and focused unit tests for transform edge cases. diff --git a/docs/experimental/01-protocol-architecture.md b/docs/experimental/01-protocol-architecture.md new file mode 100644 index 000000000..bf6349d83 --- /dev/null +++ b/docs/experimental/01-protocol-architecture.md @@ -0,0 +1,132 @@ +# Native Protocol Architecture + +Protocols are reusable bases, not rigid gospel. Providers can subclass, wrap, copy, or override protocol behavior when a provider deviates from an otherwise standard protocol. + +## Why Protocols First + +The current code relies heavily on LiteLLM and provider-specific transforms. That works, but it makes new protocols hard to reason about and makes debugging transformations difficult. The experimental goal is to make a provider mostly declarative: + +```text +provider = protocol + auth + adapters + field cache rules + model options + quota behavior +``` + +If a provider needs custom behavior, it should override a narrow protocol method instead of forcing an entirely bespoke request path. + +## Auto-Discovery + +Protocols should follow the provider plugin style: + +- protocol modules live under `src/rotator_library/protocols/`. +- modules register concrete protocol classes by name. +- a registry exposes names such as `openai_chat`, `anthropic_messages`, `gemini`, `responses`, and `litellm_fallback`. +- third-party or local protocol modules can be added later with minimal registry changes. + +## Core Types + +The unified representation should be explicit enough to cover all existing providers and the external reference protocols without losing important data. + +Suggested types: + +- `UnifiedRequest` +- `UnifiedResponse` +- `UnifiedStreamEvent` +- `UnifiedMessage` +- `ContentBlock` +- `ToolDefinition` +- `ToolCall` +- `ToolResult` +- `ReasoningBlock` +- `Usage` +- `CostDetails` +- `ProtocolMetadata` + +These types should retain unknown provider-specific metadata in explicit extension dictionaries instead of dropping it. Robustness matters more than a narrow perfect schema. + +## Protocol Interface + +The base protocol should provide default methods that can be overridden: + +- `parse_request(raw_request, context) -> UnifiedRequest` +- `build_request(unified_request, context) -> raw_provider_request` +- `parse_response(raw_response, context) -> UnifiedResponse` +- `format_response(unified_response, context) -> raw_client_response` +- `parse_stream_event(raw_event, context) -> UnifiedStreamEvent` +- `format_stream_event(unified_event, context) -> raw_stream_payload` +- `extract_usage(raw_or_unified, context) -> Usage | None` +- `supports_transport(transport_name) -> bool` + +Provider-specific overrides should receive context that includes provider name, model, credential identity, source protocol, target protocol, request ID, and session tracking information. + +## Initial Protocols + +### OpenAI Chat + +Must support: + +- chat completions request/response. +- stream chunks. +- tools and tool calls. +- function-call legacy shapes. +- reasoning fields from OpenAI-compatible providers. +- cached token and reasoning token usage details. + +### Anthropic Messages + +Must support: + +- messages request/response. +- system content extraction. +- text, image, thinking, redacted thinking, tool_use, tool_result blocks. +- stream lifecycle events. +- count_tokens path later if needed. + +### Gemini + +Must support: + +- generateContent and streamGenerateContent shapes. +- content parts. +- functionCall/functionResponse. +- thought signatures. +- safety settings passthrough without unsafe auto-injection. +- Google/Gemini usage metadata. + +### Responses + +Must support: + +- OpenAI Responses request/response. +- `previous_response_id`. +- output items. +- event streams. +- storage-friendly response objects. +- future WebSocket transport. + +### LiteLLM Fallback + +Must preserve existing behavior for providers/protocols not yet native. This path should be explicit and transaction-logged as a fallback, not hidden. + +## Transport Separation + +Protocol formatting must not be tied only to HTTP SSE. Define a transport boundary so the same unified stream events can be emitted through: + +- non-streaming HTTP JSON. +- HTTP SSE. +- future WebSocket. + +The Responses phase should leave clear extension points for WebSocket even if WebSocket is implemented later. + +## Error Handling + +Protocols should preserve provider error bodies where safe, but format client-facing errors consistently. Parsing errors should include transform-pass names and request IDs to make transaction logs useful. + +## Docstrings And Comments + +Protocol code should include docstrings explaining: + +- which external API shape it models. +- what fields are intentionally preserved in metadata. +- where provider overrides are expected. +- future expansion hooks. + +Comments should explain non-obvious transformations, especially lossy conversions between protocols. diff --git a/docs/experimental/02-transform-logging.md b/docs/experimental/02-transform-logging.md new file mode 100644 index 000000000..5cf74b1d1 --- /dev/null +++ b/docs/experimental/02-transform-logging.md @@ -0,0 +1,84 @@ +# Transform Pass Logging + +Debuggability is a core requirement. Every request, response, and stream payload must be inspectable after each transformation pass. + +## Existing Baseline + +The current project already has transaction and raw I/O logging, but it does not consistently show every intermediate transformed state. The experimental protocol layer must improve this without making normal operation too noisy. + +## Transform Trace Model + +Each request should have a trace containing ordered pass records. A pass record should include: + +- request ID. +- pass name. +- direction: request, response, stream_in, stream_out, error. +- protocol/provider/model context. +- timestamp. +- payload snapshot, redacted if needed. +- optional notes or warnings. +- exception information if the pass failed. + +## Required Request Passes + +Suggested pass names: + +- `raw_client_request` +- `parsed_unified_request` +- `after_session_inference` +- `after_field_cache_injection` +- `after_request_adapters` +- `after_provider_override` +- `provider_request` +- `litellm_fallback_request` when fallback is used + +## Required Response Passes + +- `raw_provider_response` +- `parsed_unified_response` +- `after_field_cache_extraction` +- `after_response_adapters` +- `after_client_protocol_format` +- `final_client_response` + +## Required Stream Passes + +For streams, every event can be large. Logging should support configurable sampling/full capture, but the architecture must be able to record: + +- `raw_provider_stream_event` +- `parsed_unified_stream_event` +- `after_stream_field_cache_extraction` +- `after_stream_adapters` +- `formatted_client_stream_event` + +The transaction logger should be able to record stream events as JSONL to avoid retaining large streams in memory. + +## Redaction + +Even though full multi-user security is out of scope, transform logging must avoid accidental credential leakage. + +Redaction hooks should cover: + +- `Authorization` headers. +- `x-api-key` and related API key headers. +- provider API keys. +- OAuth access and refresh tokens. +- cookies. +- obvious `api_key`, `access_token`, `refresh_token`, `client_secret`, and `Authorization` fields in JSON payloads. + +Redaction should happen at the logging boundary, not by mutating live request objects. + +## Failure Debugging + +When a transform fails, the trace should identify: + +- failed pass name. +- protocol class. +- provider/model. +- whether the failure occurred before or after provider execution. +- original error type. +- redacted payload snapshot when possible. + +## Future Expansion + +Later admin/debug endpoints can read these traces. For now, file-based transaction logging is sufficient. diff --git a/docs/experimental/03-field-cache-rules.md b/docs/experimental/03-field-cache-rules.md new file mode 100644 index 000000000..a93e13636 --- /dev/null +++ b/docs/experimental/03-field-cache-rules.md @@ -0,0 +1,113 @@ +# Field Cache Rules + +Field caching is required for providers that need values from previous responses or stream events to be returned on later requests. Examples include reasoning content, thought signatures, prompt cache keys, provider session IDs, and response IDs. + +## Goals + +- Let custom providers configure what fields to extract and where to inject them. +- Avoid hardcoding every provider-specific memory behavior. +- Preserve strict scoping so values never leak across provider, model, credential, classifier scope, or session. +- Support both non-streaming responses and streaming events. +- Support rules per provider and per model. + +## Rule Shape + +Illustrative JSON shape: + +```json +{ + "name": "reasoning_content", + "source": "response", + "path": "choices.*.message.reasoning_content", + "scope": "session", + "mode": "last", + "inject": { + "target": "request", + "path": "messages[-1].reasoning_content" + } +} +``` + +This is a design sketch, not the final schema. + +## Sources + +- `request` +- `response` +- `stream_event` +- `unified_request` +- `unified_response` +- `unified_stream_event` + +## Targets + +- raw provider request path. +- unified request field. +- protocol metadata. +- provider-specific extension field. + +## Scopes + +- `provider` +- `model` +- `credential` +- `session` +- `conversation` +- combinations of the above when needed. + +The default for conversation-affecting fields should be at least provider+model+session scoped. + +## Modes + +- `last`: only the latest matching value. +- `all`: all matching values within the scope. +- `last_user_turn`: latest value associated with the last user turn. +- `last_assistant_turn`: latest value associated with the last assistant turn. +- `per_tool_call`: keyed by tool call ID. + +## Backing Store + +Use the existing provider cache infrastructure first. Do not require SQLite. + +Potential cache keys should include: + +```text +provider / model / credential-or-scope / session-id / rule-name +``` + +Private/classifier scoped credentials must not share cached fields with global credentials. + +## Examples + +### DeepSeek Reasoning Content + +Extract reasoning content from assistant responses and inject it into the next provider request when the provider expects continuity. + +### Gemini Thought Signatures + +Extract thought signatures from Gemini response parts and return them with matching future content parts. + +### Responses Previous Response ID + +Store response IDs and output items so `previous_response_id` can load prior context. + +### Prompt Cache Keys + +Carry `prompt_cache_key` or equivalent provider cache routing values forward when a provider benefits from stable cache routing. + +## Tests + +Required test categories: + +- extraction from response. +- extraction from stream event. +- injection into next request. +- `last` versus `all` behavior. +- scope isolation. +- missing path is a no-op. +- malformed path produces useful validation error. +- redaction in transform logs. + +## Key Decision + +Field cache rules are a protocol/provider extension system, not a replacement for `SessionTracker`. Session tracking decides continuity and credential affinity; field cache rules preserve provider-specific protocol state. diff --git a/docs/experimental/04-provider-roadmap.md b/docs/experimental/04-provider-roadmap.md new file mode 100644 index 000000000..923f84d84 --- /dev/null +++ b/docs/experimental/04-provider-roadmap.md @@ -0,0 +1,115 @@ +# Provider Roadmap + +Providers follow protocols. The protocol layer must land first so provider work can be small, testable, and declarative where possible. + +## Provider Declaration Target + +A provider should eventually be expressible as: + +```text +provider name + + protocol(s) + + auth strategy + + model definitions/options + + adapter chain + + field cache rules + + quota checker/parser + + optional protocol/provider overrides +``` + +Providers can still have custom Python code. The point is to make custom code narrow. + +## Priority Providers + +1. Claude Code. +2. Codex. +3. Copilot. +4. Antigravity. +5. Gemini CLI review and parity improvements. + +## Claude Code + +Review the external reference gateway for: + +- OAuth/token handling. +- request/response protocol shape. +- tool filtering or tool proxy behavior. +- quota checks. +- stream behavior. +- Claude Code-specific headers and model naming. + +Expected implementation direction: + +- use Anthropic Messages or Responses where applicable. +- add provider-specific adapters for tool behavior. +- field cache rules for thinking/signatures if needed. + +## Codex + +Review the external reference gateway for: + +- Responses API route usage. +- Codex-specific user-agent/version behavior. +- OAuth/account handling. +- cooldown parsing for Codex usage limits. +- stream events. + +Expected implementation direction: + +- build on the Responses protocol. +- add Codex provider auth and headers. +- include version/user-agent support if needed. + +## Copilot + +Review the external reference gateway for: + +- GitHub Copilot OAuth flows. +- endpoint selection. +- model naming. +- quota checker behavior. +- provider-specific request filtering. + +Expected implementation direction: + +- use protocol adapters where possible. +- add provider-specific auth/token refresh only where necessary. + +## Antigravity + +Required comparison: + +- current retired implementation under `src/rotator_library/providers/_retired/`. +- external reference Antigravity provider/checker/OAuth behavior. + +Expected implementation direction: + +- restore only what is still valid. +- reuse protocol and field cache rules. +- avoid resurrecting obsolete device-profile behavior unless clearly required. + +## Gemini CLI + +The current Gemini CLI provider is already deep. Review the external reference gateway only for missed behavior: + +- quota checker details. +- thought signature handling. +- stream transform differences. +- OAuth edge cases. +- Gemini 3 tool behavior. +- request headers and endpoint details. + +Do not rewrite Gemini CLI just for architectural purity. + +## Provider Tests + +Each provider should have tests for: + +- config/load/registration. +- auth header or token acquisition. +- request translation. +- response translation. +- stream translation. +- quota parser/checker behavior. +- field cache extraction/injection. +- LiteLLM fallback not used when native path should apply. diff --git a/docs/experimental/05-routing-retry-usage-roadmap.md b/docs/experimental/05-routing-retry-usage-roadmap.md new file mode 100644 index 000000000..210ccf13a --- /dev/null +++ b/docs/experimental/05-routing-retry-usage-roadmap.md @@ -0,0 +1,114 @@ +# Routing, Retry, Cooldown, Usage, And Cost Roadmap + +These systems should be layered around the protocol/provider work without replacing the existing credential engine. + +## Routing Direction + +The preferred first routing model is fallback chains, not a full target-group router from the external reference gateway. + +Example: + +```json +{ + "fallback_groups": [ + ["gemini_cli/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro", "openai/gpt-4.1"] + ] +} +``` + +Behavior: + +- If the requested model is in a group, start with the requested model and then continue through the rest of the group. +- Retryable provider/model failures can move to the next candidate. +- Non-retryable request errors stop unless explicitly configured otherwise. +- Each candidate delegates to the current credential rotation engine. +- Session tracking and classifier scopes must remain isolated. + +## Target Groups Later + +Reference target groups are still useful later for selectors: + +- `in_order` +- `random` +- `usage` +- `cost` +- `latency` +- `performance` + +But the first implementation should solve ordered fallback groups and the missing/stale fallback module before adding selector complexity. + +## Retry/Cooldown Direction + +Current provider cooldown is effectively inert because `CooldownManager.start_cooldown()` is not called. The real active cooldowns are per-credential in `UsageManager`. + +Upgrade direction: + +- keep per-credential cooldowns in `UsageManager`. +- add provider/model cooldown for provider-wide or model-wide failures. +- add consecutive failure tracking. +- add exponential backoff. +- `retry_after` should override computed backoff. +- success should clear provider/model failure count. +- credential quota exhaustion should not globally cool down a provider unless evidence says the provider/model is globally exhausted. + +## Retry History + +Add structured attempt records: + +- candidate provider/model. +- credential stable ID or masked identity. +- protocol path used. +- status: success, failed, skipped. +- error type. +- retryable decision. +- cooldown decision. +- timing. + +These records should be available to transaction logging and optionally client-facing debug output later. + +## Usage And Cost Direction + +Keep existing windowing, fair cycle, custom caps, quota groups, and usage persistence. Add protocol-aware normalization before usage is recorded. + +Needed normalizers: + +- OpenAI Chat. +- OpenAI Responses. +- Anthropic Messages. +- Gemini. +- OAuth/custom providers. + +Fields to preserve: + +- input tokens. +- output tokens. +- reasoning/thinking tokens. +- cache read tokens. +- cache write tokens. +- total tokens. +- provider-reported cost. +- estimated cost. +- cost source and metadata. + +## Cost Direction + +Provider-reported cost should win when available. Sources include: + +- `usage.cost_details`. +- provider-specific response fields. +- SSE comment lines such as `: cost { ... }`. + +Estimated cost should remain as fallback. + +## Tests + +Required tests: + +- fallback chain success after first failure. +- non-retryable error stops fallback. +- streaming fallback only before visible output. +- provider/model cooldown scope. +- backoff escalation. +- success reset. +- usage normalization for each protocol. +- provider-reported cost precedence. diff --git a/docs/experimental/06-phase-workflow.md b/docs/experimental/06-phase-workflow.md new file mode 100644 index 000000000..839fc88e8 --- /dev/null +++ b/docs/experimental/06-phase-workflow.md @@ -0,0 +1,89 @@ +# Phase Workflow + +This document describes how each phase should be executed. + +## 1. Refresh Understanding + +Before writing phase docs or code, inspect the current implementation relevant to the phase. The master plan is a guide, not a substitute for current-code analysis. + +## 2. Produce Phase Plan In Conversation + +First produce the phase plan in conversation text. This forces a fresh exhaustive design pass with the current implementation in mind. + +The plan should include: + +- goals. +- non-goals. +- files/modules to inspect or modify. +- data model changes. +- public API changes. +- transaction logging implications. +- docstrings, comments, and future-extension notes required for the phase. +- tests to add. +- commit checkpoints. +- risk/rollback notes. + +## 3. Write Phase Plan To Docs + +After the conversation plan is accepted or clearly settled, write it to: + +```text +docs/experimental/phase-N-*.md +``` + +Planning docs are committed. + +## 4. Implement In Checkpoint Commits + +Commits are not phase-only. Commit whenever a coherent slice is finished and tested. + +Each commit body should include: + +- what changed. +- why it changed. +- tests run. +- known limitations or follow-ups. + +## 5. Test Continuously + +Run the most relevant tests after each meaningful slice. Do not wait until the phase end. + +## 6. Review Agents + +At phase end, call exactly these two review perspectives: + +- `explore`: code/file-level verification. +- `explore-heavy`: deeper architecture/reference verification. + +Review prompts should compare the implementation against: + +- the phase plan. +- external reference areas where relevant. +- current proxy behavior that must be preserved. +- transaction logging expectations. +- tests. + +If either agent fails or runs out of context, restart it with a narrower prompt. + +## 7. Address Findings + +Fix real findings. If a finding is intentionally deferred, document it in the user-facing phase report, not necessarily in git. + +Repeat review if the changes are substantial. + +## 8. Report To User + +Write a phase report in the conversation. Reports are for the user and are not committed by default. + +The report should include: + +- completed work. +- commits made. +- tests run. +- review-agent findings and resolutions. +- known limitations. +- next phase recommendation. + +## 9. Move To Next Phase + +Do not rely blindly on previous plans. Start the next phase by refreshing context and producing the next phase plan in conversation text. diff --git a/docs/experimental/07-detailed-phase-roadmap.md b/docs/experimental/07-detailed-phase-roadmap.md new file mode 100644 index 000000000..45026167e --- /dev/null +++ b/docs/experimental/07-detailed-phase-roadmap.md @@ -0,0 +1,540 @@ +# Detailed Phase Roadmap + +This document expands the 10-phase roadmap before implementation begins. It exists to prevent later work from narrowing to only Phase 1 details and losing the full feature set. + +Each phase still requires a fresh conversation plan immediately before implementation. This document is the durable baseline; phase-specific plans can adapt after current-code inspection. + +## Phase 1: Protocol Core + +Purpose: create native protocol foundations without changing live execution. + +Primary deliverables: + +- `src/rotator_library/protocols/` package. +- Auto-discovered protocol registry modeled after provider discovery. +- Override-friendly `ProtocolAdapter` base class. +- Unified request, response, message, content, tool, reasoning, usage, cost, stream event, and context dataclasses. +- Base protocols for OpenAI Chat, Anthropic Messages, Gemini, OpenAI Responses, and LiteLLM fallback. +- JSON-safe serialization helpers for transaction tracing and fixtures. +- Protocol errors that identify protocol name, pass name, and payload preview. + +Key requirements: + +- Protocols are bases, not rigid implementations. +- Providers can subclass, wrap, copy, or override protocol methods. +- Unknown provider fields must be preserved in `extra`/metadata instead of dropped. +- Runtime behavior should not change yet. +- LiteLLM fallback remains explicit and named. + +Tests: + +- Registry discovery and alias resolution. +- Base preservation behavior. +- Round-trip and parse/format fixtures for OpenAI Chat, Anthropic Messages, Gemini, and Responses. +- Stream event smoke coverage. + +Review focus: + +- Ensure protocol abstractions are not too narrow for current providers or external reference protocols. +- Ensure types are trace-friendly for Phase 2. + +## Phase 2: Transform Pass Logging + +Purpose: make every transformation state inspectable for debugging. + +Primary deliverables: + +- Transform trace model with ordered pass records. +- Transaction logger integration for request, response, stream, and error transform states. +- Redaction-at-log-boundary helpers. +- JSONL stream transform logs. +- Pass names shared by protocols, adapters, provider overrides, field-cache rules, and fallback execution. + +Required pass coverage: + +- `raw_client_request` +- `parsed_unified_request` +- `after_session_inference` +- `after_field_cache_injection` +- `after_request_adapters` +- `after_provider_override` +- `provider_request` +- `litellm_fallback_request` +- `raw_provider_response` +- `parsed_unified_response` +- `after_field_cache_extraction` +- `after_response_adapters` +- `after_client_protocol_format` +- `final_client_response` +- `raw_provider_stream_event` +- `parsed_unified_stream_event` +- `after_stream_field_cache_extraction` +- `after_stream_adapters` +- `formatted_client_stream_event` + +Key requirements: + +- Log snapshots must not mutate live request/response objects. +- Logs must preserve enough context to debug provider-specific behavior. +- Secret redaction should cover API keys, OAuth tokens, auth headers, cookies, and common secret field names. +- Logging must be usable by future admin/debug endpoints but remain file-based for now. + +Tests: + +- Trace pass ordering. +- Redaction behavior. +- Request/response/stream JSON serialization. +- Transform failure logging with pass name and protocol/provider context. + +Review focus: + +- Verify every planned pass is reachable from the architecture. +- Verify log output is useful without leaking credentials. + +## Phase 3: Adapter And Field Cache System + +Purpose: let custom providers configure what to transform, cache, return, and reinject without hardcoding every provider. + +Primary deliverables: + +- Adapter registry. +- Adapter chain execution for request, response, and stream events. +- Field-cache rule schema and validation. +- JSON-path-like extraction/injection helpers. +- Cache scope builder using provider, model, credential or classifier scope, session, conversation, and rule name. +- Provider/model-level rule configuration hooks. +- Initial built-in adapter rules for reasoning/thinking-related fields. + +Field cache capabilities: + +- Sources: request, response, stream event, unified request, unified response, unified stream event. +- Targets: raw provider request path, unified request field, protocol metadata, provider extension field. +- Scopes: provider, model, credential, session, conversation, and combinations. +- Modes: `last`, `all`, `last_user_turn`, `last_assistant_turn`, `per_tool_call`. +- Missing paths are no-ops unless strict validation is enabled. + +Examples that must be supported by design: + +- DeepSeek-style reasoning content. +- Anthropic thinking and redacted thinking signatures. +- Gemini thought signatures. +- Prompt cache keys. +- Provider session IDs. +- Responses `previous_response_id` metadata. + +Key requirements: + +- Field cache rules complement `SessionTracker`; they do not replace it. +- Rules must never leak across provider/model/credential/session boundaries. +- Providers can define default rules, and model config can override them. +- Transform logging must capture before/after extraction and injection passes. + +Tests: + +- Extract from response and stream events. +- Inject into later requests. +- Scope isolation. +- Mode behavior. +- Malformed path validation. +- Redacted transform logs. + +Review focus: + +- Verify custom provider authoring becomes declarative for common stateful fields. +- Verify no cross-session or cross-credential leaks. + +## Phase 4: Responses API And WebSocket-Ready Transport Shape + +Purpose: add high-priority OpenAI Responses API support while designing stream transports for future WebSocket support. + +Primary deliverables: + +- `/v1/responses` route. +- `GET /v1/responses/{response_id}` route. +- `DELETE /v1/responses/{response_id}` route. +- Optional Codex alias route if supported by provider work. +- JSON/file response storage, not SQLite. +- TTL cleanup for stored responses. +- `previous_response_id` loading and session anchor integration. +- SSE Responses streaming. +- Transport interfaces that can later support WebSocket without rewriting protocol logic. + +Transport shape: + +- Non-streaming HTTP JSON transport. +- HTTP SSE transport. +- Future WebSocket transport extension point. +- Unified stream events from Phase 1 should flow through transports. + +Key requirements: + +- `previous_response_id` must become strong session evidence where safe. +- Stored response objects must be scoped and cleaned up. +- Responses usage and output items must be normalizable in Phase 9. +- Transform logging must show Responses parse/build/storage-relevant states. + +Tests: + +- Create response. +- Retrieve response. +- Delete response. +- Previous response continuation. +- SSE event formatting. +- TTL cleanup. +- Transport interface extension smoke tests. + +Review focus: + +- Verify compatibility with OpenAI Responses expectations and external reference Responses behavior. +- Verify WebSocket is not blocked by the design. + +## Phase 5: Provider Protocol Overhaul + +Purpose: make providers use native protocols where practical and add priority providers after protocol foundations exist. + +Primary deliverables: + +- Provider protocol declaration mechanism. +- Provider hooks for protocol selection, adapter rules, field cache rules, and model options. +- Claude Code provider implementation or integration path. +- Codex provider implementation or integration path. +- Copilot provider implementation or integration path. +- Restored Antigravity provider if current/reference behavior supports it. +- Gemini CLI parity review and targeted fixes only where the external reference gateway has real improvements. + +Provider priorities: + +- Claude Code. +- Codex. +- Copilot. +- Antigravity. +- Gemini CLI review. + +Antigravity comparison requirements: + +- Compare the external reference implementation to `src/rotator_library/providers/_retired/antigravity_provider.py`. +- Restore only valid behavior. +- Avoid obsolete device-profile or fragile logic unless required by the current service. + +Key requirements: + +- Providers can override protocol methods when the base is close but not exact. +- New providers should avoid monolithic transform logic where adapter rules suffice. +- Native path must be testable independently of live credentials. +- LiteLLM fallback should be explicit if used. + +Tests: + +- Provider registration. +- Protocol selection. +- Auth header/token behavior via mocks. +- Request/response/stream translation. +- Quota/checker parsing. +- Field cache rules. +- No accidental LiteLLM fallback when native path should apply. + +Review focus: + +- Verify provider logic follows protocol architecture instead of reintroducing bespoke protocol code everywhere. +- Verify Gemini CLI improvements do not regress existing behavior. + +## Phase 6: Routing And Fallback Groups + +Purpose: add ordered model/provider fallback while keeping credential rotation inside each candidate. + +Primary deliverables: + +- Fallback group config parser. +- Ordered candidate planner for provider/model fallback chains. +- Retryable/non-retryable fallback decisions. +- Streaming fallback rules that only fallback before visible output. +- Optional target-group structure after fallback groups work. +- Optional selectors after ordered fallback works. + +Fallback behavior: + +- If requested model is in a group, try it first. +- Continue through remaining candidates only for retryable failures or configured fallback conditions. +- Preserve classifier/private credential scopes. +- Preserve session namespace isolation. +- Each candidate delegates to current credential selection and usage tracking. + +Future target group selectors: + +- `in_order`. +- `random`. +- `usage`. +- `cost`. +- `latency`. +- `performance`. + +Key requirements: + +- Fix or replace stale `fallback_groups` expectations without breaking current model resolution. +- Do not replace `UsageManager` or `SelectionEngine`. +- Transform logging should show chosen candidate and fallback attempt history. + +Tests: + +- Ordered fallback after retryable failure. +- Non-retryable failure stops. +- Requested model promotion. +- Exhausted chain reports useful error. +- Streaming fallback before output only. +- Scope/session isolation. + +Review focus: + +- Verify fallback integrates above credential rotation, not inside it. +- Verify behavior matches user preference for fallback chains over target-group complexity. + +## Phase 7: Retry/Cooldown/Failover Cleanup + +Purpose: streamline retry behavior and make provider/model cooldown real. + +Primary deliverables: + +- Replace or activate the currently inert provider `CooldownManager` path. +- Provider/model cooldown keys. +- Consecutive failure tracking. +- Exponential backoff. +- Retry-after precedence over computed cooldown. +- Success reset behavior. +- Retry history records for logging and future debug surfaces. +- Integration with fallback groups from Phase 6. + +Cooldown layers: + +- Credential-level cooldown remains in `UsageManager`. +- Provider/model cooldown applies only to evidence of provider-wide or model-wide failure. +- Credential quota exhaustion must not automatically cool an entire provider. +- Model cooldown should not block healthy models on the same provider. +- Provider cooldown should be reserved for provider-wide failure evidence. + +Retry history fields: + +- candidate provider/model. +- credential stable ID or masked identity. +- protocol path used. +- attempt number. +- status: success, failed, skipped, cooled_down. +- error category and raw classifier result. +- retryable decision. +- cooldown decision and duration. +- timing and latency. + +Key requirements: + +- Preserve the current strong retry-after parser, especially Google/Gemini compound duration parsing. +- Preserve streaming safety: no retry after visible output unless explicitly safe. +- Make cooldown state observable by transaction logs and future endpoint surfaces. + +Tests: + +- `start_cooldown` or successor has production callers. +- Provider cooldown blocks provider-wide only. +- Model cooldown blocks one model only. +- Credential cooldown does not become provider cooldown. +- Retry-after overrides exponential backoff. +- Backoff escalates with repeated failures. +- Success clears counts. +- Retry history is recorded. +- Fallback respects cooldown skips. + +Review focus: + +- Verify no healthy credential/model is suppressed too broadly. +- Verify retry/fallback/streaming interactions are deterministic. + +## Phase 8: Streaming Library Upgrade + +Purpose: harden streaming as a reusable library capability, not just a proxy route behavior. + +Primary deliverables: + +- Transport-aware stream event pipeline using unified stream events. +- Explicit upstream iterator close/cancel on client disconnect. +- TTFB timeout before first emitted output. +- TTFT metrics. +- Throughput stall detector. +- Stream usage extraction through protocol normalizers. +- Stream transform logging for raw, unified, adapted, and formatted states. +- SSE transport improvements and WebSocket-ready transport boundaries. + +Streaming behavior: + +- Retry is allowed before client-visible output when the error is retryable. +- Retry is not allowed after visible output unless a future protocol explicitly proves safety. +- Partial streams should not record accepted session response anchors. +- Completed streams should record response anchors and normalized usage. +- Disconnects should close upstream resources promptly. + +Metrics: + +- time to first byte. +- time to first token/content. +- tokens per second where token counts are available. +- chunk count. +- stall duration. +- completion status. + +Key requirements: + +- Implement with Python-native async patterns, not runtime-specific socket internals from the external reference gateway. +- Keep current conservative retry safety. +- Preserve current usage recording behavior and then improve it via Phase 9 normalizers. + +Tests: + +- Disconnect closes upstream async iterator. +- TTFB timeout retries before output. +- Stall detection trips after grace period. +- No retry after visible content. +- Retry before first output. +- Stream transform logging writes expected pass records. +- SSE formatting remains compatible. + +Review focus: + +- Verify library-level design is not tied only to FastAPI route wrappers. +- Verify future WebSocket can reuse the same unified event stream. + +## Phase 9: Usage, Quota, And Cost Accuracy + +Purpose: make usage and cost accounting protocol-aware while preserving the current usage engine. + +Primary deliverables: + +- Protocol-aware usage normalizer interface. +- Usage normalizers for OpenAI Chat, OpenAI Responses, Anthropic Messages, Gemini, OAuth/custom providers, and LiteLLM fallback. +- Structured cost details model integrated with protocol usage. +- Provider-reported cost extraction. +- SSE cost comment parsing where providers emit cost metadata. +- Reasoning/thinking token normalization. +- Cache read/write token normalization. +- Meter/checker abstraction for proactive provider quota checks where useful. + +Usage fields: + +- input tokens. +- output tokens. +- total tokens. +- reasoning/thinking tokens. +- cache read tokens. +- cache write tokens. +- provider-reported cost. +- estimated cost. +- cost source. +- raw provider usage metadata. + +Cost precedence: + +- Provider-reported cost wins when present and trusted. +- Structured provider cost fields are preferred over text parsing. +- SSE `: cost` comments can supply provider-reported cost. +- Estimated cost remains fallback. + +Current systems to preserve: + +- `UsageManager` facade. +- windowed tracking. +- fair cycle. +- custom caps. +- quota groups. +- classifier/private scope separation. +- JSON usage persistence. + +Key requirements: + +- Do not replace the usage engine. +- Normalize before recording when native protocol path is used. +- Existing LiteLLM response usage should continue working. +- Cost details should be transaction-loggable and eventually available to APIs/TUI. + +Tests: + +- OpenAI Chat usage details. +- Responses usage details. +- Anthropic cache read/write tokens. +- Gemini usage metadata including thoughts/cache when available. +- Reasoning token extraction. +- Provider-reported cost precedence. +- SSE cost comment extraction. +- Existing usage aggregation tests still pass. + +Review focus: + +- Verify accounting is more accurate without disrupting selection and quota logic. +- Verify provider-specific raw usage remains available for debugging. + +## Phase 10: Config Polish + +Purpose: support powerful protocol/routing/provider configuration without SQLite. + +Primary deliverables: + +- Optional JSON config file support. +- Env var pointing to JSON config path. +- Env overrides JSON. +- Validation with actionable errors. +- Config sections for protocols, adapters, field-cache rules, fallback groups, providers, model overrides, quota checkers, and stream settings. +- Documentation/examples for custom provider setup using existing protocols. + +Config priorities: + +- `.env` remains enough for simple setups. +- JSON supports complex nested config that is painful in env vars. +- Env overrides allow quick local changes. +- No SQLite/Postgres requirement. + +Potential JSON sections: + +- `protocols`. +- `providers`. +- `models`. +- `adapters`. +- `field_cache`. +- `fallback_groups`. +- `quota_checkers`. +- `streaming`. +- `logging`. + +Key requirements: + +- Validation errors must name the config section and field. +- Bad config should fail early when possible. +- Config should support per-provider and per-model overrides. +- Config should not require rewriting existing `.env` usage immediately. +- Config docs should show custom provider examples using OpenAI Chat, Anthropic Messages, Gemini, Responses, and LiteLLM fallback. + +Tests: + +- Env-only config. +- JSON-only config. +- Env overrides JSON. +- Bad config failure messages. +- Provider/model adapter rule config. +- Fallback group config. +- Field-cache rule config. + +Review focus: + +- Verify custom provider setup becomes practical and documented. +- Verify no accidental database dependency is introduced. + +## Cross-Phase Review Contract + +Every phase ends with two review agents: + +- `explore` for code/file-level verification. +- `explore-heavy` for deeper architecture and reference comparison. + +Review prompts must include the relevant phase plan, external reference areas, current proxy behavior to preserve, tests run, and transaction logging expectations. If either agent fails or runs out of context, restart with narrower scope. + +## Cross-Phase Testing Contract + +Every implementation phase must add or update tests before being considered complete. New test files may need `git add -f` because the repository ignores most `tests/*` by default. + +## Cross-Phase Documentation Contract + +Implementation code must include docstrings and comments for public extension points, non-obvious transformations, lossy conversions, future WebSocket seams, provider override points, and config decisions. diff --git a/docs/experimental/phase-1-protocol-core.md b/docs/experimental/phase-1-protocol-core.md new file mode 100644 index 000000000..aacfea76d --- /dev/null +++ b/docs/experimental/phase-1-protocol-core.md @@ -0,0 +1,172 @@ +# Phase 1: Protocol Core + +## Goal + +Introduce the native protocol foundation without changing live request execution yet. This phase creates robust, documented primitives that later phases can wire into `RequestExecutor`, provider declarations, adapters, field-cache rules, Responses, and streaming transports. + +## Non-Goals + +- Do not replace LiteLLM execution yet. +- Do not add `/v1/responses` routes yet. +- Do not migrate providers to native protocols yet. +- Do not implement field-cache persistence yet. +- Do not rewrite current Anthropic compatibility routes yet. +- Do not change existing request behavior unless a test-only import path requires a harmless export. + +## Current Code Context + +- No `src/rotator_library/protocols/` package exists yet. +- Provider auto-discovery in `src/rotator_library/providers/__init__.py` is the model for protocol discovery. +- `RequestContext` currently holds execution/session/logging fields and can later receive protocol metadata, but Phase 1 should avoid mutating it unless necessary. +- `ProviderTransforms` is hardcoded and will remain active until later adapter migration. +- `TransactionLogger` currently logs initial request, transformed request, response, and stream chunks; Phase 2 will add transform-pass tracing, but Phase 1 types should be trace-friendly. +- `anthropic_compat` already has useful conversion knowledge, especially thinking/tool block handling, but Phase 1 should not delete or replace it. + +## Files To Add + +- `src/rotator_library/protocols/__init__.py` +- `src/rotator_library/protocols/types.py` +- `src/rotator_library/protocols/base.py` +- `src/rotator_library/protocols/registry.py` +- `src/rotator_library/protocols/openai_chat.py` +- `src/rotator_library/protocols/anthropic_messages.py` +- `src/rotator_library/protocols/gemini.py` +- `src/rotator_library/protocols/responses.py` +- `src/rotator_library/protocols/litellm_fallback.py` +- `tests/test_protocol_registry.py` +- `tests/test_protocol_openai_chat.py` +- `tests/test_protocol_anthropic_messages.py` +- `tests/test_protocol_gemini.py` +- `tests/test_protocol_responses.py` + +## Possible Files To Touch + +- `src/rotator_library/__init__.py` only if a public lazy export is needed. +- `src/rotator_library/core/__init__.py` only if shared exports are cleaner there. +- Avoid modifying `RequestExecutor` in Phase 1 unless tests reveal a strict import issue. + +## Data Model Plan + +- `ProtocolRole`: role names should remain strings in payloads, but internal dataclasses can use simple `str` fields to avoid over-constraining custom protocols. +- `ContentBlock`: typed block with `type`, optional text/image/source/tool fields, and `extra` dict for provider-specific data. +- `UnifiedMessage`: `role`, `content`, `name`, `tool_call_id`, `tool_calls`, `extra`. +- `ToolDefinition`: protocol-neutral tool schema with `name`, `description`, `input_schema`, `extra`. +- `ToolCall`: `id`, `name`, `arguments`, `type`, `index`, `extra`. +- `ToolResult`: `tool_call_id`, `content`, `is_error`, `extra`. +- `ReasoningBlock`: `type`, `text`, `signature`, `redacted`, `extra`. +- `Usage`: input/output/total tokens, cache read/write, reasoning tokens, raw usage, cost details. +- `CostDetails`: provider reported cost, estimated cost, currency, source, metadata. +- `UnifiedRequest`: model, messages, tools, system, stream flag, generation params, response format, previous response ID, metadata, raw payload, extra. +- `UnifiedResponse`: id, model, messages/output, stop reason, usage, metadata, raw payload, extra. +- `UnifiedStreamEvent`: event type, delta/message/tool/usage/error metadata, raw event, extra. +- `ProtocolContext`: provider, model, source protocol, target protocol, request ID, session ID, credential stable ID, transport, transaction metadata, provider options. +- `ProtocolResult` is probably unnecessary in Phase 1; keep methods direct and simple. + +## Protocol Interface Plan + +- `ProtocolAdapter` abstract/base class with override-friendly methods. +- Required class attributes: `name`, `aliases`, `supported_transports`. +- Methods: + - `parse_request(raw_request, context) -> UnifiedRequest` + - `build_request(unified_request, context) -> dict` + - `parse_response(raw_response, context) -> UnifiedResponse` + - `format_response(unified_response, context) -> dict` + - `parse_stream_event(raw_event, context) -> UnifiedStreamEvent` + - `format_stream_event(unified_event, context) -> Any` + - `extract_usage(raw_or_unified, context) -> Usage | None` + - `supports_transport(transport_name) -> bool` +- Defaults should preserve unknown data in `extra` rather than raising. +- Errors should be `ProtocolError` with protocol name, pass name, and optional payload preview. + +## Registry Plan + +- `PROTOCOL_PLUGINS: dict[str, type[ProtocolAdapter]]` +- `register_protocol(cls)` for explicit class registration. +- `get_protocol(name)` returns an instance or class consistently. +- `list_protocols()`. +- Auto-discover modules under `src/rotator_library/protocols/`, skipping private modules and infrastructure modules. +- Support aliases so `openai`, `chat_completions`, and `openai_chat` can resolve to the same adapter. +- Prevent duplicate names unless the class is identical or explicit replacement is requested later. +- Keep registry import safe and lightweight. + +## Initial Protocol Behavior + +### OpenAI Chat + +- Parse chat completions request messages, system/developer/user/assistant/tool roles, text content, multimodal content arrays, tools, tool calls, tool_choice, response_format, stream, temperature/top_p/max_tokens/stop. +- Parse responses with choices, assistant messages, tool_calls, reasoning/reasoning_content, and usage details. +- Parse stream chunks into `UnifiedStreamEvent` while preserving raw delta. +- Format back to OpenAI Chat without losing unknown fields in `extra`. + +### Anthropic Messages + +- Parse separate `system`, messages with content blocks, thinking/redacted_thinking, tool_use, tool_result, images, documents where currently supported. +- Map Anthropic usage to `Usage`. +- Preserve thinking signatures in `ReasoningBlock.extra`. +- Build/format enough for round-trip tests, not full replacement of `anthropic_compat` yet. + +### Gemini + +- Parse `contents`, roles, parts, text, inline data, file data, functionCall, functionResponse, thought/thoughtSignature where present. +- Parse `generationConfig`, `safetySettings`, `tools`, stream flag metadata. +- Map usage metadata prompt/candidates/thoughts/cache where possible. +- Preserve provider-specific safety and generation fields. + +### Responses + +- Parse `input`, `instructions`, `previous_response_id`, tools, metadata, stream. +- Parse `output` items, message content, reasoning items, function/tool calls. +- Parse common response stream events into unified stream events. +- Do not add storage or routes yet. + +### LiteLLM Fallback + +- Wrap existing OpenAI-compatible dicts into unified request/response with raw preservation. +- Exist mainly as a named explicit protocol path for later logging. + +## Transaction Logging Implications + +- Phase 1 does not integrate runtime transform logs yet. +- All dataclasses need `to_dict()`/`from_dict()` or safe serialization helpers so Phase 2 can log every pass cleanly. +- `ProtocolError` should include pass names that Phase 2 can reuse. +- Avoid mutating raw payloads in protocol parse methods unless explicitly building a new provider request. + +## Docstrings And Comments Required + +- Every public protocol class explains which external API shape it models and which parts are intentionally partial/base behavior. +- `ProtocolAdapter` docstring explains override contract and why protocols are bases. +- Registry comments explain auto-discovery and skip rules. +- Conversion helpers comment on lossy or approximate mappings. +- Future-extension comments note where WebSocket, field-cache rules, provider overrides, and target transports will attach. + +## Tests + +- Registry auto-discovers built-in protocols. +- Aliases resolve correctly. +- Duplicate registration behavior is deterministic. +- Base protocol default methods preserve raw payloads. +- OpenAI Chat request round-trip with system/developer/user/assistant/tool messages, tool definitions, tool calls, reasoning content, and usage with cache/reasoning token details. +- Anthropic Messages request round-trip with system field, text blocks, thinking/redacted_thinking blocks, tool_use/tool_result, and usage cache fields. +- Gemini request round-trip with contents and parts, functionCall/functionResponse, thoughtSignature, generationConfig/safetySettings, and usageMetadata. +- Responses request/response parse with instructions, input string and input message list, previous_response_id, output messages, reasoning items, and function/tool call items. +- Stream event parse smoke tests for each protocol where practical. +- Serialization tests ensure unified types are JSON-serializable. + +## Commit Checkpoints + +1. Add protocol dataclasses, errors, and serialization helpers. +2. Add base adapter and registry auto-discovery. +3. Add LiteLLM fallback and OpenAI Chat protocol with tests. +4. Add Anthropic Messages protocol with tests. +5. Add Gemini protocol with tests. +6. Add Responses protocol with tests. +7. Run focused protocol test set and fix issues. +8. Phase review with `explore` and `explore-heavy`, then fixes if needed. + +## Risk And Rollback + +- Keep Phase 1 isolated so rollback is just removing the new package/tests. +- Avoid touching executor behavior to prevent regressions. +- If auto-discovery causes import cycles, switch to explicit built-in imports inside registry while preserving the public registry API. +- If tests become too large for one file, split into `tests/protocols/` only if ignore rules are handled with force-add when committing. +- Since `.gitignore` ignores most `tests/*`, remember to force-add new test files when committing. From 58415f58d9d4b78fd8cb85bec881e72792326076 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 22:37:15 +0200 Subject: [PATCH 002/182] feat(protocols): add native protocol core Introduces protocol-neutral request, response, stream event, content, tool, reasoning, usage, cost, and context dataclasses with JSON-safe serialization for future transform tracing. Adds the override-friendly ProtocolAdapter base and auto-discovery registry with alias handling, duplicate detection, shared stateless instances, and tests for serialization, default preservation, registration, aliases, and protocol errors. Tests: python -m pytest tests/test_protocol_registry.py --- src/rotator_library/protocols/__init__.py | 65 ++++ src/rotator_library/protocols/base.py | 129 ++++++++ src/rotator_library/protocols/registry.py | 115 +++++++ src/rotator_library/protocols/types.py | 385 ++++++++++++++++++++++ tests/test_protocol_registry.py | 96 ++++++ 5 files changed, 790 insertions(+) create mode 100644 src/rotator_library/protocols/__init__.py create mode 100644 src/rotator_library/protocols/base.py create mode 100644 src/rotator_library/protocols/registry.py create mode 100644 src/rotator_library/protocols/types.py create mode 100644 tests/test_protocol_registry.py diff --git a/src/rotator_library/protocols/__init__.py b/src/rotator_library/protocols/__init__.py new file mode 100644 index 000000000..a16b9fc20 --- /dev/null +++ b/src/rotator_library/protocols/__init__.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Native protocol adapters for provider-independent request handling. + +Importing this package auto-discovers built-in protocol adapters, similar to the +provider plugin system. Runtime execution is not changed by Phase 1; the package +only exposes reusable protocol primitives for later phases. +""" + +from .base import ProtocolAdapter +from .registry import ( + PROTOCOL_ALIASES, + PROTOCOL_PLUGINS, + get_protocol, + get_protocol_class, + list_protocols, + register_protocol, + resolve_protocol_name, +) +from .types import ( + ContentBlock, + CostDetails, + ProtocolContext, + ProtocolError, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, + first_text, + serialize_value, + text_blocks, +) + +__all__ = [ + "PROTOCOL_ALIASES", + "PROTOCOL_PLUGINS", + "ContentBlock", + "CostDetails", + "ProtocolAdapter", + "ProtocolContext", + "ProtocolError", + "ReasoningBlock", + "ToolCall", + "ToolDefinition", + "ToolResult", + "UnifiedMessage", + "UnifiedRequest", + "UnifiedResponse", + "UnifiedStreamEvent", + "Usage", + "first_text", + "get_protocol", + "get_protocol_class", + "list_protocols", + "register_protocol", + "resolve_protocol_name", + "serialize_value", + "text_blocks", +] diff --git a/src/rotator_library/protocols/base.py b/src/rotator_library/protocols/base.py new file mode 100644 index 000000000..fb5e87079 --- /dev/null +++ b/src/rotator_library/protocols/base.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Base classes for native protocol adapters. + +Protocol adapters are intentionally override-friendly. They provide reusable +defaults for custom providers, but providers can override any method when a +service uses an almost-standard protocol with provider-specific quirks. +""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .types import ( + ProtocolContext, + ProtocolError, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, +) + + +class ProtocolAdapter: + """Base adapter for converting between raw protocol payloads and unified types. + + Subclasses should override only the methods they need. The default behavior + is deliberately conservative: preserve raw payloads and avoid lossy + assumptions. Later phases will layer transform logging, field-cache rules, + and provider override hooks around this interface. + """ + + name: ClassVar[str] = "base" + aliases: ClassVar[tuple[str, ...]] = () + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + + def supports_transport(self, transport_name: str) -> bool: + """Return whether this protocol can format the requested transport.""" + + return transport_name in self.supported_transports + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + """Parse a raw client/provider request into a unified request.""" + + request = dict(raw_request or {}) + return UnifiedRequest( + model=str(request.get("model") or getattr(context, "model", None) or ""), + stream=bool(request.get("stream", False)), + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in {"model", "stream"}}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + """Build a provider request from a unified request. + + The base implementation returns the original raw dict when present. This + keeps fallback providers safe and gives custom protocol subclasses a + predictable starting point. + """ + + if isinstance(unified_request.raw, dict): + return deepcopy(unified_request.raw) + if not isinstance(unified_request.raw, type(None)): + raise ProtocolError( + "cannot build dict request from non-dict raw payload", + protocol=self.name, + pass_name="build_request", + payload=unified_request.raw, + ) + payload = {"model": unified_request.model, "stream": unified_request.stream} + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + """Parse a raw response into a unified response.""" + + response = raw_response if isinstance(raw_response, dict) else {} + return UnifiedResponse( + id=response.get("id") if isinstance(response, dict) else None, + model=response.get("model") if isinstance(response, dict) else getattr(context, "model", None), + raw=deepcopy(raw_response), + extra=deepcopy(response) if isinstance(response, dict) else {}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + """Format a unified response for a client protocol.""" + + if isinstance(unified_response.raw, dict): + return deepcopy(unified_response.raw) + payload = deepcopy(unified_response.extra) + if unified_response.id is not None: + payload.setdefault("id", unified_response.id) + if unified_response.model is not None: + payload.setdefault("model", unified_response.model) + return payload + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + """Parse one raw stream event. + + Subclasses should preserve the original event in ``raw`` because Phase 2 + transform logging needs both provider-native and unified states. + """ + + event_type = "done" if raw_event == "[DONE]" else "chunk" + return UnifiedStreamEvent(type=event_type, raw=deepcopy(raw_event)) + + def format_stream_event(self, unified_event: UnifiedStreamEvent, context: ProtocolContext | None = None) -> Any: + """Format one unified stream event for the target transport.""" + + return deepcopy(unified_event.raw) if unified_event.raw is not None else unified_event.to_dict() + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + """Extract normalized usage when the protocol can identify it.""" + + if isinstance(raw_or_unified, UnifiedResponse): + return raw_or_unified.usage + if isinstance(raw_or_unified, UnifiedStreamEvent): + return raw_or_unified.usage + if isinstance(raw_or_unified, dict) and isinstance(raw_or_unified.get("usage"), dict): + usage = raw_or_unified["usage"] + return Usage( + input_tokens=int(usage.get("prompt_tokens") or usage.get("input_tokens") or 0), + output_tokens=int(usage.get("completion_tokens") or usage.get("output_tokens") or 0), + total_tokens=int(usage.get("total_tokens") or 0), + raw=deepcopy(usage), + ) + return None diff --git a/src/rotator_library/protocols/registry.py b/src/rotator_library/protocols/registry.py new file mode 100644 index 000000000..ae9812f3c --- /dev/null +++ b/src/rotator_library/protocols/registry.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Auto-discovery registry for native protocol adapters.""" + +from __future__ import annotations + +import importlib +import inspect +import logging +import pkgutil +from typing import Type + +from .base import ProtocolAdapter + +lib_logger = logging.getLogger("rotator_library") + +PROTOCOL_PLUGINS: dict[str, Type[ProtocolAdapter]] = {} +PROTOCOL_ALIASES: dict[str, str] = {} +_PROTOCOL_INSTANCES: dict[str, ProtocolAdapter] = {} + +_INFRASTRUCTURE_MODULES = {"base", "registry", "types"} + + +def register_protocol(protocol_class: Type[ProtocolAdapter], *, replace: bool = False) -> Type[ProtocolAdapter]: + """Register a protocol adapter class and its aliases. + + The registry mirrors the provider plugin system while staying stricter about + duplicate names. This matters for custom protocol modules: accidental alias + collisions should fail early instead of silently changing conversion logic. + """ + + if not inspect.isclass(protocol_class) or not issubclass(protocol_class, ProtocolAdapter): + raise TypeError("protocol_class must inherit ProtocolAdapter") + if protocol_class is ProtocolAdapter: + raise TypeError("cannot register ProtocolAdapter itself") + + name = protocol_class.name + if not name: + raise ValueError(f"Protocol {protocol_class.__name__} must define a name") + + existing = PROTOCOL_PLUGINS.get(name) + if existing and existing is not protocol_class and not replace: + raise ValueError(f"Protocol name already registered: {name}") + + PROTOCOL_PLUGINS[name] = protocol_class + _PROTOCOL_INSTANCES.pop(name, None) + + for alias in protocol_class.aliases: + existing_name = PROTOCOL_ALIASES.get(alias) + if existing_name and existing_name != name and not replace: + raise ValueError(f"Protocol alias already registered: {alias}") + if alias in PROTOCOL_PLUGINS and alias != name and not replace: + raise ValueError(f"Protocol alias conflicts with registered name: {alias}") + PROTOCOL_ALIASES[alias] = name + + lib_logger.debug("Registered protocol: %s", name) + return protocol_class + + +def resolve_protocol_name(name: str) -> str: + """Resolve aliases to canonical protocol names.""" + + if name in PROTOCOL_PLUGINS: + return name + if name in PROTOCOL_ALIASES: + return PROTOCOL_ALIASES[name] + raise KeyError(f"Unknown protocol: {name}") + + +def get_protocol_class(name: str) -> Type[ProtocolAdapter]: + """Return a registered protocol adapter class by name or alias.""" + + return PROTOCOL_PLUGINS[resolve_protocol_name(name)] + + +def get_protocol(name: str) -> ProtocolAdapter: + """Return a shared stateless protocol adapter instance by name or alias.""" + + canonical = resolve_protocol_name(name) + if canonical not in _PROTOCOL_INSTANCES: + _PROTOCOL_INSTANCES[canonical] = PROTOCOL_PLUGINS[canonical]() + return _PROTOCOL_INSTANCES[canonical] + + +def list_protocols() -> list[str]: + """Return canonical protocol names in deterministic order.""" + + return sorted(PROTOCOL_PLUGINS) + + +def _register_protocols() -> None: + """Discover protocol modules in this package and register adapter classes. + + Private modules and infrastructure modules are skipped so local experiments + can live next to production protocols without being imported accidentally. + """ + + package = importlib.import_module(__package__ or "rotator_library.protocols") + for _, module_name, _ in pkgutil.iter_modules(package.__path__): + if module_name.startswith("_") or module_name in _INFRASTRUCTURE_MODULES: + continue + module = importlib.import_module(f"{package.__name__}.{module_name}") + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if ( + inspect.isclass(attribute) + and issubclass(attribute, ProtocolAdapter) + and attribute is not ProtocolAdapter + and attribute.__module__ == module.__name__ + ): + register_protocol(attribute) + + +_register_protocols() diff --git a/src/rotator_library/protocols/types.py b/src/rotator_library/protocols/types.py new file mode 100644 index 000000000..4440ce7c3 --- /dev/null +++ b/src/rotator_library/protocols/types.py @@ -0,0 +1,385 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Protocol-neutral data structures used by native protocol adapters. + +These types intentionally model common LLM API concepts without pretending the +set is complete. Providers and future protocol implementations can preserve +non-standard fields in ``extra`` or ``raw`` instead of dropping them. That +preservation is important for transform-pass logging, field-cache rules, and +provider-specific overrides added in later experimental phases. +""" + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import asdict, dataclass, field, is_dataclass +from typing import Any, ClassVar, Iterable, Mapping, Optional + + +JsonObject = dict[str, Any] + + +def serialize_value(value: Any) -> Any: + """Return a JSON-friendly copy of a protocol value. + + The transaction logging phase needs reliable snapshots after every protocol + and adapter pass. This helper keeps that concern centralized and avoids + mutating live request/response objects while preparing logs or fixtures. + """ + + if isinstance(value, ProtocolSerializable): + return value.to_dict() + if is_dataclass(value): + return serialize_value(asdict(value)) + if isinstance(value, Mapping): + return {str(k): serialize_value(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set, frozenset)): + return [serialize_value(v) for v in value] + return deepcopy(value) + + +def copy_mapping(value: Optional[Mapping[str, Any]]) -> JsonObject: + """Return a deep-copied dict for extension fields.""" + + return serialize_value(dict(value or {})) + + +class ProtocolSerializable: + """Mixin for dataclasses that need stable dict serialization. + + Concrete protocol types define ``_fields`` so ``to_dict`` stays explicit and + future additions do not accidentally disappear from transform logs. + """ + + _fields: ClassVar[tuple[str, ...]] = () + + def to_dict(self) -> JsonObject: + return {field_name: serialize_value(getattr(self, field_name)) for field_name in self._fields} + + +@dataclass +class CostDetails(ProtocolSerializable): + """Normalized cost metadata from provider-reported or estimated sources.""" + + provider_reported_cost: Optional[float] = None + estimated_cost: Optional[float] = None + currency: str = "USD" + source: Optional[str] = None + metadata: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "provider_reported_cost", + "estimated_cost", + "currency", + "source", + "metadata", + ) + + +@dataclass +class Usage(ProtocolSerializable): + """Protocol-neutral token and cost usage values. + + Existing usage tracking has provider-specific extraction paths. Native + protocols use this shape first, then later phases can normalize it into the + current usage manager without replacing that engine. + """ + + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 + reasoning_tokens: int = 0 + cost: Optional[CostDetails] = None + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "input_tokens", + "output_tokens", + "total_tokens", + "cache_read_tokens", + "cache_write_tokens", + "reasoning_tokens", + "cost", + "raw", + "extra", + ) + + def __post_init__(self) -> None: + if self.total_tokens <= 0: + self.total_tokens = self.input_tokens + self.output_tokens + self.reasoning_tokens + + +@dataclass +class ReasoningBlock(ProtocolSerializable): + """Reasoning/thinking content and signatures preserved across protocols.""" + + type: str = "reasoning" + text: Optional[str] = None + signature: Optional[str] = None + redacted: bool = False + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ("type", "text", "signature", "redacted", "extra") + + +@dataclass +class ToolCall(ProtocolSerializable): + """Protocol-neutral tool/function call emitted by an assistant.""" + + id: Optional[str] = None + name: Optional[str] = None + arguments: Any = None + type: str = "function" + index: Optional[int] = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ("id", "name", "arguments", "type", "index", "extra") + + +@dataclass +class ToolResult(ProtocolSerializable): + """Protocol-neutral result associated with a prior tool call.""" + + tool_call_id: Optional[str] = None + content: Any = None + is_error: Optional[bool] = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ("tool_call_id", "content", "is_error", "extra") + + +@dataclass +class ToolDefinition(ProtocolSerializable): + """Protocol-neutral tool schema exposed to a model.""" + + name: str = "" + description: Optional[str] = None + input_schema: JsonObject = field(default_factory=dict) + type: str = "function" + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ("name", "description", "input_schema", "type", "extra") + + +@dataclass +class ContentBlock(ProtocolSerializable): + """A single message content block. + + ``type`` follows the source protocol when practical. The dedicated fields + cover common text, image, document, tool, and reasoning cases, while ``extra`` + keeps provider-specific payloads for later field-cache extraction. + """ + + type: str = "text" + text: Optional[str] = None + source: Any = None + tool_call: Optional[ToolCall] = None + tool_result: Optional[ToolResult] = None + reasoning: Optional[ReasoningBlock] = None + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "type", + "text", + "source", + "tool_call", + "tool_result", + "reasoning", + "raw", + "extra", + ) + + +@dataclass +class UnifiedMessage(ProtocolSerializable): + """A protocol-neutral chat/message turn.""" + + role: str + content: list[ContentBlock] = field(default_factory=list) + name: Optional[str] = None + tool_call_id: Optional[str] = None + tool_calls: list[ToolCall] = field(default_factory=list) + reasoning: list[ReasoningBlock] = field(default_factory=list) + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "role", + "content", + "name", + "tool_call_id", + "tool_calls", + "reasoning", + "raw", + "extra", + ) + + +@dataclass +class UnifiedRequest(ProtocolSerializable): + """A request after parsing from a client or provider protocol.""" + + model: str = "" + messages: list[UnifiedMessage] = field(default_factory=list) + system: list[ContentBlock] = field(default_factory=list) + tools: list[ToolDefinition] = field(default_factory=list) + stream: bool = False + generation_params: JsonObject = field(default_factory=dict) + response_format: Any = None + previous_response_id: Optional[str] = None + metadata: JsonObject = field(default_factory=dict) + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "model", + "messages", + "system", + "tools", + "stream", + "generation_params", + "response_format", + "previous_response_id", + "metadata", + "raw", + "extra", + ) + + +@dataclass +class UnifiedResponse(ProtocolSerializable): + """A complete provider/client response in protocol-neutral form.""" + + id: Optional[str] = None + model: Optional[str] = None + messages: list[UnifiedMessage] = field(default_factory=list) + output: list[Any] = field(default_factory=list) + stop_reason: Optional[str] = None + usage: Optional[Usage] = None + metadata: JsonObject = field(default_factory=dict) + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "id", + "model", + "messages", + "output", + "stop_reason", + "usage", + "metadata", + "raw", + "extra", + ) + + +@dataclass +class UnifiedStreamEvent(ProtocolSerializable): + """A single protocol-neutral stream event. + + Future SSE and WebSocket transports should consume this type instead of raw + provider chunks so transport code can stay independent from protocol parsing. + """ + + type: str + delta: Optional[UnifiedMessage] = None + message: Optional[UnifiedMessage] = None + tool_call: Optional[ToolCall] = None + usage: Optional[Usage] = None + error: Any = None + raw: Any = None + extra: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "type", + "delta", + "message", + "tool_call", + "usage", + "error", + "raw", + "extra", + ) + + +@dataclass +class ProtocolContext(ProtocolSerializable): + """Execution context passed through protocol methods. + + Only a small subset is needed in Phase 1, but the fields anticipate later + provider overrides, transaction tracing, field-cache scoping, and transport + selection without forcing those systems to exist yet. + """ + + provider: Optional[str] = None + model: Optional[str] = None + source_protocol: Optional[str] = None + target_protocol: Optional[str] = None + request_id: Optional[str] = None + session_id: Optional[str] = None + credential_stable_id: Optional[str] = None + transport: str = "http" + provider_options: JsonObject = field(default_factory=dict) + metadata: JsonObject = field(default_factory=dict) + + _fields: ClassVar[tuple[str, ...]] = ( + "provider", + "model", + "source_protocol", + "target_protocol", + "request_id", + "session_id", + "credential_stable_id", + "transport", + "provider_options", + "metadata", + ) + + +class ProtocolError(ValueError): + """Error raised by protocol parsing/building passes.""" + + def __init__( + self, + message: str, + *, + protocol: str, + pass_name: str, + payload: Any = None, + ): + self.protocol = protocol + self.pass_name = pass_name + self.payload_preview = _payload_preview(payload) + details = f"{protocol}.{pass_name}: {message}" + if self.payload_preview is not None: + details = f"{details} | payload={self.payload_preview}" + super().__init__(details) + + +def _payload_preview(payload: Any, limit: int = 500) -> Optional[str]: + if payload is None: + return None + text = repr(serialize_value(payload)) + if len(text) > limit: + return text[: limit - 3] + "..." + return text + + +def text_blocks(text: Optional[str]) -> list[ContentBlock]: + """Return a single text block list for simple string content.""" + + if text is None: + return [] + return [ContentBlock(type="text", text=str(text))] + + +def first_text(blocks: Iterable[ContentBlock]) -> Optional[str]: + """Return concatenated text from content blocks, or ``None`` if absent.""" + + parts = [block.text for block in blocks if block.type == "text" and block.text] + return "".join(parts) if parts else None diff --git a/tests/test_protocol_registry.py b/tests/test_protocol_registry.py new file mode 100644 index 000000000..d3151b4c4 --- /dev/null +++ b/tests/test_protocol_registry.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import json + +import pytest + +from src.rotator_library.protocols import ( + ContentBlock, + ProtocolAdapter, + ProtocolContext, + ProtocolError, + ToolCall, + UnifiedMessage, + UnifiedRequest, + get_protocol, + get_protocol_class, + list_protocols, + register_protocol, + serialize_value, +) + + +class ExampleProtocol(ProtocolAdapter): + name = "example_test_protocol" + aliases = ("example_test_alias",) + + +def test_protocol_serialization_preserves_nested_values() -> None: + request = UnifiedRequest( + model="provider/model", + messages=[ + UnifiedMessage( + role="assistant", + content=[ContentBlock(type="text", text="hello")], + tool_calls=[ToolCall(id="call_1", name="lookup", arguments={"q": "x"})], + extra={"provider_field": {"kept": True}}, + ) + ], + metadata={"request_id": "req_1"}, + ) + + serialized = request.to_dict() + + assert serialized["messages"][0]["content"][0]["text"] == "hello" + assert serialized["messages"][0]["tool_calls"][0]["arguments"] == {"q": "x"} + assert serialized["messages"][0]["extra"]["provider_field"] == {"kept": True} + json.dumps(serialize_value(request)) + + +def test_base_protocol_preserves_unknown_request_fields() -> None: + adapter = ProtocolAdapter() + raw = {"model": "provider/model", "stream": True, "custom": {"value": 1}} + + unified = adapter.parse_request(raw, ProtocolContext(provider="provider")) + rebuilt = adapter.build_request(unified) + + assert unified.model == "provider/model" + assert unified.stream is True + assert unified.extra == {"custom": {"value": 1}} + assert rebuilt == raw + + +def test_register_protocol_resolves_alias_and_reuses_instance() -> None: + register_protocol(ExampleProtocol, replace=True) + + assert get_protocol_class("example_test_protocol") is ExampleProtocol + assert get_protocol_class("example_test_alias") is ExampleProtocol + assert get_protocol("example_test_protocol") is get_protocol("example_test_alias") + assert "example_test_protocol" in list_protocols() + + +def test_register_protocol_rejects_duplicate_alias() -> None: + class FirstProtocol(ProtocolAdapter): + name = "duplicate_alias_first" + aliases = ("duplicate_alias",) + + class SecondProtocol(ProtocolAdapter): + name = "duplicate_alias_second" + aliases = ("duplicate_alias",) + + register_protocol(FirstProtocol, replace=True) + + with pytest.raises(ValueError, match="alias already registered"): + register_protocol(SecondProtocol) + + +def test_protocol_error_includes_pass_and_payload_preview() -> None: + error = ProtocolError( + "failed", + protocol="example", + pass_name="parse_request", + payload={"secret": "not-redacted-here", "value": 1}, + ) + + assert "example.parse_request" in str(error) + assert "value" in str(error) From 004d471817f3b4be68a64809475dae672acf9c41 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 22:41:29 +0200 Subject: [PATCH 003/182] feat(protocols): add OpenAI chat adapter Adds the explicit LiteLLM fallback protocol marker and a native OpenAI Chat Completions adapter for request parsing/building, response parsing/formatting, usage and provider-reported cost extraction, reasoning preservation, tool calls, multimodal content blocks, and SSE chunk parsing. The adapter remains isolated from runtime execution and preserves unknown extension fields for future adapter, field-cache, and transform logging phases. Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py --- .../protocols/litellm_fallback.py | 28 ++ src/rotator_library/protocols/openai_chat.py | 416 ++++++++++++++++++ tests/test_protocol_openai_chat.py | 132 ++++++ 3 files changed, 576 insertions(+) create mode 100644 src/rotator_library/protocols/litellm_fallback.py create mode 100644 src/rotator_library/protocols/openai_chat.py create mode 100644 tests/test_protocol_openai_chat.py diff --git a/src/rotator_library/protocols/litellm_fallback.py b/src/rotator_library/protocols/litellm_fallback.py new file mode 100644 index 000000000..f8bd5a794 --- /dev/null +++ b/src/rotator_library/protocols/litellm_fallback.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Explicit fallback protocol for the existing LiteLLM-shaped path. + +This adapter intentionally does very little beyond preserving raw payloads. It +gives later execution and transform-logging code a named protocol path whenever +native adapters do not yet cover a provider or request shape. +""" + +from __future__ import annotations + +from typing import ClassVar + +from .base import ProtocolAdapter + + +class LiteLLMFallbackProtocol(ProtocolAdapter): + """Protocol adapter that keeps the current LiteLLM-compatible payload shape. + + Providers should prefer native protocol adapters when available. This class + remains as a safe compatibility base for unsupported provider shapes and as + a clear marker in future transaction transform traces. + """ + + name: ClassVar[str] = "litellm_fallback" + aliases: ClassVar[tuple[str, ...]] = ("litellm", "fallback") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py new file mode 100644 index 000000000..3c7e44a5c --- /dev/null +++ b/src/rotator_library/protocols/openai_chat.py @@ -0,0 +1,416 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI Chat Completions protocol adapter. + +The adapter models the common OpenAI-compatible chat shape used by many current +providers. It is a reusable base, not a final authority: providers can subclass +or override pieces when they need non-standard fields, stricter ordering, or +different stream semantics. +""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar, Iterable + +from .base import ProtocolAdapter +from .types import ( + ContentBlock, + CostDetails, + ProtocolContext, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, + first_text, + serialize_value, + text_blocks, +) + +_GENERATION_PARAMS = { + "frequency_penalty", + "logit_bias", + "logprobs", + "max_completion_tokens", + "max_tokens", + "n", + "parallel_tool_calls", + "presence_penalty", + "reasoning_effort", + "seed", + "service_tier", + "stop", + "stream_options", + "temperature", + "tool_choice", + "top_logprobs", + "top_p", + "user", +} + +_REQUEST_CORE_FIELDS = { + "model", + "messages", + "tools", + "stream", + "response_format", + "metadata", + *_GENERATION_PARAMS, +} + + +class OpenAIChatProtocol(ProtocolAdapter): + """Adapter for OpenAI Chat Completions request, response, and stream chunks. + + Unknown OpenAI-compatible extension fields are preserved in ``extra`` so a + custom provider can still use them through later adapter or field-cache + phases. Lossy conversions are avoided unless the source shape itself uses a + compact representation, such as string message content. + """ + + name: ClassVar[str] = "openai_chat" + aliases: ClassVar[tuple[str, ...]] = ( + "openai", + "chat_completions", + "openai_chat_completions", + ) + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + messages = [self._parse_message(message) for message in request.get("messages") or []] + tools = [self._parse_tool_definition(tool) for tool in request.get("tools") or []] + generation_params = {k: deepcopy(request[k]) for k in _GENERATION_PARAMS if k in request} + extra = {k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS} + + return UnifiedRequest( + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=messages, + tools=tools, + stream=bool(request.get("stream", False)), + generation_params=generation_params, + response_format=deepcopy(request.get("response_format")), + metadata=deepcopy(request.get("metadata") or {}), + raw=deepcopy(raw_request), + extra=extra, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": unified_request.model, + "messages": [self._format_message(message) for message in unified_request.messages], + } + if unified_request.tools: + payload["tools"] = [self._format_tool_definition(tool) for tool in unified_request.tools] + if unified_request.stream: + payload["stream"] = True + if unified_request.response_format is not None: + payload["response_format"] = deepcopy(unified_request.response_format) + if unified_request.metadata: + payload["metadata"] = deepcopy(unified_request.metadata) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = _as_dict(raw_response) + messages: list[UnifiedMessage] = [] + stop_reason = None + for choice in response.get("choices") or []: + if not isinstance(choice, dict): + continue + message_payload = choice.get("message") or {} + if message_payload: + messages.append(self._parse_message(message_payload)) + if choice.get("finish_reason") is not None: + stop_reason = choice.get("finish_reason") + + return UnifiedResponse( + id=response.get("id"), + model=response.get("model") or getattr(context, "model", None), + messages=messages, + stop_reason=stop_reason, + usage=self.extract_usage(response, context), + metadata={ + "object": response.get("object"), + "created": response.get("created"), + "system_fingerprint": response.get("system_fingerprint"), + }, + raw=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"id", "object", "created", "model", "choices", "usage", "system_fingerprint"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if isinstance(unified_response.raw, dict): + return deepcopy(unified_response.raw) + choices = [] + for index, message in enumerate(unified_response.messages): + choices.append( + { + "index": index, + "message": self._format_message(message), + "finish_reason": unified_response.stop_reason, + } + ) + payload = { + "id": unified_response.id, + "object": unified_response.metadata.get("object", "chat.completion"), + "created": unified_response.metadata.get("created"), + "model": unified_response.model, + "choices": choices, + "usage": unified_response.usage.to_dict() if unified_response.usage else None, + } + payload.update(deepcopy(unified_response.extra)) + return payload + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + event = _decode_sse_data(raw_event) + if event == "[DONE]": + return UnifiedStreamEvent(type="done", raw=deepcopy(raw_event)) + data = _as_dict(event) + if data.get("error") is not None: + return UnifiedStreamEvent(type="error", error=deepcopy(data["error"]), raw=deepcopy(raw_event), extra={"payload": data}) + + delta_message = None + finish_reason = None + for choice in data.get("choices") or []: + if not isinstance(choice, dict): + continue + delta = choice.get("delta") or {} + if delta: + delta_message = self._parse_message({"role": delta.get("role", "assistant"), **delta}) + finish_reason = choice.get("finish_reason") if choice.get("finish_reason") is not None else finish_reason + break + + usage = self.extract_usage(data, context) + return UnifiedStreamEvent( + type="message_delta" if delta_message else "chunk", + delta=delta_message, + usage=usage, + raw=deepcopy(raw_event), + extra={ + "id": data.get("id"), + "model": data.get("model"), + "finish_reason": finish_reason, + "payload": data, + }, + ) + + def format_stream_event(self, unified_event: UnifiedStreamEvent, context: ProtocolContext | None = None) -> Any: + if unified_event.raw is not None: + return deepcopy(unified_event.raw) + if unified_event.type == "done": + return "data: [DONE]\n\n" + return f"data: {json.dumps(unified_event.to_dict())}\n\n" + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): + return raw_or_unified.usage + payload = _as_dict(raw_or_unified) + usage = payload.get("usage") + if not isinstance(usage, dict): + return None + prompt_details = usage.get("prompt_tokens_details") or {} + completion_details = usage.get("completion_tokens_details") or {} + if not isinstance(prompt_details, dict): + prompt_details = {} + if not isinstance(completion_details, dict): + completion_details = {} + cost = None + cost_details = usage.get("cost_details") + if isinstance(cost_details, dict): + provider_cost = cost_details.get("total_cost") or cost_details.get("cost") + cost = CostDetails( + provider_reported_cost=float(provider_cost) if provider_cost is not None else None, + currency=str(cost_details.get("currency") or "USD"), + source="usage.cost_details", + metadata={k: deepcopy(v) for k, v in cost_details.items() if k not in {"total_cost", "cost", "currency"}}, + ) + return Usage( + input_tokens=int(usage.get("prompt_tokens") or usage.get("input_tokens") or 0), + output_tokens=int(usage.get("completion_tokens") or usage.get("output_tokens") or 0), + total_tokens=int(usage.get("total_tokens") or 0), + cache_read_tokens=int(prompt_details.get("cached_tokens") or usage.get("cache_read_tokens") or 0), + cache_write_tokens=int(prompt_details.get("cache_creation_tokens") or usage.get("cache_creation_tokens") or 0), + reasoning_tokens=int(completion_details.get("reasoning_tokens") or usage.get("reasoning_tokens") or 0), + cost=cost, + raw=deepcopy(usage), + ) + + def _parse_message(self, message: dict[str, Any]) -> UnifiedMessage: + payload = dict(message or {}) + reasoning = _extract_reasoning(payload) + return UnifiedMessage( + role=str(payload.get("role") or "assistant"), + content=self._parse_content(payload.get("content")), + name=payload.get("name"), + tool_call_id=payload.get("tool_call_id"), + tool_calls=[self._parse_tool_call(call) for call in payload.get("tool_calls") or []], + reasoning=reasoning, + raw=deepcopy(message), + extra={k: deepcopy(v) for k, v in payload.items() if k not in {"role", "content", "name", "tool_call_id", "tool_calls", "reasoning", "reasoning_content"}}, + ) + + def _format_message(self, message: UnifiedMessage) -> dict[str, Any]: + payload: dict[str, Any] = {"role": message.role} + if message.name: + payload["name"] = message.name + if message.tool_call_id: + payload["tool_call_id"] = message.tool_call_id + content = self._format_content(message.content) + if content is not None: + payload["content"] = content + if message.tool_calls: + payload["tool_calls"] = [self._format_tool_call(call) for call in message.tool_calls] + if message.reasoning: + # OpenAI-compatible providers use multiple names for reasoning text. + # Prefer the common extension field while keeping all blocks in extra. + text = "".join(block.text or "" for block in message.reasoning if block.text) + if text: + payload["reasoning_content"] = text + payload.update(deepcopy(message.extra)) + return payload + + def _parse_content(self, content: Any) -> list[ContentBlock]: + if content is None: + return [] + if isinstance(content, str): + return text_blocks(content) + if not isinstance(content, list): + return [ContentBlock(type="unknown", raw=deepcopy(content))] + blocks = [] + for block in content: + if isinstance(block, str): + blocks.append(ContentBlock(type="text", text=block, raw=block)) + continue + if not isinstance(block, dict): + blocks.append(ContentBlock(type="unknown", raw=deepcopy(block))) + continue + block_type = block.get("type", "text") + if block_type == "text": + blocks.append(ContentBlock(type="text", text=block.get("text", ""), raw=deepcopy(block))) + elif block_type in {"image_url", "input_image"}: + blocks.append(ContentBlock(type=block_type, source=deepcopy(block.get("image_url") or block.get("source")), raw=deepcopy(block), extra=_without(block, {"type", "image_url", "source"}))) + else: + blocks.append(ContentBlock(type=str(block_type), raw=deepcopy(block), extra=_without(block, {"type"}))) + return blocks + + def _format_content(self, blocks: Iterable[ContentBlock]) -> Any: + block_list = list(blocks) + if not block_list: + return None + if all(block.type == "text" for block in block_list): + return first_text(block_list) or "" + formatted = [] + for block in block_list: + if block.type == "text": + formatted.append({"type": "text", "text": block.text or ""}) + elif block.type in {"image_url", "input_image"}: + payload = {"type": block.type, "image_url": deepcopy(block.source)} + payload.update(deepcopy(block.extra)) + formatted.append(payload) + elif isinstance(block.raw, dict): + formatted.append(deepcopy(block.raw)) + else: + payload = {"type": block.type} + payload.update(deepcopy(block.extra)) + formatted.append(payload) + return formatted + + def _parse_tool_definition(self, tool: dict[str, Any]) -> ToolDefinition: + payload = dict(tool or {}) + function = payload.get("function") if isinstance(payload.get("function"), dict) else payload + return ToolDefinition( + name=str(function.get("name") or ""), + description=function.get("description"), + input_schema=deepcopy(function.get("parameters") or function.get("input_schema") or {}), + type=str(payload.get("type") or "function"), + extra={"raw": deepcopy(tool), **_without(payload, {"type", "function"})}, + ) + + def _format_tool_definition(self, tool: ToolDefinition) -> dict[str, Any]: + raw = tool.extra.get("raw") + if isinstance(raw, dict): + return deepcopy(raw) + return { + "type": tool.type, + "function": { + "name": tool.name, + "description": tool.description, + "parameters": deepcopy(tool.input_schema), + }, + } + + def _parse_tool_call(self, call: dict[str, Any]) -> ToolCall: + payload = dict(call or {}) + function = payload.get("function") if isinstance(payload.get("function"), dict) else {} + arguments: Any = function.get("arguments") + return ToolCall( + id=payload.get("id"), + name=function.get("name") or payload.get("name"), + arguments=arguments, + type=str(payload.get("type") or "function"), + index=payload.get("index"), + extra=_without(payload, {"id", "function", "type", "index", "name"}), + ) + + def _format_tool_call(self, call: ToolCall) -> dict[str, Any]: + payload = {"type": call.type, "function": {"name": call.name or "", "arguments": _format_arguments(call.arguments)}} + if call.id: + payload["id"] = call.id + if call.index is not None: + payload["index"] = call.index + payload.update(deepcopy(call.extra)) + return payload + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return {} + + +def _decode_sse_data(raw_event: Any) -> Any: + if not isinstance(raw_event, str): + return raw_event + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + if text == "[DONE]": + return text + try: + return json.loads(text) + except json.JSONDecodeError: + return raw_event + + +def _extract_reasoning(payload: dict[str, Any]) -> list[ReasoningBlock]: + blocks = [] + for field_name in ("reasoning_content", "reasoning"): + value = payload.get(field_name) + if value: + blocks.append(ReasoningBlock(type=field_name, text=str(value), extra={"source_field": field_name})) + return blocks + + +def _format_arguments(arguments: Any) -> str: + if arguments is None: + return "" + if isinstance(arguments, str): + return arguments + return json.dumps(serialize_value(arguments), separators=(",", ":")) + + +def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: + return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/tests/test_protocol_openai_chat.py b/tests/test_protocol_openai_chat.py new file mode 100644 index 000000000..214b96e6b --- /dev/null +++ b/tests/test_protocol_openai_chat.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import json + +from src.rotator_library.protocols import get_protocol, list_protocols + + +def test_openai_chat_protocol_is_discovered_with_aliases() -> None: + assert "openai_chat" in list_protocols() + assert get_protocol("openai") is get_protocol("openai_chat") + assert get_protocol("chat_completions") is get_protocol("openai_chat") + + +def test_litellm_fallback_protocol_preserves_raw_payload() -> None: + adapter = get_protocol("litellm_fallback") + raw = {"model": "custom/model", "messages": [{"role": "user", "content": "hi"}], "vendor_flag": True} + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.extra["messages"] == raw["messages"] + assert rebuilt == raw + + +def test_openai_chat_request_round_trip_preserves_tools_and_reasoning() -> None: + adapter = get_protocol("openai_chat") + raw = { + "model": "openai/gpt-test", + "stream": True, + "temperature": 0.2, + "messages": [ + {"role": "system", "content": "system"}, + {"role": "developer", "content": "dev"}, + {"role": "user", "content": [{"type": "text", "text": "hello"}]}, + { + "role": "assistant", + "content": "thinking done", + "reasoning_content": "internal chain", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": "{\"q\":\"x\"}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup a value", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + }, + } + ], + "vendor_extension": {"kept": True}, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.model == "openai/gpt-test" + assert unified.stream is True + assert unified.generation_params["temperature"] == 0.2 + assert unified.messages[3].reasoning[0].text == "internal chain" + assert unified.messages[3].tool_calls[0].name == "lookup" + assert unified.tools[0].input_schema["properties"]["q"]["type"] == "string" + assert rebuilt["vendor_extension"] == {"kept": True} + assert rebuilt["messages"][3]["reasoning_content"] == "internal chain" + + +def test_openai_chat_response_extracts_usage_cost_and_reasoning() -> None: + adapter = get_protocol("openai_chat") + raw = { + "id": "chatcmpl_1", + "object": "chat.completion", + "created": 123, + "model": "openai/gpt-test", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "answer", + "reasoning_content": "reasoned", + }, + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 18, + "prompt_tokens_details": {"cached_tokens": 3, "cache_creation_tokens": 2}, + "completion_tokens_details": {"reasoning_tokens": 3}, + "cost_details": {"total_cost": 0.01, "currency": "USD"}, + }, + } + + unified = adapter.parse_response(raw) + + assert unified.id == "chatcmpl_1" + assert unified.messages[0].content[0].text == "answer" + assert unified.messages[0].reasoning[0].text == "reasoned" + assert unified.usage is not None + assert unified.usage.input_tokens == 10 + assert unified.usage.output_tokens == 5 + assert unified.usage.reasoning_tokens == 3 + assert unified.usage.cache_read_tokens == 3 + assert unified.usage.cache_write_tokens == 2 + assert unified.usage.cost is not None + assert unified.usage.cost.provider_reported_cost == 0.01 + + +def test_openai_chat_stream_event_parses_sse_delta_and_done() -> None: + adapter = get_protocol("openai_chat") + event = { + "id": "chunk_1", + "model": "openai/gpt-test", + "choices": [{"index": 0, "delta": {"role": "assistant", "content": "Hel"}, "finish_reason": None}], + } + + parsed = adapter.parse_stream_event(f"data: {json.dumps(event)}\n\n") + done = adapter.parse_stream_event("data: [DONE]\n\n") + + assert parsed.type == "message_delta" + assert parsed.delta is not None + assert parsed.delta.content[0].text == "Hel" + assert done.type == "done" From 1a6b5641c780c56e7f4566908ebb2a1818cf88f2 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 22:44:05 +0200 Subject: [PATCH 004/182] feat(protocols): add Anthropic messages adapter Adds a native Anthropic Messages protocol adapter for request parsing/building, response formatting, stream event parsing, tool_use/tool_result blocks, thinking and redacted-thinking signature preservation, and cache usage normalization. The existing compatibility routes remain untouched; this adapter is an isolated base for later native provider execution, field-cache rules, and transform logging. Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py --- .../protocols/anthropic_messages.py | 347 ++++++++++++++++++ tests/test_protocol_anthropic_messages.py | 95 +++++ 2 files changed, 442 insertions(+) create mode 100644 src/rotator_library/protocols/anthropic_messages.py create mode 100644 tests/test_protocol_anthropic_messages.py diff --git a/src/rotator_library/protocols/anthropic_messages.py b/src/rotator_library/protocols/anthropic_messages.py new file mode 100644 index 000000000..0605d5134 --- /dev/null +++ b/src/rotator_library/protocols/anthropic_messages.py @@ -0,0 +1,347 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Anthropic Messages protocol adapter. + +This adapter captures the native Messages shape as a reusable base. The existing +compatibility routes remain active; this module gives future provider-native +execution a loss-conscious parser/builder with thinking and tool block support. +""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar, Iterable + +from .base import ProtocolAdapter +from .types import ( + ContentBlock, + ProtocolContext, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, + first_text, + text_blocks, +) + +_GENERATION_PARAMS = { + "max_tokens", + "metadata", + "stop_sequences", + "temperature", + "thinking", + "tool_choice", + "top_k", + "top_p", +} + +_REQUEST_CORE_FIELDS = {"model", "messages", "system", "tools", "stream", *_GENERATION_PARAMS} + + +class AnthropicMessagesProtocol(ProtocolAdapter): + """Adapter for Anthropic Messages requests, responses, and stream events. + + Thinking and redacted-thinking blocks are represented as reasoning blocks so + later field-cache rules can extract signatures without relying on a bespoke + provider implementation. + """ + + name: ClassVar[str] = "anthropic_messages" + aliases: ClassVar[tuple[str, ...]] = ("anthropic", "messages", "claude_messages") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + return UnifiedRequest( + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=[self._parse_message(message) for message in request.get("messages") or []], + system=self._parse_system(request.get("system")), + tools=[self._parse_tool_definition(tool) for tool in request.get("tools") or []], + stream=bool(request.get("stream", False)), + generation_params={k: deepcopy(request[k]) for k in _GENERATION_PARAMS if k in request and k != "metadata"}, + metadata=deepcopy(request.get("metadata") or {}), + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": unified_request.model, + "messages": [self._format_message(message) for message in unified_request.messages], + } + system = self._format_system(unified_request.system) + if system is not None: + payload["system"] = system + if unified_request.tools: + payload["tools"] = [self._format_tool_definition(tool) for tool in unified_request.tools] + if unified_request.stream: + payload["stream"] = True + if unified_request.metadata: + payload["metadata"] = deepcopy(unified_request.metadata) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = _as_dict(raw_response) + message = UnifiedMessage( + role=str(response.get("role") or "assistant"), + content=self._parse_content(response.get("content")), + raw=deepcopy(response), + extra={"type": response.get("type")}, + ) + self._promote_message_blocks(message) + return UnifiedResponse( + id=response.get("id"), + model=response.get("model") or getattr(context, "model", None), + messages=[message] if response else [], + stop_reason=response.get("stop_reason"), + usage=self.extract_usage(response, context), + metadata={"stop_sequence": response.get("stop_sequence"), "type": response.get("type")}, + raw=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"id", "type", "role", "content", "model", "stop_reason", "stop_sequence", "usage"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if isinstance(unified_response.raw, dict): + return deepcopy(unified_response.raw) + message = unified_response.messages[0] if unified_response.messages else UnifiedMessage(role="assistant") + payload = { + "id": unified_response.id, + "type": unified_response.metadata.get("type", "message"), + "role": message.role, + "content": self._format_content(message.content), + "model": unified_response.model, + "stop_reason": unified_response.stop_reason, + "stop_sequence": unified_response.metadata.get("stop_sequence"), + "usage": self._format_usage(unified_response.usage), + } + payload.update(deepcopy(unified_response.extra)) + return payload + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + event = _decode_sse_data(raw_event) + if event == "[DONE]": + return UnifiedStreamEvent(type="done", raw=deepcopy(raw_event)) + data = _as_dict(event) + event_type = str(data.get("type") or "chunk") + + if event_type == "error" or data.get("error") is not None: + return UnifiedStreamEvent(type="error", error=deepcopy(data.get("error", data)), raw=deepcopy(raw_event), extra={"payload": data}) + if event_type == "message_start": + response = self.parse_response(data.get("message") or {}, context) + return UnifiedStreamEvent(type="message_start", message=response.messages[0] if response.messages else None, usage=response.usage, raw=deepcopy(raw_event), extra={"payload": data}) + if event_type == "message_delta": + return UnifiedStreamEvent(type="message_delta", usage=self.extract_usage(data.get("usage") or {}, context), raw=deepcopy(raw_event), extra={"payload": data, "stop_reason": (data.get("delta") or {}).get("stop_reason")}) + if event_type in {"content_block_start", "content_block_delta", "content_block_stop"}: + return self._parse_content_stream_event(data, raw_event) + return UnifiedStreamEvent(type=event_type, raw=deepcopy(raw_event), extra={"payload": data}) + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): + return raw_or_unified.usage + payload = _as_dict(raw_or_unified) + usage = payload.get("usage") if isinstance(payload.get("usage"), dict) else payload + if not isinstance(usage, dict) or not any(k.endswith("tokens") for k in usage): + return None + input_tokens = int(usage.get("input_tokens") or 0) + output_tokens = int(usage.get("output_tokens") or 0) + cache_write = int(usage.get("cache_creation_input_tokens") or 0) + cache_read = int(usage.get("cache_read_input_tokens") or 0) + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=int(usage.get("total_tokens") or input_tokens + output_tokens), + cache_read_tokens=cache_read, + cache_write_tokens=cache_write, + raw=deepcopy(usage), + ) + + def _parse_system(self, system: Any) -> list[ContentBlock]: + if system is None: + return [] + if isinstance(system, str): + return text_blocks(system) + return self._parse_content(system) + + def _format_system(self, blocks: Iterable[ContentBlock]) -> Any: + block_list = list(blocks) + if not block_list: + return None + if all(block.type == "text" for block in block_list): + return first_text(block_list) or "" + return self._format_content(block_list) + + def _parse_message(self, message: dict[str, Any]) -> UnifiedMessage: + payload = dict(message or {}) + unified = UnifiedMessage( + role=str(payload.get("role") or "user"), + content=self._parse_content(payload.get("content")), + raw=deepcopy(message), + extra={k: deepcopy(v) for k, v in payload.items() if k not in {"role", "content"}}, + ) + self._promote_message_blocks(unified) + return unified + + def _format_message(self, message: UnifiedMessage) -> dict[str, Any]: + payload = {"role": message.role, "content": self._format_content(message.content)} + payload.update(deepcopy(message.extra)) + return payload + + def _parse_content(self, content: Any) -> list[ContentBlock]: + if content is None: + return [] + if isinstance(content, str): + return text_blocks(content) + if not isinstance(content, list): + return [ContentBlock(type="unknown", raw=deepcopy(content))] + return [self._parse_content_block(block) for block in content] + + def _parse_content_block(self, block: Any) -> ContentBlock: + if isinstance(block, str): + return ContentBlock(type="text", text=block, raw=block) + if not isinstance(block, dict): + return ContentBlock(type="unknown", raw=deepcopy(block)) + block_type = str(block.get("type") or "text") + if block_type == "text": + return ContentBlock(type="text", text=block.get("text", ""), raw=deepcopy(block)) + if block_type in {"image", "document"}: + return ContentBlock(type=block_type, source=deepcopy(block.get("source")), raw=deepcopy(block), extra=_without(block, {"type", "source"})) + if block_type in {"thinking", "redacted_thinking"}: + reasoning = ReasoningBlock( + type=block_type, + text=block.get("thinking"), + signature=block.get("signature"), + redacted=block_type == "redacted_thinking", + extra=_without(block, {"type", "thinking", "signature"}), + ) + return ContentBlock(type=block_type, reasoning=reasoning, raw=deepcopy(block)) + if block_type == "tool_use": + return ContentBlock( + type="tool_use", + tool_call=ToolCall(id=block.get("id"), name=block.get("name"), arguments=deepcopy(block.get("input")), type="tool_use"), + raw=deepcopy(block), + extra=_without(block, {"type", "id", "name", "input"}), + ) + if block_type == "tool_result": + return ContentBlock( + type="tool_result", + tool_result=ToolResult(tool_call_id=block.get("tool_use_id"), content=deepcopy(block.get("content")), is_error=block.get("is_error")), + raw=deepcopy(block), + extra=_without(block, {"type", "tool_use_id", "content", "is_error"}), + ) + return ContentBlock(type=block_type, raw=deepcopy(block), extra=_without(block, {"type"})) + + def _format_content(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: + formatted = [] + for block in blocks: + if isinstance(block.raw, dict): + formatted.append(deepcopy(block.raw)) + elif block.type == "text": + formatted.append({"type": "text", "text": block.text or ""}) + elif block.reasoning: + payload = {"type": block.reasoning.type} + if block.reasoning.text is not None: + payload["thinking"] = block.reasoning.text + if block.reasoning.signature is not None: + payload["signature"] = block.reasoning.signature + payload.update(deepcopy(block.reasoning.extra)) + formatted.append(payload) + elif block.tool_call: + formatted.append({"type": "tool_use", "id": block.tool_call.id, "name": block.tool_call.name, "input": deepcopy(block.tool_call.arguments)}) + elif block.tool_result: + payload = {"type": "tool_result", "tool_use_id": block.tool_result.tool_call_id, "content": deepcopy(block.tool_result.content)} + if block.tool_result.is_error is not None: + payload["is_error"] = block.tool_result.is_error + formatted.append(payload) + else: + payload = {"type": block.type} + payload.update(deepcopy(block.extra)) + formatted.append(payload) + return formatted + + def _parse_tool_definition(self, tool: dict[str, Any]) -> ToolDefinition: + payload = dict(tool or {}) + return ToolDefinition( + name=str(payload.get("name") or ""), + description=payload.get("description"), + input_schema=deepcopy(payload.get("input_schema") or {}), + type="tool", + extra=_without(payload, {"name", "description", "input_schema"}), + ) + + def _format_tool_definition(self, tool: ToolDefinition) -> dict[str, Any]: + payload = {"name": tool.name, "input_schema": deepcopy(tool.input_schema)} + if tool.description is not None: + payload["description"] = tool.description + payload.update(deepcopy(tool.extra)) + return payload + + def _format_usage(self, usage: Usage | None) -> dict[str, int] | None: + if usage is None: + return None + payload = {"input_tokens": usage.input_tokens, "output_tokens": usage.output_tokens} + if usage.cache_write_tokens: + payload["cache_creation_input_tokens"] = usage.cache_write_tokens + if usage.cache_read_tokens: + payload["cache_read_input_tokens"] = usage.cache_read_tokens + return payload + + def _promote_message_blocks(self, message: UnifiedMessage) -> None: + for block in message.content: + if block.tool_call: + message.tool_calls.append(block.tool_call) + if block.reasoning: + message.reasoning.append(block.reasoning) + + def _parse_content_stream_event(self, data: dict[str, Any], raw_event: Any) -> UnifiedStreamEvent: + block = data.get("content_block") if isinstance(data.get("content_block"), dict) else None + delta = data.get("delta") if isinstance(data.get("delta"), dict) else None + content_block = None + if block: + content_block = self._parse_content_block(block) + elif delta: + delta_type = delta.get("type") + if delta_type == "text_delta": + content_block = ContentBlock(type="text", text=delta.get("text"), raw=deepcopy(delta)) + elif delta_type in {"thinking_delta", "signature_delta"}: + reasoning = ReasoningBlock(type=str(delta_type), text=delta.get("thinking"), signature=delta.get("signature"), extra=_without(delta, {"type", "thinking", "signature"})) + content_block = ContentBlock(type=str(delta_type), reasoning=reasoning, raw=deepcopy(delta)) + message = UnifiedMessage(role="assistant", content=[content_block] if content_block else []) + self._promote_message_blocks(message) + return UnifiedStreamEvent(type=str(data.get("type") or "content_block_delta"), delta=message, raw=deepcopy(raw_event), extra={"payload": data, "index": data.get("index")}) + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return {} + + +def _decode_sse_data(raw_event: Any) -> Any: + if not isinstance(raw_event, str): + return raw_event + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + if text == "[DONE]": + return text + try: + return json.loads(text) + except json.JSONDecodeError: + return raw_event + + +def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: + return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/tests/test_protocol_anthropic_messages.py b/tests/test_protocol_anthropic_messages.py new file mode 100644 index 000000000..b1c161559 --- /dev/null +++ b/tests/test_protocol_anthropic_messages.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json + +from src.rotator_library.protocols import get_protocol, list_protocols + + +def test_anthropic_messages_protocol_is_discovered_with_aliases() -> None: + assert "anthropic_messages" in list_protocols() + assert get_protocol("anthropic") is get_protocol("anthropic_messages") + assert get_protocol("messages") is get_protocol("anthropic_messages") + + +def test_anthropic_request_round_trip_preserves_thinking_tools_and_cache_metadata() -> None: + adapter = get_protocol("anthropic_messages") + raw = { + "model": "anthropic/claude-test", + "system": [{"type": "text", "text": "system"}], + "max_tokens": 100, + "stream": True, + "messages": [ + {"role": "user", "content": "hello"}, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "worked", "signature": "sig_1"}, + {"type": "text", "text": "answer"}, + {"type": "tool_use", "id": "toolu_1", "name": "lookup", "input": {"q": "x"}}, + ], + }, + {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "toolu_1", "content": "result"}]}, + ], + "tools": [{"name": "lookup", "description": "Lookup", "input_schema": {"type": "object"}}], + "vendor_extension": {"kept": True}, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.model == "anthropic/claude-test" + assert unified.system[0].text == "system" + assert unified.messages[1].reasoning[0].signature == "sig_1" + assert unified.messages[1].tool_calls[0].name == "lookup" + assert unified.messages[2].content[0].tool_result.tool_call_id == "toolu_1" + assert unified.tools[0].input_schema == {"type": "object"} + assert rebuilt["vendor_extension"] == {"kept": True} + assert rebuilt["messages"][1]["content"][0]["signature"] == "sig_1" + + +def test_anthropic_response_extracts_content_and_usage() -> None: + adapter = get_protocol("anthropic_messages") + raw = { + "id": "msg_1", + "type": "message", + "role": "assistant", + "model": "anthropic/claude-test", + "content": [ + {"type": "redacted_thinking", "signature": "sig_2"}, + {"type": "text", "text": "answer"}, + ], + "stop_reason": "end_turn", + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "cache_creation_input_tokens": 2, + "cache_read_input_tokens": 3, + }, + } + + unified = adapter.parse_response(raw) + + assert unified.id == "msg_1" + assert unified.messages[0].content[1].text == "answer" + assert unified.messages[0].reasoning[0].signature == "sig_2" + assert unified.messages[0].reasoning[0].redacted is True + assert unified.usage is not None + assert unified.usage.input_tokens == 10 + assert unified.usage.output_tokens == 5 + assert unified.usage.cache_write_tokens == 2 + assert unified.usage.cache_read_tokens == 3 + + +def test_anthropic_stream_event_parses_text_delta_and_usage() -> None: + adapter = get_protocol("anthropic_messages") + text_event = {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hi"}} + usage_event = {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 7}} + + parsed_text = adapter.parse_stream_event(f"data: {json.dumps(text_event)}\n\n") + parsed_usage = adapter.parse_stream_event(usage_event) + + assert parsed_text.type == "content_block_delta" + assert parsed_text.delta is not None + assert parsed_text.delta.content[0].text == "Hi" + assert parsed_usage.usage is not None + assert parsed_usage.usage.output_tokens == 7 From 7df4ae93c9d1f63be3ef24bfe6a14e784af637f1 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 22:46:27 +0200 Subject: [PATCH 005/182] feat(protocols): add Gemini adapter Adds a native Gemini generateContent adapter for request parsing/building, response formatting, stream event parsing, content parts, function calls/responses, thought signatures, generation config, safety settings, tools, and Gemini usage metadata. The adapter preserves raw Gemini-native fields and remains isolated from runtime execution so provider migration can happen in later checkpoints. Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py --- src/rotator_library/protocols/gemini.py | 316 ++++++++++++++++++++++++ tests/test_protocol_gemini.py | 97 ++++++++ 2 files changed, 413 insertions(+) create mode 100644 src/rotator_library/protocols/gemini.py create mode 100644 tests/test_protocol_gemini.py diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py new file mode 100644 index 000000000..f6e5b3007 --- /dev/null +++ b/src/rotator_library/protocols/gemini.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Gemini generateContent protocol adapter. + +The adapter preserves Gemini-native content parts, thought signatures, safety +settings, tools, and generation configuration so later native providers can use +the same base without forcing an OpenAI-compatible intermediate shape. +""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar, Iterable + +from .base import ProtocolAdapter +from .types import ( + ContentBlock, + ProtocolContext, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, +) + +_REQUEST_CORE_FIELDS = { + "model", + "contents", + "systemInstruction", + "system_instruction", + "tools", + "generationConfig", + "generation_config", + "safetySettings", + "safety_settings", + "stream", +} + + +class GeminiProtocol(ProtocolAdapter): + """Adapter for Gemini ``generateContent`` and stream event shapes. + + Gemini parts are richer than simple chat messages. Unknown part fields remain + in ``extra`` and raw payloads are preserved so provider-specific subclasses + can refine behavior without losing data. + """ + + name: ClassVar[str] = "gemini" + aliases: ClassVar[tuple[str, ...]] = ("google_gemini", "generate_content") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + generation_config = deepcopy(request.get("generationConfig") or request.get("generation_config") or {}) + safety_settings = deepcopy(request.get("safetySettings") or request.get("safety_settings") or []) + return UnifiedRequest( + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=[self._parse_content(content) for content in request.get("contents") or []], + system=self._parse_system(request.get("systemInstruction") or request.get("system_instruction")), + tools=[self._parse_tool(tool) for tool in request.get("tools") or []], + stream=bool(request.get("stream", False)), + generation_params={"generationConfig": generation_config, "safetySettings": safety_settings}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = { + "contents": [self._format_content(message) for message in unified_request.messages], + } + if unified_request.model: + payload["model"] = unified_request.model + if unified_request.system: + payload["systemInstruction"] = {"parts": self._format_parts(unified_request.system)} + generation_config = unified_request.generation_params.get("generationConfig") + safety_settings = unified_request.generation_params.get("safetySettings") + if generation_config: + payload["generationConfig"] = deepcopy(generation_config) + if safety_settings: + payload["safetySettings"] = deepcopy(safety_settings) + if unified_request.tools: + payload["tools"] = [self._format_tool(tool) for tool in unified_request.tools] + if unified_request.stream: + payload["stream"] = True + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = _as_dict(raw_response) + messages: list[UnifiedMessage] = [] + stop_reason = None + for candidate in response.get("candidates") or []: + if not isinstance(candidate, dict): + continue + content = candidate.get("content") if isinstance(candidate.get("content"), dict) else {} + message = self._parse_content(content) + message.extra["candidate"] = _without(candidate, {"content"}) + messages.append(message) + if candidate.get("finishReason") is not None: + stop_reason = candidate.get("finishReason") + return UnifiedResponse( + id=response.get("responseId") or response.get("id"), + model=response.get("modelVersion") or getattr(context, "model", None), + messages=messages, + stop_reason=stop_reason, + usage=self.extract_usage(response, context), + metadata={"promptFeedback": deepcopy(response.get("promptFeedback")), "modelVersion": response.get("modelVersion")}, + raw=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"responseId", "id", "modelVersion", "candidates", "usageMetadata", "promptFeedback"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if isinstance(unified_response.raw, dict): + return deepcopy(unified_response.raw) + candidates = [] + for index, message in enumerate(unified_response.messages): + candidate = {"index": index, "content": self._format_content(message)} + if unified_response.stop_reason: + candidate["finishReason"] = unified_response.stop_reason + candidate.update(deepcopy(message.extra.get("candidate") or {})) + candidates.append(candidate) + payload = { + "responseId": unified_response.id, + "modelVersion": unified_response.model, + "candidates": candidates, + "usageMetadata": self._format_usage(unified_response.usage), + "promptFeedback": deepcopy(unified_response.metadata.get("promptFeedback")), + } + payload.update(deepcopy(unified_response.extra)) + return payload + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + event = _decode_sse_data(raw_event) + if event == "[DONE]": + return UnifiedStreamEvent(type="done", raw=deepcopy(raw_event)) + data = _as_dict(event) + response = self.parse_response(data, context) + message = response.messages[0] if response.messages else None + return UnifiedStreamEvent( + type="message_delta" if message else "chunk", + delta=message, + usage=response.usage, + raw=deepcopy(raw_event), + extra={"payload": data, "finish_reason": response.stop_reason}, + ) + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): + return raw_or_unified.usage + payload = _as_dict(raw_or_unified) + usage = payload.get("usageMetadata") if isinstance(payload.get("usageMetadata"), dict) else payload + if not isinstance(usage, dict) or not any(key.endswith("TokenCount") for key in usage): + return None + input_tokens = int(usage.get("promptTokenCount") or 0) + output_tokens = int(usage.get("candidatesTokenCount") or 0) + reasoning_tokens = int(usage.get("thoughtsTokenCount") or 0) + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=int(usage.get("totalTokenCount") or input_tokens + output_tokens + reasoning_tokens), + cache_read_tokens=int(usage.get("cachedContentTokenCount") or 0), + reasoning_tokens=reasoning_tokens, + raw=deepcopy(usage), + ) + + def _parse_system(self, system: Any) -> list[ContentBlock]: + if system is None: + return [] + if isinstance(system, str): + return [ContentBlock(type="text", text=system, raw=system)] + if isinstance(system, dict): + return self._parse_parts(system.get("parts") or []) + return [] + + def _parse_content(self, content: dict[str, Any]) -> UnifiedMessage: + payload = dict(content or {}) + role = str(payload.get("role") or "model") + # Gemini uses "model" where chat protocols usually say "assistant". + normalized_role = "assistant" if role == "model" else role + message = UnifiedMessage( + role=normalized_role, + content=self._parse_parts(payload.get("parts") or []), + raw=deepcopy(content), + extra={"gemini_role": role, **_without(payload, {"role", "parts"})}, + ) + for block in message.content: + if block.tool_call: + message.tool_calls.append(block.tool_call) + if block.reasoning: + message.reasoning.append(block.reasoning) + return message + + def _format_content(self, message: UnifiedMessage) -> dict[str, Any]: + role = message.extra.get("gemini_role") or ("model" if message.role == "assistant" else message.role) + payload = {"role": role, "parts": self._format_parts(message.content)} + payload.update({k: deepcopy(v) for k, v in message.extra.items() if k != "gemini_role"}) + return payload + + def _parse_parts(self, parts: Iterable[Any]) -> list[ContentBlock]: + blocks = [] + for part in parts: + blocks.append(self._parse_part(part)) + return blocks + + def _parse_part(self, part: Any) -> ContentBlock: + if isinstance(part, str): + return ContentBlock(type="text", text=part, raw=part) + if not isinstance(part, dict): + return ContentBlock(type="unknown", raw=deepcopy(part)) + if "text" in part: + reasoning = None + if part.get("thought") or part.get("thoughtSignature"): + reasoning = ReasoningBlock(type="gemini_thought", text=part.get("text"), signature=part.get("thoughtSignature"), extra=_without(part, {"text", "thought", "thoughtSignature"})) + return ContentBlock(type="text", text=part.get("text", ""), reasoning=reasoning, raw=deepcopy(part), extra=_without(part, {"text"})) + if "inlineData" in part or "inline_data" in part: + source = part.get("inlineData") or part.get("inline_data") + return ContentBlock(type="inline_data", source=deepcopy(source), raw=deepcopy(part), extra=_without(part, {"inlineData", "inline_data"})) + if "fileData" in part or "file_data" in part: + source = part.get("fileData") or part.get("file_data") + return ContentBlock(type="file_data", source=deepcopy(source), raw=deepcopy(part), extra=_without(part, {"fileData", "file_data"})) + if "functionCall" in part or "function_call" in part: + call = part.get("functionCall") or part.get("function_call") or {} + return ContentBlock(type="function_call", tool_call=ToolCall(name=call.get("name"), arguments=deepcopy(call.get("args")), type="function_call"), raw=deepcopy(part), extra=_without(part, {"functionCall", "function_call"})) + if "functionResponse" in part or "function_response" in part: + response = part.get("functionResponse") or part.get("function_response") or {} + return ContentBlock(type="function_response", tool_result=ToolResult(tool_call_id=response.get("name"), content=deepcopy(response.get("response"))), raw=deepcopy(part), extra=_without(part, {"functionResponse", "function_response"})) + return ContentBlock(type="unknown", raw=deepcopy(part), extra=deepcopy(part)) + + def _format_parts(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: + parts = [] + for block in blocks: + if isinstance(block.raw, dict): + parts.append(deepcopy(block.raw)) + elif block.tool_call: + parts.append({"functionCall": {"name": block.tool_call.name, "args": deepcopy(block.tool_call.arguments)}}) + elif block.tool_result: + parts.append({"functionResponse": {"name": block.tool_result.tool_call_id, "response": deepcopy(block.tool_result.content)}}) + elif block.type == "inline_data": + parts.append({"inlineData": deepcopy(block.source)}) + elif block.type == "file_data": + parts.append({"fileData": deepcopy(block.source)}) + else: + payload = {"text": block.text or ""} + if block.reasoning: + payload["thought"] = True + if block.reasoning.signature: + payload["thoughtSignature"] = block.reasoning.signature + payload.update(deepcopy(block.extra)) + parts.append(payload) + return parts + + def _parse_tool(self, tool: dict[str, Any]) -> ToolDefinition: + payload = dict(tool or {}) + declarations = payload.get("functionDeclarations") or payload.get("function_declarations") or [] + first = declarations[0] if declarations and isinstance(declarations[0], dict) else payload + return ToolDefinition( + name=str(first.get("name") or ""), + description=first.get("description"), + input_schema=deepcopy(first.get("parameters") or {}), + type="function", + extra={"raw": deepcopy(tool)}, + ) + + def _format_tool(self, tool: ToolDefinition) -> dict[str, Any]: + raw = tool.extra.get("raw") + if isinstance(raw, dict): + return deepcopy(raw) + return {"functionDeclarations": [{"name": tool.name, "description": tool.description, "parameters": deepcopy(tool.input_schema)}]} + + def _format_usage(self, usage: Usage | None) -> dict[str, int] | None: + if usage is None: + return None + payload = { + "promptTokenCount": usage.input_tokens, + "candidatesTokenCount": usage.output_tokens, + "totalTokenCount": usage.total_tokens, + } + if usage.reasoning_tokens: + payload["thoughtsTokenCount"] = usage.reasoning_tokens + if usage.cache_read_tokens: + payload["cachedContentTokenCount"] = usage.cache_read_tokens + return payload + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return {} + + +def _decode_sse_data(raw_event: Any) -> Any: + if not isinstance(raw_event, str): + return raw_event + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + if text == "[DONE]": + return text + try: + return json.loads(text) + except json.JSONDecodeError: + return raw_event + + +def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: + return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/tests/test_protocol_gemini.py b/tests/test_protocol_gemini.py new file mode 100644 index 000000000..2ce1eaab1 --- /dev/null +++ b/tests/test_protocol_gemini.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import json + +from src.rotator_library.protocols import get_protocol, list_protocols + + +def test_gemini_protocol_is_discovered_with_aliases() -> None: + assert "gemini" in list_protocols() + assert get_protocol("google_gemini") is get_protocol("gemini") + assert get_protocol("generate_content") is get_protocol("gemini") + + +def test_gemini_request_round_trip_preserves_parts_tools_and_settings() -> None: + adapter = get_protocol("gemini") + raw = { + "model": "gemini/gemini-test", + "systemInstruction": {"parts": [{"text": "system"}]}, + "contents": [ + {"role": "user", "parts": [{"text": "hello"}, {"inlineData": {"mimeType": "image/png", "data": "abc"}}]}, + { + "role": "model", + "parts": [ + {"text": "thought", "thought": True, "thoughtSignature": "sig_1"}, + {"functionCall": {"name": "lookup", "args": {"q": "x"}}}, + ], + }, + {"role": "user", "parts": [{"functionResponse": {"name": "lookup", "response": {"value": 1}}}]}, + ], + "tools": [{"functionDeclarations": [{"name": "lookup", "description": "Lookup", "parameters": {"type": "object"}}]}], + "generationConfig": {"temperature": 0.3}, + "safetySettings": [{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}], + "vendor_extension": {"kept": True}, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.model == "gemini/gemini-test" + assert unified.system[0].text == "system" + assert unified.messages[1].role == "assistant" + assert unified.messages[1].reasoning[0].signature == "sig_1" + assert unified.messages[1].tool_calls[0].name == "lookup" + assert unified.messages[2].content[0].tool_result.content == {"value": 1} + assert unified.tools[0].name == "lookup" + assert rebuilt["generationConfig"] == {"temperature": 0.3} + assert rebuilt["safetySettings"][0]["threshold"] == "BLOCK_NONE" + assert rebuilt["vendor_extension"] == {"kept": True} + + +def test_gemini_response_extracts_usage_and_thought_signature() -> None: + adapter = get_protocol("gemini") + raw = { + "responseId": "resp_1", + "modelVersion": "gemini-test-001", + "candidates": [ + { + "finishReason": "STOP", + "content": {"role": "model", "parts": [{"text": "answer", "thought": True, "thoughtSignature": "sig_2"}]}, + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 5, + "thoughtsTokenCount": 2, + "cachedContentTokenCount": 3, + "totalTokenCount": 17, + }, + } + + unified = adapter.parse_response(raw) + + assert unified.id == "resp_1" + assert unified.stop_reason == "STOP" + assert unified.messages[0].content[0].text == "answer" + assert unified.messages[0].reasoning[0].signature == "sig_2" + assert unified.usage is not None + assert unified.usage.input_tokens == 10 + assert unified.usage.output_tokens == 5 + assert unified.usage.reasoning_tokens == 2 + assert unified.usage.cache_read_tokens == 3 + + +def test_gemini_stream_event_parses_candidate_delta() -> None: + adapter = get_protocol("gemini") + event = { + "candidates": [{"content": {"role": "model", "parts": [{"text": "Hi"}]}, "finishReason": None}], + "usageMetadata": {"promptTokenCount": 1, "candidatesTokenCount": 1, "totalTokenCount": 2}, + } + + parsed = adapter.parse_stream_event(f"data: {json.dumps(event)}\n\n") + + assert parsed.type == "message_delta" + assert parsed.delta is not None + assert parsed.delta.content[0].text == "Hi" + assert parsed.usage is not None + assert parsed.usage.total_tokens == 2 From aaefc64711b2d2a1cf91139cd188f953afa1cf77 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 22:49:13 +0200 Subject: [PATCH 006/182] feat(protocols): add Responses adapter Adds a native Responses protocol adapter for request parsing/building, response formatting, event-stream parsing, previous_response_id preservation, input and output item handling, reasoning items, function calls, usage details, provider-reported costs, and a WebSocket-ready transport capability flag. Routes, storage, and runtime wiring remain deferred to later checkpoints; this commit only adds the reusable protocol base and tests. Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- src/rotator_library/protocols/responses.py | 347 +++++++++++++++++++++ tests/test_protocol_responses.py | 104 ++++++ 2 files changed, 451 insertions(+) create mode 100644 src/rotator_library/protocols/responses.py create mode 100644 tests/test_protocol_responses.py diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py new file mode 100644 index 000000000..1b0580e76 --- /dev/null +++ b/src/rotator_library/protocols/responses.py @@ -0,0 +1,347 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI Responses protocol adapter. + +Responses is important enough to model natively rather than forcing it through a +chat-completions shape. This adapter focuses on loss-conscious parsing and +formatting; storage, routes, and WebSocket transport are later phases. +""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar, Iterable + +from .base import ProtocolAdapter +from .types import ( + ContentBlock, + CostDetails, + ProtocolContext, + ReasoningBlock, + ToolCall, + ToolDefinition, + ToolResult, + UnifiedMessage, + UnifiedRequest, + UnifiedResponse, + UnifiedStreamEvent, + Usage, + first_text, + text_blocks, +) + +_GENERATION_PARAMS = { + "include", + "instructions", + "max_output_tokens", + "parallel_tool_calls", + "reasoning", + "store", + "temperature", + "text", + "tool_choice", + "top_p", + "truncation", + "user", +} + +_REQUEST_CORE_FIELDS = { + "model", + "input", + "metadata", + "previous_response_id", + "stream", + "tools", + *_GENERATION_PARAMS, +} + + +class ResponsesProtocol(ProtocolAdapter): + """Adapter for OpenAI Responses request, response, and event stream shapes. + + The protocol keeps output items in addition to parsed messages because later + response storage and continuation features need item-level fidelity. + """ + + name: ClassVar[str] = "responses" + aliases: ClassVar[tuple[str, ...]] = ("openai_responses", "response_api") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse", "websocket") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + generation_params = {k: deepcopy(request[k]) for k in _GENERATION_PARAMS if k in request and k != "instructions"} + return UnifiedRequest( + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=self._parse_input(request.get("input")), + system=text_blocks(request.get("instructions")) if request.get("instructions") is not None else [], + tools=[self._parse_tool(tool) for tool in request.get("tools") or []], + stream=bool(request.get("stream", False)), + generation_params=generation_params, + previous_response_id=request.get("previous_response_id"), + metadata=deepcopy(request.get("metadata") or {}), + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": unified_request.model, + "input": [self._format_input_message(message) for message in unified_request.messages], + } + instructions = first_text(unified_request.system) + if instructions is not None: + payload["instructions"] = instructions + if unified_request.previous_response_id: + payload["previous_response_id"] = unified_request.previous_response_id + if unified_request.tools: + payload["tools"] = [self._format_tool(tool) for tool in unified_request.tools] + if unified_request.stream: + payload["stream"] = True + if unified_request.metadata: + payload["metadata"] = deepcopy(unified_request.metadata) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = _as_dict(raw_response) + output = deepcopy(response.get("output") or []) + messages: list[UnifiedMessage] = [] + for item in output: + if isinstance(item, dict): + parsed = self._parse_output_item(item) + if parsed: + messages.append(parsed) + return UnifiedResponse( + id=response.get("id"), + model=response.get("model") or getattr(context, "model", None), + messages=messages, + output=output, + stop_reason=response.get("status"), + usage=self.extract_usage(response, context), + metadata={"object": response.get("object"), "created_at": response.get("created_at")}, + raw=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"id", "object", "created_at", "model", "output", "usage", "status"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if isinstance(unified_response.raw, dict): + return deepcopy(unified_response.raw) + output = deepcopy(unified_response.output) if unified_response.output else [self._format_output_message(message, index) for index, message in enumerate(unified_response.messages)] + payload = { + "id": unified_response.id, + "object": unified_response.metadata.get("object", "response"), + "created_at": unified_response.metadata.get("created_at"), + "model": unified_response.model, + "status": unified_response.stop_reason, + "output": output, + "usage": unified_response.usage.to_dict() if unified_response.usage else None, + } + payload.update(deepcopy(unified_response.extra)) + return payload + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + event = _decode_sse_data(raw_event) + if event == "[DONE]": + return UnifiedStreamEvent(type="done", raw=deepcopy(raw_event)) + data = _as_dict(event) + event_type = str(data.get("type") or data.get("event") or "chunk") + if event_type in {"error", "response.error"} or data.get("error") is not None: + return UnifiedStreamEvent(type="error", error=deepcopy(data.get("error", data)), raw=deepcopy(raw_event), extra={"payload": data}) + if event_type in {"response.completed", "response.failed", "response.incomplete"}: + response = self.parse_response(data.get("response") or {}, context) + return UnifiedStreamEvent(type=event_type, message=response.messages[0] if response.messages else None, usage=response.usage, raw=deepcopy(raw_event), extra={"payload": data}) + if event_type == "response.output_text.delta": + message = UnifiedMessage(role="assistant", content=text_blocks(data.get("delta") or "")) + return UnifiedStreamEvent(type="message_delta", delta=message, raw=deepcopy(raw_event), extra={"payload": data, "output_index": data.get("output_index"), "content_index": data.get("content_index")}) + if event_type in {"response.output_item.added", "response.output_item.done"} and isinstance(data.get("item"), dict): + message = self._parse_output_item(data["item"]) + return UnifiedStreamEvent(type=event_type, message=message, raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type=event_type, raw=deepcopy(raw_event), extra={"payload": data}) + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): + return raw_or_unified.usage + payload = _as_dict(raw_or_unified) + usage = payload.get("usage") if isinstance(payload.get("usage"), dict) else payload + if not isinstance(usage, dict) or not any(key.endswith("tokens") for key in usage): + return None + input_details = usage.get("input_tokens_details") if isinstance(usage.get("input_tokens_details"), dict) else {} + output_details = usage.get("output_tokens_details") if isinstance(usage.get("output_tokens_details"), dict) else {} + cost = None + cost_details = usage.get("cost_details") + if isinstance(cost_details, dict): + provider_cost = cost_details.get("total_cost") or cost_details.get("cost") + cost = CostDetails( + provider_reported_cost=float(provider_cost) if provider_cost is not None else None, + currency=str(cost_details.get("currency") or "USD"), + source="usage.cost_details", + metadata={k: deepcopy(v) for k, v in cost_details.items() if k not in {"total_cost", "cost", "currency"}}, + ) + return Usage( + input_tokens=int(usage.get("input_tokens") or 0), + output_tokens=int(usage.get("output_tokens") or 0), + total_tokens=int(usage.get("total_tokens") or 0), + cache_read_tokens=int(input_details.get("cached_tokens") or 0), + reasoning_tokens=int(output_details.get("reasoning_tokens") or 0), + cost=cost, + raw=deepcopy(usage), + ) + + def _parse_input(self, input_value: Any) -> list[UnifiedMessage]: + if input_value is None: + return [] + if isinstance(input_value, str): + return [UnifiedMessage(role="user", content=text_blocks(input_value), raw=input_value)] + if not isinstance(input_value, list): + return [UnifiedMessage(role="user", content=[ContentBlock(type="unknown", raw=deepcopy(input_value))], raw=deepcopy(input_value))] + messages = [] + for item in input_value: + if isinstance(item, dict): + messages.append(self._parse_input_item(item)) + else: + messages.append(UnifiedMessage(role="user", content=text_blocks(str(item)), raw=deepcopy(item))) + return messages + + def _parse_input_item(self, item: dict[str, Any]) -> UnifiedMessage: + item_type = item.get("type") + if item_type in {"message", None}: + return UnifiedMessage( + role=str(item.get("role") or "user"), + content=self._parse_content(item.get("content")), + raw=deepcopy(item), + extra={k: deepcopy(v) for k, v in item.items() if k not in {"type", "role", "content"}}, + ) + if item_type == "function_call_output": + return UnifiedMessage( + role="tool", + content=[ContentBlock(type="tool_result", tool_result=ToolResult(tool_call_id=item.get("call_id"), content=item.get("output")), raw=deepcopy(item))], + tool_call_id=item.get("call_id"), + raw=deepcopy(item), + ) + return UnifiedMessage(role=str(item.get("role") or "user"), content=[ContentBlock(type=str(item_type or "unknown"), raw=deepcopy(item))], raw=deepcopy(item)) + + def _format_input_message(self, message: UnifiedMessage) -> dict[str, Any]: + if isinstance(message.raw, dict): + return deepcopy(message.raw) + return {"type": "message", "role": message.role, "content": self._format_content(message.content)} + + def _parse_output_item(self, item: dict[str, Any]) -> UnifiedMessage | None: + item_type = item.get("type") + if item_type == "message": + return UnifiedMessage( + role=str(item.get("role") or "assistant"), + content=self._parse_content(item.get("content")), + raw=deepcopy(item), + extra={k: deepcopy(v) for k, v in item.items() if k not in {"type", "role", "content"}}, + ) + if item_type == "reasoning": + reasoning = ReasoningBlock(type="reasoning", text=_reasoning_text(item), extra={k: deepcopy(v) for k, v in item.items() if k not in {"type", "summary"}}) + return UnifiedMessage(role="assistant", content=[ContentBlock(type="reasoning", reasoning=reasoning, raw=deepcopy(item))], reasoning=[reasoning], raw=deepcopy(item)) + if item_type in {"function_call", "custom_tool_call"}: + call = ToolCall(id=item.get("call_id") or item.get("id"), name=item.get("name"), arguments=item.get("arguments") or item.get("input"), type=str(item_type)) + return UnifiedMessage(role="assistant", content=[ContentBlock(type=str(item_type), tool_call=call, raw=deepcopy(item))], tool_calls=[call], raw=deepcopy(item)) + return None + + def _format_output_message(self, message: UnifiedMessage, index: int) -> dict[str, Any]: + return {"id": f"msg_{index}", "type": "message", "role": message.role, "content": self._format_content(message.content)} + + def _parse_content(self, content: Any) -> list[ContentBlock]: + if content is None: + return [] + if isinstance(content, str): + return text_blocks(content) + if not isinstance(content, list): + return [ContentBlock(type="unknown", raw=deepcopy(content))] + blocks = [] + for block in content: + if isinstance(block, str): + blocks.append(ContentBlock(type="input_text", text=block, raw=block)) + continue + if not isinstance(block, dict): + blocks.append(ContentBlock(type="unknown", raw=deepcopy(block))) + continue + block_type = str(block.get("type") or "text") + if block_type in {"input_text", "output_text", "text"}: + blocks.append(ContentBlock(type=block_type, text=block.get("text", ""), raw=deepcopy(block), extra=_without(block, {"type", "text"}))) + elif block_type in {"input_image", "image_url"}: + blocks.append(ContentBlock(type=block_type, source=deepcopy(block.get("image_url") or block.get("source")), raw=deepcopy(block), extra=_without(block, {"type", "image_url", "source"}))) + else: + blocks.append(ContentBlock(type=block_type, raw=deepcopy(block), extra=_without(block, {"type"}))) + return blocks + + def _format_content(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: + formatted = [] + for block in blocks: + if isinstance(block.raw, dict): + formatted.append(deepcopy(block.raw)) + elif block.type in {"input_text", "output_text", "text"}: + formatted.append({"type": block.type, "text": block.text or ""}) + elif block.type in {"input_image", "image_url"}: + formatted.append({"type": block.type, "image_url": deepcopy(block.source)}) + else: + payload = {"type": block.type} + payload.update(deepcopy(block.extra)) + formatted.append(payload) + return formatted + + def _parse_tool(self, tool: dict[str, Any]) -> ToolDefinition: + payload = dict(tool or {}) + parameters = payload.get("parameters") or payload.get("input_schema") or {} + return ToolDefinition( + name=str(payload.get("name") or ""), + description=payload.get("description"), + input_schema=deepcopy(parameters), + type=str(payload.get("type") or "function"), + extra={k: deepcopy(v) for k, v in payload.items() if k not in {"type", "name", "description", "parameters", "input_schema"}}, + ) + + def _format_tool(self, tool: ToolDefinition) -> dict[str, Any]: + payload = {"type": tool.type, "name": tool.name, "parameters": deepcopy(tool.input_schema)} + if tool.description is not None: + payload["description"] = tool.description + payload.update(deepcopy(tool.extra)) + return payload + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return {} + + +def _decode_sse_data(raw_event: Any) -> Any: + if not isinstance(raw_event, str): + return raw_event + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + if text == "[DONE]": + return text + try: + return json.loads(text) + except json.JSONDecodeError: + return raw_event + + +def _reasoning_text(item: dict[str, Any]) -> str | None: + summary = item.get("summary") + if isinstance(summary, list): + parts = [] + for part in summary: + if isinstance(part, dict) and part.get("text"): + parts.append(str(part["text"])) + elif isinstance(part, str): + parts.append(part) + return "".join(parts) if parts else None + return str(summary) if summary else None + + +def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: + return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/tests/test_protocol_responses.py b/tests/test_protocol_responses.py new file mode 100644 index 000000000..b6474c7d7 --- /dev/null +++ b/tests/test_protocol_responses.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import json + +from src.rotator_library.protocols import get_protocol, list_protocols + + +def test_responses_protocol_is_discovered_with_aliases_and_websocket_support() -> None: + adapter = get_protocol("responses") + + assert "responses" in list_protocols() + assert get_protocol("openai_responses") is adapter + assert adapter.supports_transport("websocket") is True + + +def test_responses_request_round_trip_preserves_previous_response_and_tools() -> None: + adapter = get_protocol("responses") + raw = { + "model": "openai/gpt-test", + "instructions": "system", + "previous_response_id": "resp_prev", + "stream": True, + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"type": "function_call_output", "call_id": "call_1", "output": "result"}, + ], + "tools": [{"type": "function", "name": "lookup", "description": "Lookup", "parameters": {"type": "object"}}], + "reasoning": {"effort": "medium"}, + "vendor_extension": {"kept": True}, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.previous_response_id == "resp_prev" + assert unified.system[0].text == "system" + assert unified.messages[0].content[0].text == "hello" + assert unified.messages[1].tool_call_id == "call_1" + assert unified.tools[0].name == "lookup" + assert rebuilt["previous_response_id"] == "resp_prev" + assert rebuilt["reasoning"] == {"effort": "medium"} + assert rebuilt["vendor_extension"] == {"kept": True} + + +def test_responses_response_extracts_output_items_reasoning_calls_and_usage() -> None: + adapter = get_protocol("responses") + raw = { + "id": "resp_1", + "object": "response", + "created_at": 123, + "model": "openai/gpt-test", + "status": "completed", + "output": [ + {"id": "msg_1", "type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "answer"}]}, + {"id": "rs_1", "type": "reasoning", "summary": [{"type": "summary_text", "text": "reasoned"}]}, + {"id": "fc_1", "type": "function_call", "call_id": "call_1", "name": "lookup", "arguments": "{\"q\":\"x\"}"}, + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 18, + "input_tokens_details": {"cached_tokens": 3}, + "output_tokens_details": {"reasoning_tokens": 3}, + "cost_details": {"total_cost": 0.02, "currency": "USD"}, + }, + } + + unified = adapter.parse_response(raw) + + assert unified.id == "resp_1" + assert len(unified.output) == 3 + assert unified.messages[0].content[0].text == "answer" + assert unified.messages[1].reasoning[0].text == "reasoned" + assert unified.messages[2].tool_calls[0].name == "lookup" + assert unified.usage is not None + assert unified.usage.input_tokens == 10 + assert unified.usage.reasoning_tokens == 3 + assert unified.usage.cache_read_tokens == 3 + assert unified.usage.cost is not None + assert unified.usage.cost.provider_reported_cost == 0.02 + + +def test_responses_stream_event_parses_text_delta_and_completed_response() -> None: + adapter = get_protocol("responses") + delta = {"type": "response.output_text.delta", "output_index": 0, "content_index": 0, "delta": "Hi"} + completed = { + "type": "response.completed", + "response": { + "id": "resp_1", + "model": "openai/gpt-test", + "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hi"}]}], + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + }, + } + + parsed_delta = adapter.parse_stream_event(f"data: {json.dumps(delta)}\n\n") + parsed_completed = adapter.parse_stream_event(completed) + + assert parsed_delta.type == "message_delta" + assert parsed_delta.delta is not None + assert parsed_delta.delta.content[0].text == "Hi" + assert parsed_completed.type == "response.completed" + assert parsed_completed.usage is not None + assert parsed_completed.usage.total_tokens == 2 From 8a25309fba525c93198591be56ef0626baf960f6 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 22:59:28 +0200 Subject: [PATCH 007/182] fix(protocols): harden Phase 1 review findings Addresses Phase 1 review findings by treating raw payloads as provenance instead of stale formatting authority in native adapters, tightening registry alias/name collision handling, adding JSON-safe serialization fallbacks, and avoiding default reasoning-token double counting. Adds nested raw preservation for tool, result, and reasoning structures, exposes WebSocket as a future Responses transport seam rather than current formatting support, expands Gemini tool declaration parsing, and switches protocol tests to the public package import path through a local test path fixture. Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py Tests: python -m pytest tests/test_session_tracking.py tests/test_selection_engine.py --- .../protocols/anthropic_messages.py | 28 ++++--- src/rotator_library/protocols/base.py | 6 ++ src/rotator_library/protocols/gemini.py | 75 ++++++++++++------- src/rotator_library/protocols/openai_chat.py | 29 ++++--- src/rotator_library/protocols/registry.py | 7 ++ src/rotator_library/protocols/responses.py | 55 +++++++++++--- src/rotator_library/protocols/types.py | 36 +++++++-- tests/conftest.py | 11 +++ tests/test_protocol_anthropic_messages.py | 17 ++++- tests/test_protocol_gemini.py | 21 +++++- tests/test_protocol_openai_chat.py | 23 +++++- tests/test_protocol_registry.py | 42 ++++++++++- tests/test_protocol_responses.py | 5 +- 13 files changed, 284 insertions(+), 71 deletions(-) create mode 100644 tests/conftest.py diff --git a/src/rotator_library/protocols/anthropic_messages.py b/src/rotator_library/protocols/anthropic_messages.py index 0605d5134..ed1034a0f 100644 --- a/src/rotator_library/protocols/anthropic_messages.py +++ b/src/rotator_library/protocols/anthropic_messages.py @@ -110,8 +110,6 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: - if isinstance(unified_response.raw, dict): - return deepcopy(unified_response.raw) message = unified_response.messages[0] if unified_response.messages else UnifiedMessage(role="assistant") payload = { "id": unified_response.id, @@ -124,7 +122,7 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo "usage": self._format_usage(unified_response.usage), } payload.update(deepcopy(unified_response.extra)) - return payload + return {k: v for k, v in payload.items() if v is not None} def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: event = _decode_sse_data(raw_event) @@ -220,20 +218,21 @@ def _parse_content_block(self, block: Any) -> ContentBlock: text=block.get("thinking"), signature=block.get("signature"), redacted=block_type == "redacted_thinking", + raw=deepcopy(block), extra=_without(block, {"type", "thinking", "signature"}), ) return ContentBlock(type=block_type, reasoning=reasoning, raw=deepcopy(block)) if block_type == "tool_use": return ContentBlock( type="tool_use", - tool_call=ToolCall(id=block.get("id"), name=block.get("name"), arguments=deepcopy(block.get("input")), type="tool_use"), + tool_call=ToolCall(id=block.get("id"), name=block.get("name"), arguments=deepcopy(block.get("input")), type="tool_use", raw=deepcopy(block)), raw=deepcopy(block), extra=_without(block, {"type", "id", "name", "input"}), ) if block_type == "tool_result": return ContentBlock( type="tool_result", - tool_result=ToolResult(tool_call_id=block.get("tool_use_id"), content=deepcopy(block.get("content")), is_error=block.get("is_error")), + tool_result=ToolResult(tool_call_id=block.get("tool_use_id"), content=deepcopy(block.get("content")), is_error=block.get("is_error"), raw=deepcopy(block)), raw=deepcopy(block), extra=_without(block, {"type", "tool_use_id", "content", "is_error"}), ) @@ -242,12 +241,14 @@ def _parse_content_block(self, block: Any) -> ContentBlock: def _format_content(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: formatted = [] for block in blocks: - if isinstance(block.raw, dict): - formatted.append(deepcopy(block.raw)) - elif block.type == "text": - formatted.append({"type": "text", "text": block.text or ""}) + if block.type == "text": + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": "text"} + payload["type"] = "text" + payload["text"] = block.text or "" + formatted.append(payload) elif block.reasoning: - payload = {"type": block.reasoning.type} + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": block.reasoning.type} + payload["type"] = block.reasoning.type if block.reasoning.text is not None: payload["thinking"] = block.reasoning.text if block.reasoning.signature is not None: @@ -255,9 +256,12 @@ def _format_content(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any] payload.update(deepcopy(block.reasoning.extra)) formatted.append(payload) elif block.tool_call: - formatted.append({"type": "tool_use", "id": block.tool_call.id, "name": block.tool_call.name, "input": deepcopy(block.tool_call.arguments)}) + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": "tool_use"} + payload.update({"type": "tool_use", "id": block.tool_call.id, "name": block.tool_call.name, "input": deepcopy(block.tool_call.arguments)}) + formatted.append(payload) elif block.tool_result: - payload = {"type": "tool_result", "tool_use_id": block.tool_result.tool_call_id, "content": deepcopy(block.tool_result.content)} + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": "tool_result"} + payload.update({"type": "tool_result", "tool_use_id": block.tool_result.tool_call_id, "content": deepcopy(block.tool_result.content)}) if block.tool_result.is_error is not None: payload["is_error"] = block.tool_result.is_error formatted.append(payload) diff --git a/src/rotator_library/protocols/base.py b/src/rotator_library/protocols/base.py index fb5e87079..967d26926 100644 --- a/src/rotator_library/protocols/base.py +++ b/src/rotator_library/protocols/base.py @@ -35,12 +35,18 @@ class ProtocolAdapter: name: ClassVar[str] = "base" aliases: ClassVar[tuple[str, ...]] = () supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + future_transports: ClassVar[tuple[str, ...]] = () def supports_transport(self, transport_name: str) -> bool: """Return whether this protocol can format the requested transport.""" return transport_name in self.supported_transports + def is_future_transport(self, transport_name: str) -> bool: + """Return whether this protocol has an intentional future transport seam.""" + + return transport_name in self.future_transports + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: """Parse a raw client/provider request into a unified request.""" diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py index f6e5b3007..89b52b13a 100644 --- a/src/rotator_library/protocols/gemini.py +++ b/src/rotator_library/protocols/gemini.py @@ -63,7 +63,7 @@ def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | model=str(request.get("model") or getattr(context, "model", None) or ""), messages=[self._parse_content(content) for content in request.get("contents") or []], system=self._parse_system(request.get("systemInstruction") or request.get("system_instruction")), - tools=[self._parse_tool(tool) for tool in request.get("tools") or []], + tools=self._parse_tools(request.get("tools") or []), stream=bool(request.get("stream", False)), generation_params={"generationConfig": generation_config, "safetySettings": safety_settings}, raw=deepcopy(raw_request), @@ -116,8 +116,6 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: - if isinstance(unified_response.raw, dict): - return deepcopy(unified_response.raw) candidates = [] for index, message in enumerate(unified_response.messages): candidate = {"index": index, "content": self._format_content(message)} @@ -133,7 +131,7 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo "promptFeedback": deepcopy(unified_response.metadata.get("promptFeedback")), } payload.update(deepcopy(unified_response.extra)) - return payload + return {k: v for k, v in payload.items() if v is not None} def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: event = _decode_sse_data(raw_event) @@ -216,7 +214,7 @@ def _parse_part(self, part: Any) -> ContentBlock: if "text" in part: reasoning = None if part.get("thought") or part.get("thoughtSignature"): - reasoning = ReasoningBlock(type="gemini_thought", text=part.get("text"), signature=part.get("thoughtSignature"), extra=_without(part, {"text", "thought", "thoughtSignature"})) + reasoning = ReasoningBlock(type="gemini_thought", text=part.get("text"), signature=part.get("thoughtSignature"), raw=deepcopy(part), extra=_without(part, {"text", "thought", "thoughtSignature"})) return ContentBlock(type="text", text=part.get("text", ""), reasoning=reasoning, raw=deepcopy(part), extra=_without(part, {"text"})) if "inlineData" in part or "inline_data" in part: source = part.get("inlineData") or part.get("inline_data") @@ -226,27 +224,34 @@ def _parse_part(self, part: Any) -> ContentBlock: return ContentBlock(type="file_data", source=deepcopy(source), raw=deepcopy(part), extra=_without(part, {"fileData", "file_data"})) if "functionCall" in part or "function_call" in part: call = part.get("functionCall") or part.get("function_call") or {} - return ContentBlock(type="function_call", tool_call=ToolCall(name=call.get("name"), arguments=deepcopy(call.get("args")), type="function_call"), raw=deepcopy(part), extra=_without(part, {"functionCall", "function_call"})) + return ContentBlock(type="function_call", tool_call=ToolCall(name=call.get("name"), arguments=deepcopy(call.get("args")), type="function_call", raw=deepcopy(call)), raw=deepcopy(part), extra=_without(part, {"functionCall", "function_call"})) if "functionResponse" in part or "function_response" in part: response = part.get("functionResponse") or part.get("function_response") or {} - return ContentBlock(type="function_response", tool_result=ToolResult(tool_call_id=response.get("name"), content=deepcopy(response.get("response"))), raw=deepcopy(part), extra=_without(part, {"functionResponse", "function_response"})) + return ContentBlock(type="function_response", tool_result=ToolResult(tool_call_id=response.get("name"), content=deepcopy(response.get("response")), raw=deepcopy(response)), raw=deepcopy(part), extra=_without(part, {"functionResponse", "function_response"})) return ContentBlock(type="unknown", raw=deepcopy(part), extra=deepcopy(part)) def _format_parts(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: parts = [] for block in blocks: - if isinstance(block.raw, dict): - parts.append(deepcopy(block.raw)) - elif block.tool_call: - parts.append({"functionCall": {"name": block.tool_call.name, "args": deepcopy(block.tool_call.arguments)}}) + if block.tool_call: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["functionCall"] = {"name": block.tool_call.name, "args": deepcopy(block.tool_call.arguments)} + parts.append(payload) elif block.tool_result: - parts.append({"functionResponse": {"name": block.tool_result.tool_call_id, "response": deepcopy(block.tool_result.content)}}) + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["functionResponse"] = {"name": block.tool_result.tool_call_id, "response": deepcopy(block.tool_result.content)} + parts.append(payload) elif block.type == "inline_data": - parts.append({"inlineData": deepcopy(block.source)}) + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["inlineData"] = deepcopy(block.source) + parts.append(payload) elif block.type == "file_data": - parts.append({"fileData": deepcopy(block.source)}) + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["fileData"] = deepcopy(block.source) + parts.append(payload) else: - payload = {"text": block.text or ""} + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {} + payload["text"] = block.text or "" if block.reasoning: payload["thought"] = True if block.reasoning.signature: @@ -255,17 +260,35 @@ def _format_parts(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: parts.append(payload) return parts - def _parse_tool(self, tool: dict[str, Any]) -> ToolDefinition: - payload = dict(tool or {}) - declarations = payload.get("functionDeclarations") or payload.get("function_declarations") or [] - first = declarations[0] if declarations and isinstance(declarations[0], dict) else payload - return ToolDefinition( - name=str(first.get("name") or ""), - description=first.get("description"), - input_schema=deepcopy(first.get("parameters") or {}), - type="function", - extra={"raw": deepcopy(tool)}, - ) + def _parse_tools(self, tools: Iterable[dict[str, Any]]) -> list[ToolDefinition]: + parsed: list[ToolDefinition] = [] + for tool in tools: + payload = dict(tool or {}) + declarations = payload.get("functionDeclarations") or payload.get("function_declarations") or [] + if declarations: + for index, declaration in enumerate(declarations): + if not isinstance(declaration, dict): + continue + parsed.append( + ToolDefinition( + name=str(declaration.get("name") or ""), + description=declaration.get("description"), + input_schema=deepcopy(declaration.get("parameters") or {}), + type="function", + extra={"raw_container": deepcopy(tool), "declaration_index": index}, + ) + ) + continue + parsed.append( + ToolDefinition( + name=str(payload.get("name") or payload.get("type") or "gemini_tool"), + description=payload.get("description"), + input_schema=deepcopy(payload.get("parameters") or {}), + type=str(payload.get("type") or next(iter(payload.keys()), "tool")), + extra={"raw": deepcopy(tool)}, + ) + ) + return parsed def _format_tool(self, tool: ToolDefinition) -> dict[str, Any]: raw = tool.extra.get("raw") diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index 3c7e44a5c..aaaa32d94 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -148,8 +148,6 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: - if isinstance(unified_response.raw, dict): - return deepcopy(unified_response.raw) choices = [] for index, message in enumerate(unified_response.messages): choices.append( @@ -168,7 +166,7 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo "usage": unified_response.usage.to_dict() if unified_response.usage else None, } payload.update(deepcopy(unified_response.extra)) - return payload + return {k: v for k, v in payload.items() if v is not None} def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: event = _decode_sse_data(raw_event) @@ -295,7 +293,7 @@ def _parse_content(self, content: Any) -> list[ContentBlock]: continue block_type = block.get("type", "text") if block_type == "text": - blocks.append(ContentBlock(type="text", text=block.get("text", ""), raw=deepcopy(block))) + blocks.append(ContentBlock(type="text", text=block.get("text", ""), raw=deepcopy(block), extra=_without(block, {"type", "text"}))) elif block_type in {"image_url", "input_image"}: blocks.append(ContentBlock(type=block_type, source=deepcopy(block.get("image_url") or block.get("source")), raw=deepcopy(block), extra=_without(block, {"type", "image_url", "source"}))) else: @@ -306,14 +304,20 @@ def _format_content(self, blocks: Iterable[ContentBlock]) -> Any: block_list = list(blocks) if not block_list: return None - if all(block.type == "text" for block in block_list): + if all(block.type == "text" and not isinstance(block.raw, dict) and not block.extra for block in block_list): return first_text(block_list) or "" formatted = [] for block in block_list: if block.type == "text": - formatted.append({"type": "text", "text": block.text or ""}) + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": "text"} + payload["type"] = payload.get("type", "text") + payload["text"] = block.text or "" + payload.update(deepcopy(block.extra)) + formatted.append(payload) elif block.type in {"image_url", "input_image"}: - payload = {"type": block.type, "image_url": deepcopy(block.source)} + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": block.type} + payload["type"] = block.type + payload["image_url"] = deepcopy(block.source) payload.update(deepcopy(block.extra)) formatted.append(payload) elif isinstance(block.raw, dict): @@ -358,16 +362,21 @@ def _parse_tool_call(self, call: dict[str, Any]) -> ToolCall: arguments=arguments, type=str(payload.get("type") or "function"), index=payload.get("index"), - extra=_without(payload, {"id", "function", "type", "index", "name"}), + raw=deepcopy(call), + extra={**_without(function, {"name", "arguments"}), **_without(payload, {"id", "function", "type", "index", "name"})}, ) def _format_tool_call(self, call: ToolCall) -> dict[str, Any]: - payload = {"type": call.type, "function": {"name": call.name or "", "arguments": _format_arguments(call.arguments)}} + payload = deepcopy(call.raw) if isinstance(call.raw, dict) else {} + payload["type"] = call.type if call.id: payload["id"] = call.id if call.index is not None: payload["index"] = call.index - payload.update(deepcopy(call.extra)) + function = deepcopy(payload.get("function")) if isinstance(payload.get("function"), dict) else {} + function["name"] = call.name or "" + function["arguments"] = _format_arguments(call.arguments) + payload["function"] = function return payload diff --git a/src/rotator_library/protocols/registry.py b/src/rotator_library/protocols/registry.py index ae9812f3c..1923d5830 100644 --- a/src/rotator_library/protocols/registry.py +++ b/src/rotator_library/protocols/registry.py @@ -38,11 +38,18 @@ def register_protocol(protocol_class: Type[ProtocolAdapter], *, replace: bool = name = protocol_class.name if not name: raise ValueError(f"Protocol {protocol_class.__name__} must define a name") + alias_owner = PROTOCOL_ALIASES.get(name) + if alias_owner and alias_owner != name and not replace: + raise ValueError(f"Protocol name conflicts with registered alias: {name}") existing = PROTOCOL_PLUGINS.get(name) if existing and existing is not protocol_class and not replace: raise ValueError(f"Protocol name already registered: {name}") + if replace and existing and existing is not protocol_class: + for alias, owner in list(PROTOCOL_ALIASES.items()): + if owner == name: + PROTOCOL_ALIASES.pop(alias, None) PROTOCOL_PLUGINS[name] = protocol_class _PROTOCOL_INSTANCES.pop(name, None) diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py index 1b0580e76..8cb262017 100644 --- a/src/rotator_library/protocols/responses.py +++ b/src/rotator_library/protocols/responses.py @@ -67,7 +67,8 @@ class ResponsesProtocol(ProtocolAdapter): name: ClassVar[str] = "responses" aliases: ClassVar[tuple[str, ...]] = ("openai_responses", "response_api") - supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse", "websocket") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + future_transports: ClassVar[tuple[str, ...]] = ("websocket",) def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) @@ -127,9 +128,7 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: - if isinstance(unified_response.raw, dict): - return deepcopy(unified_response.raw) - output = deepcopy(unified_response.output) if unified_response.output else [self._format_output_message(message, index) for index, message in enumerate(unified_response.messages)] + output = [self._format_output_message(message, index) for index, message in enumerate(unified_response.messages)] payload = { "id": unified_response.id, "object": unified_response.metadata.get("object", "response"), @@ -140,7 +139,7 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo "usage": unified_response.usage.to_dict() if unified_response.usage else None, } payload.update(deepcopy(unified_response.extra)) - return payload + return {k: v for k, v in payload.items() if v is not None} def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: event = _decode_sse_data(raw_event) @@ -225,7 +224,16 @@ def _parse_input_item(self, item: dict[str, Any]) -> UnifiedMessage: def _format_input_message(self, message: UnifiedMessage) -> dict[str, Any]: if isinstance(message.raw, dict): - return deepcopy(message.raw) + payload = deepcopy(message.raw) + if payload.get("type") == "function_call_output": + payload["call_id"] = message.tool_call_id or payload.get("call_id") + result = message.content[0].tool_result if message.content and message.content[0].tool_result else None + if result: + payload["output"] = deepcopy(result.content) + return payload + payload["role"] = message.role + payload["content"] = self._format_content(message.content) + return payload return {"type": "message", "role": message.role, "content": self._format_content(message.content)} def _parse_output_item(self, item: dict[str, Any]) -> UnifiedMessage | None: @@ -239,13 +247,30 @@ def _parse_output_item(self, item: dict[str, Any]) -> UnifiedMessage | None: ) if item_type == "reasoning": reasoning = ReasoningBlock(type="reasoning", text=_reasoning_text(item), extra={k: deepcopy(v) for k, v in item.items() if k not in {"type", "summary"}}) + reasoning.raw = deepcopy(item) return UnifiedMessage(role="assistant", content=[ContentBlock(type="reasoning", reasoning=reasoning, raw=deepcopy(item))], reasoning=[reasoning], raw=deepcopy(item)) if item_type in {"function_call", "custom_tool_call"}: - call = ToolCall(id=item.get("call_id") or item.get("id"), name=item.get("name"), arguments=item.get("arguments") or item.get("input"), type=str(item_type)) + call = ToolCall(id=item.get("call_id") or item.get("id"), name=item.get("name"), arguments=item.get("arguments") or item.get("input"), type=str(item_type), raw=deepcopy(item)) return UnifiedMessage(role="assistant", content=[ContentBlock(type=str(item_type), tool_call=call, raw=deepcopy(item))], tool_calls=[call], raw=deepcopy(item)) return None def _format_output_message(self, message: UnifiedMessage, index: int) -> dict[str, Any]: + if isinstance(message.raw, dict): + payload = deepcopy(message.raw) + item_type = payload.get("type") + if item_type == "message": + payload["role"] = message.role + payload["content"] = self._format_content(message.content) + return payload + if item_type == "reasoning" and message.reasoning: + payload["summary"] = [{"type": "summary_text", "text": message.reasoning[0].text or ""}] + return payload + if item_type in {"function_call", "custom_tool_call"} and message.tool_calls: + call = message.tool_calls[0] + payload["call_id"] = call.id + payload["name"] = call.name + payload["arguments"] = deepcopy(call.arguments) + return payload return {"id": f"msg_{index}", "type": "message", "role": message.role, "content": self._format_content(message.content)} def _parse_content(self, content: Any) -> list[ContentBlock]: @@ -275,12 +300,18 @@ def _parse_content(self, content: Any) -> list[ContentBlock]: def _format_content(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: formatted = [] for block in blocks: - if isinstance(block.raw, dict): - formatted.append(deepcopy(block.raw)) - elif block.type in {"input_text", "output_text", "text"}: - formatted.append({"type": block.type, "text": block.text or ""}) + if block.type in {"input_text", "output_text", "text"}: + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": block.type} + payload["type"] = block.type + payload["text"] = block.text or "" + payload.update(deepcopy(block.extra)) + formatted.append(payload) elif block.type in {"input_image", "image_url"}: - formatted.append({"type": block.type, "image_url": deepcopy(block.source)}) + payload = deepcopy(block.raw) if isinstance(block.raw, dict) else {"type": block.type} + payload["type"] = block.type + payload["image_url"] = deepcopy(block.source) + payload.update(deepcopy(block.extra)) + formatted.append(payload) else: payload = {"type": block.type} payload.update(deepcopy(block.extra)) diff --git a/src/rotator_library/protocols/types.py b/src/rotator_library/protocols/types.py index 4440ce7c3..4eb4a889a 100644 --- a/src/rotator_library/protocols/types.py +++ b/src/rotator_library/protocols/types.py @@ -14,6 +14,9 @@ from copy import deepcopy from dataclasses import asdict, dataclass, field, is_dataclass +from datetime import date, datetime +from decimal import Decimal +from pathlib import Path from typing import Any, ClassVar, Iterable, Mapping, Optional @@ -36,7 +39,24 @@ def serialize_value(value: Any) -> Any: return {str(k): serialize_value(v) for k, v in value.items()} if isinstance(value, (list, tuple, set, frozenset)): return [serialize_value(v) for v in value] - return deepcopy(value) + if isinstance(value, bytes): + try: + return value.decode("utf-8") + except UnicodeDecodeError: + return value.hex() + if isinstance(value, (datetime, date)): + return value.isoformat() + if isinstance(value, Decimal): + return float(value) + if isinstance(value, Path): + return str(value) + try: + import json + + json.dumps(value) + return deepcopy(value) + except (TypeError, ValueError): + return repr(value) def copy_mapping(value: Optional[Mapping[str, Any]]) -> JsonObject: @@ -110,7 +130,10 @@ class Usage(ProtocolSerializable): def __post_init__(self) -> None: if self.total_tokens <= 0: - self.total_tokens = self.input_tokens + self.output_tokens + self.reasoning_tokens + # Most APIs report reasoning/thinking tokens as a detail of output + # tokens. Protocol-specific normalizers can provide a larger total + # when a provider documents reasoning as a separate bucket. + self.total_tokens = self.input_tokens + self.output_tokens @dataclass @@ -121,9 +144,10 @@ class ReasoningBlock(ProtocolSerializable): text: Optional[str] = None signature: Optional[str] = None redacted: bool = False + raw: Any = None extra: JsonObject = field(default_factory=dict) - _fields: ClassVar[tuple[str, ...]] = ("type", "text", "signature", "redacted", "extra") + _fields: ClassVar[tuple[str, ...]] = ("type", "text", "signature", "redacted", "raw", "extra") @dataclass @@ -135,9 +159,10 @@ class ToolCall(ProtocolSerializable): arguments: Any = None type: str = "function" index: Optional[int] = None + raw: Any = None extra: JsonObject = field(default_factory=dict) - _fields: ClassVar[tuple[str, ...]] = ("id", "name", "arguments", "type", "index", "extra") + _fields: ClassVar[tuple[str, ...]] = ("id", "name", "arguments", "type", "index", "raw", "extra") @dataclass @@ -147,9 +172,10 @@ class ToolResult(ProtocolSerializable): tool_call_id: Optional[str] = None content: Any = None is_error: Optional[bool] = None + raw: Any = None extra: JsonObject = field(default_factory=dict) - _fields: ClassVar[tuple[str, ...]] = ("tool_call_id", "content", "is_error", "extra") + _fields: ClassVar[tuple[str, ...]] = ("tool_call_id", "content", "is_error", "raw", "extra") @dataclass diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..df150dc59 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" + +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) diff --git a/tests/test_protocol_anthropic_messages.py b/tests/test_protocol_anthropic_messages.py index b1c161559..34708d60f 100644 --- a/tests/test_protocol_anthropic_messages.py +++ b/tests/test_protocol_anthropic_messages.py @@ -2,7 +2,22 @@ import json -from src.rotator_library.protocols import get_protocol, list_protocols +from rotator_library.protocols import get_protocol, list_protocols + + +def test_anthropic_build_uses_mutated_unified_thinking_signature() -> None: + adapter = get_protocol("anthropic_messages") + raw = { + "model": "anthropic/claude-test", + "messages": [{"role": "assistant", "content": [{"type": "thinking", "thinking": "before", "signature": "sig_1"}]}], + } + + unified = adapter.parse_request(raw) + unified.messages[0].content[0].reasoning.text = "after" + rebuilt = adapter.build_request(unified) + + assert rebuilt["messages"][0]["content"][0]["thinking"] == "after" + assert rebuilt["messages"][0]["content"][0]["signature"] == "sig_1" def test_anthropic_messages_protocol_is_discovered_with_aliases() -> None: diff --git a/tests/test_protocol_gemini.py b/tests/test_protocol_gemini.py index 2ce1eaab1..d90592c75 100644 --- a/tests/test_protocol_gemini.py +++ b/tests/test_protocol_gemini.py @@ -2,7 +2,7 @@ import json -from src.rotator_library.protocols import get_protocol, list_protocols +from rotator_library.protocols import get_protocol, list_protocols def test_gemini_protocol_is_discovered_with_aliases() -> None: @@ -48,6 +48,25 @@ def test_gemini_request_round_trip_preserves_parts_tools_and_settings() -> None: assert rebuilt["vendor_extension"] == {"kept": True} +def test_gemini_parses_multiple_function_declarations() -> None: + adapter = get_protocol("gemini") + raw = { + "contents": [], + "tools": [ + { + "functionDeclarations": [ + {"name": "one", "parameters": {"type": "object"}}, + {"name": "two", "parameters": {"type": "object"}}, + ] + } + ], + } + + unified = adapter.parse_request(raw) + + assert [tool.name for tool in unified.tools] == ["one", "two"] + + def test_gemini_response_extracts_usage_and_thought_signature() -> None: adapter = get_protocol("gemini") raw = { diff --git a/tests/test_protocol_openai_chat.py b/tests/test_protocol_openai_chat.py index 214b96e6b..fce7f39db 100644 --- a/tests/test_protocol_openai_chat.py +++ b/tests/test_protocol_openai_chat.py @@ -2,7 +2,7 @@ import json -from src.rotator_library.protocols import get_protocol, list_protocols +from rotator_library.protocols import get_protocol, list_protocols def test_openai_chat_protocol_is_discovered_with_aliases() -> None: @@ -72,6 +72,27 @@ def test_openai_chat_request_round_trip_preserves_tools_and_reasoning() -> None: assert rebuilt["messages"][3]["reasoning_content"] == "internal chain" +def test_openai_chat_build_uses_mutated_unified_fields_and_preserves_block_extras() -> None: + adapter = get_protocol("openai_chat") + raw = { + "model": "openai/gpt-test", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "before", "cache_control": {"type": "ephemeral"}}], + } + ], + } + + unified = adapter.parse_request(raw) + unified.messages[0].content[0].text = "after" + rebuilt = adapter.build_request(unified) + + assert isinstance(rebuilt["messages"][0]["content"], list) + assert rebuilt["messages"][0]["content"][0]["text"] == "after" + assert rebuilt["messages"][0]["content"][0]["cache_control"] == {"type": "ephemeral"} + + def test_openai_chat_response_extracts_usage_cost_and_reasoning() -> None: adapter = get_protocol("openai_chat") raw = { diff --git a/tests/test_protocol_registry.py b/tests/test_protocol_registry.py index d3151b4c4..0b449eaf6 100644 --- a/tests/test_protocol_registry.py +++ b/tests/test_protocol_registry.py @@ -1,10 +1,12 @@ from __future__ import annotations import json +from datetime import datetime +from decimal import Decimal import pytest -from src.rotator_library.protocols import ( +from rotator_library.protocols import ( ContentBlock, ProtocolAdapter, ProtocolContext, @@ -84,6 +86,44 @@ class SecondProtocol(ProtocolAdapter): register_protocol(SecondProtocol) +def test_register_protocol_rejects_name_that_conflicts_with_existing_alias() -> None: + class AliasOwnerProtocol(ProtocolAdapter): + name = "alias_owner_protocol" + aliases = ("reserved_protocol_name",) + + class ConflictingNameProtocol(ProtocolAdapter): + name = "reserved_protocol_name" + + register_protocol(AliasOwnerProtocol, replace=True) + + with pytest.raises(ValueError, match="name conflicts with registered alias"): + register_protocol(ConflictingNameProtocol) + + +def test_protocol_serialization_handles_common_non_json_values() -> None: + value = serialize_value( + { + "bytes": b"hello", + "decimal": Decimal("1.25"), + "datetime": datetime(2026, 1, 2, 3, 4, 5), + "object": object(), + } + ) + + assert value["bytes"] == "hello" + assert value["decimal"] == 1.25 + assert value["datetime"] == "2026-01-02T03:04:05" + json.dumps(value) + + +def test_usage_total_does_not_double_count_reasoning_by_default() -> None: + from rotator_library.protocols import Usage + + usage = Usage(input_tokens=10, output_tokens=5, reasoning_tokens=3) + + assert usage.total_tokens == 15 + + def test_protocol_error_includes_pass_and_payload_preview() -> None: error = ProtocolError( "failed", diff --git a/tests/test_protocol_responses.py b/tests/test_protocol_responses.py index b6474c7d7..a461addef 100644 --- a/tests/test_protocol_responses.py +++ b/tests/test_protocol_responses.py @@ -2,7 +2,7 @@ import json -from src.rotator_library.protocols import get_protocol, list_protocols +from rotator_library.protocols import get_protocol, list_protocols def test_responses_protocol_is_discovered_with_aliases_and_websocket_support() -> None: @@ -10,7 +10,8 @@ def test_responses_protocol_is_discovered_with_aliases_and_websocket_support() - assert "responses" in list_protocols() assert get_protocol("openai_responses") is adapter - assert adapter.supports_transport("websocket") is True + assert adapter.supports_transport("websocket") is False + assert adapter.is_future_transport("websocket") is True def test_responses_request_round_trip_preserves_previous_response_and_tools() -> None: From cb66027c63e045510c6e99b92763898f990b8aef Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:05:09 +0200 Subject: [PATCH 008/182] fix(protocols): preserve native block and output fidelity Preserves Anthropic system block shape and metadata during rebuilds, keeps unknown Responses output items while still applying unified-message mutations, and groups Gemini multi-declaration tools back into their original native tool container. Adds tests for Anthropic system cache metadata, Responses future output-item preservation, and Gemini multi-declaration rebuild fidelity. Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py Tests: python -m pytest tests/test_session_tracking.py tests/test_selection_engine.py --- .../protocols/anthropic_messages.py | 2 +- src/rotator_library/protocols/gemini.py | 28 +++++++++++++++++-- src/rotator_library/protocols/responses.py | 14 ++++++++-- tests/test_protocol_anthropic_messages.py | 17 +++++++++++ tests/test_protocol_gemini.py | 5 ++++ tests/test_protocol_responses.py | 19 +++++++++++++ 6 files changed, 79 insertions(+), 6 deletions(-) diff --git a/src/rotator_library/protocols/anthropic_messages.py b/src/rotator_library/protocols/anthropic_messages.py index ed1034a0f..b9a81c248 100644 --- a/src/rotator_library/protocols/anthropic_messages.py +++ b/src/rotator_library/protocols/anthropic_messages.py @@ -173,7 +173,7 @@ def _format_system(self, blocks: Iterable[ContentBlock]) -> Any: block_list = list(blocks) if not block_list: return None - if all(block.type == "text" for block in block_list): + if all(block.type == "text" and not isinstance(block.raw, dict) and not block.extra for block in block_list): return first_text(block_list) or "" return self._format_content(block_list) diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py index 89b52b13a..a812c41db 100644 --- a/src/rotator_library/protocols/gemini.py +++ b/src/rotator_library/protocols/gemini.py @@ -85,7 +85,7 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex if safety_settings: payload["safetySettings"] = deepcopy(safety_settings) if unified_request.tools: - payload["tools"] = [self._format_tool(tool) for tool in unified_request.tools] + payload["tools"] = self._format_tools(unified_request.tools) if unified_request.stream: payload["stream"] = True payload.update(deepcopy(unified_request.extra)) @@ -262,7 +262,7 @@ def _format_parts(self, blocks: Iterable[ContentBlock]) -> list[dict[str, Any]]: def _parse_tools(self, tools: Iterable[dict[str, Any]]) -> list[ToolDefinition]: parsed: list[ToolDefinition] = [] - for tool in tools: + for container_index, tool in enumerate(tools): payload = dict(tool or {}) declarations = payload.get("functionDeclarations") or payload.get("function_declarations") or [] if declarations: @@ -275,7 +275,7 @@ def _parse_tools(self, tools: Iterable[dict[str, Any]]) -> list[ToolDefinition]: description=declaration.get("description"), input_schema=deepcopy(declaration.get("parameters") or {}), type="function", - extra={"raw_container": deepcopy(tool), "declaration_index": index}, + extra={"raw_container": deepcopy(tool), "container_index": container_index, "declaration_index": index}, ) ) continue @@ -290,6 +290,28 @@ def _parse_tools(self, tools: Iterable[dict[str, Any]]) -> list[ToolDefinition]: ) return parsed + def _format_tools(self, tools: Iterable[ToolDefinition]) -> list[dict[str, Any]]: + grouped: dict[int, dict[str, Any]] = {} + ungrouped: list[dict[str, Any]] = [] + for tool in tools: + raw_container = tool.extra.get("raw_container") + container_index = tool.extra.get("container_index") + declaration_index = tool.extra.get("declaration_index") + if isinstance(raw_container, dict) and isinstance(container_index, int) and isinstance(declaration_index, int): + container = grouped.setdefault(container_index, deepcopy(raw_container)) + declarations = container.setdefault("functionDeclarations", []) + while len(declarations) <= declaration_index: + declarations.append({}) + declaration = deepcopy(declarations[declaration_index]) if isinstance(declarations[declaration_index], dict) else {} + declaration["name"] = tool.name + if tool.description is not None: + declaration["description"] = tool.description + declaration["parameters"] = deepcopy(tool.input_schema) + declarations[declaration_index] = declaration + continue + ungrouped.append(self._format_tool(tool)) + return [grouped[index] for index in sorted(grouped)] + ungrouped + def _format_tool(self, tool: ToolDefinition) -> dict[str, Any]: raw = tool.extra.get("raw") if isinstance(raw, dict): diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py index 8cb262017..1330e72d7 100644 --- a/src/rotator_library/protocols/responses.py +++ b/src/rotator_library/protocols/responses.py @@ -110,10 +110,11 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No response = _as_dict(raw_response) output = deepcopy(response.get("output") or []) messages: list[UnifiedMessage] = [] - for item in output: + for index, item in enumerate(output): if isinstance(item, dict): parsed = self._parse_output_item(item) if parsed: + parsed.extra["_output_index"] = index messages.append(parsed) return UnifiedResponse( id=response.get("id"), @@ -128,7 +129,16 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: - output = [self._format_output_message(message, index) for index, message in enumerate(unified_response.messages)] + if unified_response.output: + output = deepcopy(unified_response.output) + for fallback_index, message in enumerate(unified_response.messages): + output_index = message.extra.get("_output_index", fallback_index) + if isinstance(output_index, int) and 0 <= output_index < len(output): + output[output_index] = self._format_output_message(message, output_index) + else: + output.append(self._format_output_message(message, fallback_index)) + else: + output = [self._format_output_message(message, index) for index, message in enumerate(unified_response.messages)] payload = { "id": unified_response.id, "object": unified_response.metadata.get("object", "response"), diff --git a/tests/test_protocol_anthropic_messages.py b/tests/test_protocol_anthropic_messages.py index 34708d60f..9cc0aaf54 100644 --- a/tests/test_protocol_anthropic_messages.py +++ b/tests/test_protocol_anthropic_messages.py @@ -59,9 +59,26 @@ def test_anthropic_request_round_trip_preserves_thinking_tools_and_cache_metadat assert unified.messages[2].content[0].tool_result.tool_call_id == "toolu_1" assert unified.tools[0].input_schema == {"type": "object"} assert rebuilt["vendor_extension"] == {"kept": True} + assert isinstance(rebuilt["system"], list) + assert rebuilt["system"][0]["text"] == "system" assert rebuilt["messages"][1]["content"][0]["signature"] == "sig_1" +def test_anthropic_system_preserves_block_metadata() -> None: + adapter = get_protocol("anthropic_messages") + raw = { + "model": "anthropic/claude-test", + "system": [{"type": "text", "text": "system", "cache_control": {"type": "ephemeral"}}], + "messages": [{"role": "user", "content": "hello"}], + } + + unified = adapter.parse_request(raw) + unified.system[0].text = "updated" + rebuilt = adapter.build_request(unified) + + assert rebuilt["system"] == [{"type": "text", "text": "updated", "cache_control": {"type": "ephemeral"}}] + + def test_anthropic_response_extracts_content_and_usage() -> None: adapter = get_protocol("anthropic_messages") raw = { diff --git a/tests/test_protocol_gemini.py b/tests/test_protocol_gemini.py index d90592c75..709a8afab 100644 --- a/tests/test_protocol_gemini.py +++ b/tests/test_protocol_gemini.py @@ -63,8 +63,13 @@ def test_gemini_parses_multiple_function_declarations() -> None: } unified = adapter.parse_request(raw) + unified.tools[1].description = "Second tool" + rebuilt = adapter.build_request(unified) assert [tool.name for tool in unified.tools] == ["one", "two"] + assert len(rebuilt["tools"]) == 1 + assert [declaration["name"] for declaration in rebuilt["tools"][0]["functionDeclarations"]] == ["one", "two"] + assert rebuilt["tools"][0]["functionDeclarations"][1]["description"] == "Second tool" def test_gemini_response_extracts_usage_and_thought_signature() -> None: diff --git a/tests/test_protocol_responses.py b/tests/test_protocol_responses.py index a461addef..0ae2d752c 100644 --- a/tests/test_protocol_responses.py +++ b/tests/test_protocol_responses.py @@ -81,6 +81,25 @@ def test_responses_response_extracts_output_items_reasoning_calls_and_usage() -> assert unified.usage.cost.provider_reported_cost == 0.02 +def test_responses_format_preserves_unknown_output_items() -> None: + adapter = get_protocol("responses") + raw = { + "id": "resp_1", + "model": "openai/gpt-test", + "output": [ + {"id": "msg_1", "type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "before"}]}, + {"id": "future_1", "type": "future_item", "custom": {"kept": True}}, + ], + } + + unified = adapter.parse_response(raw) + unified.messages[0].content[0].text = "after" + rebuilt = adapter.format_response(unified) + + assert rebuilt["output"][0]["content"][0]["text"] == "after" + assert rebuilt["output"][1] == {"id": "future_1", "type": "future_item", "custom": {"kept": True}} + + def test_responses_stream_event_parses_text_delta_and_completed_response() -> None: adapter = get_protocol("responses") delta = {"type": "response.output_text.delta", "output_index": 0, "content_index": 0, "delta": "Hi"} From 46696d59efabcfea1f7462d925b681f75d409768 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:14:22 +0200 Subject: [PATCH 009/182] docs(experimental): plan transform trace logging Adds the Phase 2 plan for additive transform-pass transaction logging, including trace entry shape, writer behavior, request/response/stream pass names, sanitization, TransactionLogger and ProviderLogger integration, tests, risks, and review checkpoints. The Phase 1 report remains uncommitted for user-facing review only. --- .../experimental/phase-2-transform-logging.md | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 docs/experimental/phase-2-transform-logging.md diff --git a/docs/experimental/phase-2-transform-logging.md b/docs/experimental/phase-2-transform-logging.md new file mode 100644 index 000000000..77cfbaeec --- /dev/null +++ b/docs/experimental/phase-2-transform-logging.md @@ -0,0 +1,203 @@ +# Phase 2 Plan: Transform-Pass Logging + +## Goal + +Make transaction logging capable of recording every meaningful request, response, and stream state after each transformation pass, without yet replacing runtime execution with the native protocol adapters. This phase creates the durable trace format and wires it into existing request/executor/provider logging points so later protocol, adapter, and field-cache phases can add more trace entries without redesigning observability. + +## Non-Goals + +- Do not route execution through native protocols yet. +- Do not implement field-cache extraction or injection yet. +- Do not add Responses routes yet. +- Do not add full multi-user/security work. +- Do not rewrite provider-specific logging systems. +- Do not change request/response behavior. + +## Current Code Context + +- `TransactionLogger` creates one directory per transaction and writes legacy files such as `request.json`, `request_transformed.json`, `response.json`, `streaming_chunks.jsonl`, and `metadata.json`. +- `RequestContextBuilder` creates `TransactionLogger` and calls `log_request(kwargs)`. +- `RequestExecutor` calls `log_transformed_request(kwargs, context.kwargs)` after request preparation. +- `RequestExecutor` logs non-streaming responses through `log_response(response_data)`. +- `RequestExecutor._transaction_logging_stream_wrapper()` logs parsed chunks and assembles/logs the final streamed response. +- `ProviderLogger` writes provider-level request payloads, raw response chunks, final responses, and provider errors. +- Phase 1 unified types already provide JSON-safe serialization through `serialize_value()`. + +## Files To Add + +- `src/rotator_library/transform_trace.py` +- `tests/test_transform_trace.py` +- `tests/test_transaction_logger_transform_trace.py` + +## Files Likely To Touch + +- `src/rotator_library/transaction_logger.py` +- `src/rotator_library/client/request_builder.py` +- `src/rotator_library/client/executor.py` + +## Trace Model + +`TransformTraceEntry` fields: + +- `sequence`: monotonically increasing integer per writer instance. +- `timestamp_utc`. +- `component`: `client` or `provider`. +- `pass_name`: stable machine-readable name. +- `direction`: `request`, `response`, `stream`, `error`, or `metadata`. +- `stage`: `client`, `protocol`, `adapter`, `provider`, or `final`. +- `protocol`: optional protocol name. +- `provider`: optional provider name. +- `model`: optional model name. +- `credential_id`: optional stable or masked credential identifier only. +- `transport`: `http`, `sse`, `websocket`, or a future string. +- `changed_from_previous`: optional bool. +- `metadata`: JSON-safe dict. +- `data`: JSON-safe sanitized payload. + +`TransformTraceWriter` responsibilities: + +- Maintain a local sequence counter. +- Write compact ordered entries to `transform_trace.jsonl`. +- Optionally write individual request/response snapshots under `transforms/`. +- Avoid per-chunk snapshot files for stream chunks. +- Never raise logging failures into request execution. + +## Log Files + +Existing files stay for compatibility: + +- `request.json` +- `request_transformed.json` +- `response.json` +- `streaming_chunks.jsonl` +- `metadata.json` + +New files: + +- `transform_trace.jsonl` +- `transforms/0001_raw_client_request.json` +- `transforms/0002_prepared_provider_request.json` +- `transforms/NNNN_.json` for non-stream snapshots + +Stream chunks use JSONL entries rather than one file per chunk. + +## Required Pass Names + +Request passes: + +- `raw_client_request` +- `prepared_provider_request` +- `provider_request_payload` + +Response passes: + +- `raw_client_response` +- `final_client_response` +- `provider_final_response` + +Stream passes: + +- `raw_stream_chunk` +- `parsed_stream_chunk` +- `assembled_stream_response` +- `provider_raw_stream_chunk` + +Error passes: + +- `provider_error` +- `transform_log_error` + +## Sanitization + +This is pragmatic trace hygiene, not full security work. + +Redact recursively by key name: + +- `api_key` +- `credential_identifier` +- `authorization` +- `x-api-key` +- `x-goog-api-key` +- `access_token` +- `refresh_token` +- `client_secret` +- `password` +- `secret` +- `token` + +Rules: + +- Redacted value is `"[REDACTED]"`. +- Redact by key only, not by text value. +- Do not hide ordinary model text containing words such as token. +- Always serialize through Phase 1 JSON-safe helpers. + +## Integration Behavior + +- Trace logging is additive and backward compatible. +- Logging failures never fail requests. +- `log_request()` records `raw_client_request`. +- `log_transformed_request()` records `prepared_provider_request` even when legacy transformed file is skipped because payloads compare equal. +- `log_response()` records `final_client_response`. +- A dedicated method can record `raw_client_response` before final normalization when available. +- `log_stream_chunk()` records `parsed_stream_chunk`. +- `_transaction_logging_stream_wrapper()` records `raw_stream_chunk` before JSON parsing and `assembled_stream_response` before final `log_response()`. +- `ProviderLogger.log_request()` records `provider_request_payload`. +- `ProviderLogger.log_response_chunk()` records `provider_raw_stream_chunk`. +- `ProviderLogger.log_final_response()` records `provider_final_response`. +- `ProviderLogger.log_error()` records `provider_error`. + +## Context Design + +- `TransactionLogger` owns a client-side `TransformTraceWriter`. +- `TransactionContext` carries enough trace metadata for providers to create provider-side trace entries without sharing mutable logger objects. +- `ProviderLogger` creates its own writer with `component="provider"`. +- Client and provider sequence numbers are local to their writer instances. Entries include component and timestamps, so Phase 2 does not promise one global order across independent writers. + +## Comments And Docstrings + +- Public trace classes must explain they are observability-only and must not affect request behavior. +- Redaction helpers must explain why redaction is key-based rather than value-based. +- TransactionLogger comments should distinguish legacy files from the new trace ledger. +- Stream comments should explain why stream chunks use JSONL rather than per-chunk snapshots. + +## Tests + +`tests/test_transform_trace.py`: + +- JSON-safe serialization handles dataclasses and non-primitives. +- Redaction recurses through nested payloads. +- Redaction does not redact normal text values. +- Sequence increments per writer. +- Snapshot filenames are stable and sanitized. + +`tests/test_transaction_logger_transform_trace.py`: + +- `log_request()` writes legacy `request.json` and trace entry `raw_client_request`. +- `log_transformed_request()` writes trace entry even when legacy transformed file is skipped because payloads are equal. +- `log_response()` writes `final_client_response`. +- `log_stream_chunk()` writes `parsed_stream_chunk`. +- ProviderLogger writes provider request, chunk, final, and error trace entries. +- Logging disabled writes nothing. + +Regression tests: + +- Phase 1 protocol tests. +- `tests/test_session_tracking.py`. +- `tests/test_selection_engine.py`. + +## Commit Checkpoints + +1. Add transform trace dataclasses, writer, redaction, and tests. +2. Integrate client-side `TransactionLogger` trace entries and tests. +3. Integrate provider-side `ProviderLogger` trace entries and tests. +4. Run focused and regression tests. +5. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 2 report. + +## Risks And Mitigations + +- Stream logs can grow quickly. Use JSONL only for stream chunks. +- Redaction can hide useful fields if too broad. Redact by sensitive key names only. +- Provider/client entry ordering is not globally sequenced in Phase 2. Include component and timestamps. +- Provider SDK objects may not be JSON-native. Always use `serialize_value()`. +- Tests may import an installed package if local `src` is not first on `sys.path`. Keep `tests/conftest.py`. From 1b843218b6859ca010f409a375e88a6a58269b4b Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:16:26 +0200 Subject: [PATCH 010/182] feat(logging): add transform trace writer Introduces transform trace entries, a local-sequence JSONL/snapshot writer, recursive key-based redaction, filesystem-safe snapshot names, and JSON-safe payload serialization for future protocol and adapter pass logging. The trace writer is observability-only and isolated from runtime transaction logging in this checkpoint. Tests: python -m pytest tests/test_transform_trace.py --- src/rotator_library/transform_trace.py | 193 +++++++++++++++++++++++++ tests/test_transform_trace.py | 77 ++++++++++ 2 files changed, 270 insertions(+) create mode 100644 src/rotator_library/transform_trace.py create mode 100644 tests/test_transform_trace.py diff --git a/src/rotator_library/transform_trace.py b/src/rotator_library/transform_trace.py new file mode 100644 index 000000000..55c00c105 --- /dev/null +++ b/src/rotator_library/transform_trace.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Transform-pass trace primitives for transaction logging. + +The trace layer is observability-only: failures here must never change request +execution. Transaction logging uses this module to snapshot each meaningful +request, response, and stream state as later protocol/adapters mutate payloads. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Optional + +from .protocols import serialize_value + +lib_logger = logging.getLogger("rotator_library") + +REDACTED = "[REDACTED]" + +_SENSITIVE_KEYS = frozenset( + { + "api-key", + "credential-identifier", + "authorization", + "x-api-key", + "x-goog-api-key", + "access-token", + "refresh-token", + "client-secret", + "password", + "secret", + "token", + } +) + + +def _normalise_key(key: Any) -> str: + return str(key).strip().lower().replace("_", "-") + + +def sanitize_for_trace(value: Any) -> Any: + """Return a JSON-safe, recursively redacted value for trace files. + + Redaction is intentionally key-based rather than value-based. Model text may + legitimately mention tokens, passwords, or secrets; hiding by value would make + debugging transformations unreliable. Sensitive framework fields are redacted + only when their key name is known to carry credentials. + """ + + serialized = serialize_value(value) + if isinstance(serialized, dict): + sanitized = {} + for key, item in serialized.items(): + if _normalise_key(key) in _SENSITIVE_KEYS: + sanitized[str(key)] = REDACTED + else: + sanitized[str(key)] = sanitize_for_trace(item) + return sanitized + if isinstance(serialized, list): + return [sanitize_for_trace(item) for item in serialized] + return serialized + + +def sanitize_filename(value: str) -> str: + """Return a stable, filesystem-safe name component for snapshot files.""" + + safe = value.strip().lower().replace(" ", "_") or "trace" + for char in '/\\:*?"<>|': + safe = safe.replace(char, "_") + return "".join(char if char.isalnum() or char in "._-" else "_" for char in safe) + + +@dataclass +class TransformTraceEntry: + """A single transform-pass observation. + + Entries are deliberately protocol-neutral. Later phases can add protocol, + adapter, field-cache, routing, and transport passes without changing the file + format used by this phase. + """ + + sequence: int + component: str + pass_name: str + direction: str + stage: str + timestamp_utc: str = field(default_factory=lambda: datetime.now(UTC).isoformat()) + protocol: Optional[str] = None + provider: Optional[str] = None + model: Optional[str] = None + credential_id: Optional[str] = None + transport: Optional[str] = None + changed_from_previous: Optional[bool] = None + metadata: dict[str, Any] = field(default_factory=dict) + data: Any = None + + def to_dict(self) -> dict[str, Any]: + return { + "sequence": self.sequence, + "component": self.component, + "pass_name": self.pass_name, + "direction": self.direction, + "stage": self.stage, + "timestamp_utc": self.timestamp_utc, + "protocol": self.protocol, + "provider": self.provider, + "model": self.model, + "credential_id": self.credential_id, + "transport": self.transport, + "changed_from_previous": self.changed_from_previous, + "metadata": sanitize_for_trace(self.metadata), + "data": sanitize_for_trace(self.data), + } + + +class TransformTraceWriter: + """Append-only writer for transform trace entries and snapshots. + + One writer owns one local sequence counter. Phase 2 intentionally does not + promise a global ordering across client and provider writers; entries include + component and timestamps so interleaved logs remain understandable. + """ + + def __init__( + self, + log_dir: Path, + *, + component: str, + provider: Optional[str] = None, + model: Optional[str] = None, + enabled: bool = True, + ) -> None: + self.log_dir = log_dir + self.component = component + self.provider = provider + self.model = model + self.enabled = enabled + self._sequence = 0 + self.trace_file = log_dir / "transform_trace.jsonl" + self.snapshot_dir = log_dir / "transforms" + + def record( + self, + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + protocol: Optional[str] = None, + credential_id: Optional[str] = None, + transport: Optional[str] = None, + changed_from_previous: Optional[bool] = None, + metadata: Optional[dict[str, Any]] = None, + snapshot: bool = True, + ) -> Optional[TransformTraceEntry]: + """Record a transform pass, swallowing logging failures.""" + + if not self.enabled: + return None + self._sequence += 1 + entry = TransformTraceEntry( + sequence=self._sequence, + component=self.component, + pass_name=pass_name, + direction=direction, + stage=stage, + protocol=protocol, + provider=self.provider, + model=self.model, + credential_id=credential_id, + transport=transport, + changed_from_previous=changed_from_previous, + metadata=metadata or {}, + data=data, + ) + try: + self.log_dir.mkdir(parents=True, exist_ok=True) + with open(self.trace_file, "a", encoding="utf-8") as handle: + handle.write(json.dumps(entry.to_dict(), ensure_ascii=False) + "\n") + if snapshot and direction != "stream": + self.snapshot_dir.mkdir(parents=True, exist_ok=True) + snapshot_name = f"{entry.sequence:04d}_{sanitize_filename(pass_name)}.json" + with open(self.snapshot_dir / snapshot_name, "w", encoding="utf-8") as handle: + json.dump(entry.to_dict(), handle, indent=2, ensure_ascii=False) + except Exception as exc: + lib_logger.debug("Transform trace write failed for %s: %s", pass_name, exc) + return entry diff --git a/tests/test_transform_trace.py b/tests/test_transform_trace.py new file mode 100644 index 000000000..f24a1b4cb --- /dev/null +++ b/tests/test_transform_trace.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import datetime +from decimal import Decimal + +from rotator_library.transform_trace import ( + REDACTED, + TransformTraceWriter, + sanitize_filename, + sanitize_for_trace, +) + + +@dataclass +class ExamplePayload: + when: datetime + amount: Decimal + + +def test_sanitize_for_trace_redacts_sensitive_keys_recursively() -> None: + payload = { + "api_key": "secret-key", + "headers": {"Authorization": "Bearer secret", "normal": "token in normal text"}, + "items": [{"refresh_token": "refresh", "text": "token should remain in value"}], + } + + sanitized = sanitize_for_trace(payload) + + assert sanitized["api_key"] == REDACTED + assert sanitized["headers"]["Authorization"] == REDACTED + assert sanitized["headers"]["normal"] == "token in normal text" + assert sanitized["items"][0]["refresh_token"] == REDACTED + assert sanitized["items"][0]["text"] == "token should remain in value" + + +def test_sanitize_for_trace_serializes_common_non_json_values() -> None: + payload = ExamplePayload(when=datetime(2026, 1, 2, 3, 4, 5), amount=Decimal("1.5")) + + sanitized = sanitize_for_trace({"payload": payload, "bytes": b"hello"}) + + assert sanitized["payload"]["when"] == "2026-01-02T03:04:05" + assert sanitized["payload"]["amount"] == 1.5 + assert sanitized["bytes"] == "hello" + json.dumps(sanitized) + + +def test_sanitize_filename_is_stable_and_filesystem_safe() -> None: + assert sanitize_filename("Raw Client Request") == "raw_client_request" + assert sanitize_filename("provider/request:payload") == "provider_request_payload" + + +def test_transform_trace_writer_records_jsonl_and_snapshots(tmp_path) -> None: + writer = TransformTraceWriter(tmp_path, component="client", provider="openai", model="gpt-test") + + first = writer.record("raw_client_request", {"model": "gpt-test"}, direction="request", stage="client") + second = writer.record("parsed_stream_chunk", {"delta": "hi"}, direction="stream", stage="client", snapshot=True) + + assert first is not None + assert second is not None + assert first.sequence == 1 + assert second.sequence == 2 + + lines = (tmp_path / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines() + assert len(lines) == 2 + assert json.loads(lines[0])["pass_name"] == "raw_client_request" + assert json.loads(lines[1])["direction"] == "stream" + assert (tmp_path / "transforms" / "0001_raw_client_request.json").exists() + assert not (tmp_path / "transforms" / "0002_parsed_stream_chunk.json").exists() + + +def test_transform_trace_writer_disabled_writes_nothing(tmp_path) -> None: + writer = TransformTraceWriter(tmp_path, component="client", enabled=False) + + assert writer.record("raw_client_request", {}, direction="request", stage="client") is None + assert not (tmp_path / "transform_trace.jsonl").exists() From 78fac9316c8481d5cfeea5b79a9d863465327f82 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:23:26 +0200 Subject: [PATCH 011/182] feat(logging): trace transaction transform passes Wires the transform trace writer into TransactionLogger and ProviderLogger while preserving legacy request, transformed request, response, streaming chunk, metadata, and provider log files. Adds trace entries for raw client requests, prepared provider requests, raw and parsed stream chunks, assembled stream responses, final client responses, provider request payloads, provider raw stream chunks, provider final responses, and provider errors. Includes transaction logger tests for legacy compatibility, redaction, equality-skipped transformed requests, provider traces, streaming wrapper traces, and disabled logging. Tests: python -m pytest tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py Tests: python -m pytest tests/test_session_tracking.py tests/test_selection_engine.py --- src/rotator_library/client/executor.py | 15 ++ src/rotator_library/transaction_logger.py | 164 ++++++++++++++++-- ...test_transaction_logger_transform_trace.py | 108 ++++++++++++ 3 files changed, 275 insertions(+), 12 deletions(-) create mode 100644 tests/test_transaction_logger_transform_trace.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index d83c61192..cbcd8a570 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1503,6 +1503,14 @@ async def _transaction_logging_stream_wrapper( chunks = [] async for sse_line in stream: + transaction_logger.log_transform_pass( + "raw_stream_chunk", + sse_line, + direction="stream", + stage="client", + transport="sse", + snapshot=False, + ) yield sse_line # Parse and accumulate for final logging @@ -1524,6 +1532,13 @@ async def _transaction_logging_stream_wrapper( if chunks: try: final_response = TransactionLogger.assemble_streaming_response(chunks) + transaction_logger.log_transform_pass( + "assembled_stream_response", + final_response, + direction="response", + stage="client", + transport="sse", + ) transaction_logger.log_response(final_response) except Exception as e: lib_logger.debug( diff --git a/src/rotator_library/transaction_logger.py b/src/rotator_library/transaction_logger.py index 1a5fa6df8..9f77b9821 100644 --- a/src/rotator_library/transaction_logger.py +++ b/src/rotator_library/transaction_logger.py @@ -29,10 +29,11 @@ import time import uuid from dataclasses import dataclass -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path from typing import Any, Dict, Optional, Union +from .transform_trace import TransformTraceWriter from .utils.paths import get_logs_dir lib_logger = logging.getLogger("rotator_library") @@ -58,6 +59,12 @@ def _get_transactions_dir() -> Path: return transactions_dir +def _utc_timestamp() -> str: + """Return an ISO-8601 UTC timestamp for log records.""" + + return datetime.now(UTC).isoformat() + + def _sanitize_name(name: str) -> str: """Sanitize a name for use in directory/file names.""" # Replace problematic characters with underscores @@ -90,6 +97,9 @@ class TransactionContext: model: str """Model name (sanitized for filesystem use).""" + trace_enabled: bool = False + """Whether provider loggers should append transform trace entries.""" + class TransactionLogger: """ @@ -116,6 +126,7 @@ class TransactionLogger: "api_format", "_dir_available", "_context", + "_trace_writer", ) def __init__( @@ -153,6 +164,7 @@ def __init__( self.log_dir: Optional[Path] = None self._dir_available = False self._context: Optional[TransactionContext] = None + self._trace_writer: Optional[TransformTraceWriter] = None if not enabled: return @@ -174,6 +186,13 @@ def __init__( try: self.log_dir.mkdir(parents=True, exist_ok=True) self._dir_available = True + self._trace_writer = TransformTraceWriter( + self.log_dir, + component="client", + provider=provider, + model=self.model, + enabled=True, + ) except Exception as e: lib_logger.error(f"TransactionLogger: Failed to create directory: {e}") self.enabled = False @@ -192,9 +211,41 @@ def get_context(self) -> TransactionContext: enabled=self.enabled, provider=self.provider, model=self.model, + trace_enabled=bool(self._trace_writer), ) return self._context + def log_transform_pass( + self, + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + protocol: Optional[str] = None, + credential_id: Optional[str] = None, + transport: Optional[str] = None, + changed_from_previous: Optional[bool] = None, + metadata: Optional[Dict[str, Any]] = None, + snapshot: bool = True, + ) -> None: + """Record an additive transform trace entry if tracing is available.""" + + if not self.enabled or not self._dir_available or not self._trace_writer: + return + self._trace_writer.record( + pass_name, + data, + direction=direction, + stage=stage, + protocol=protocol, + credential_id=credential_id, + transport=transport, + changed_from_previous=changed_from_previous, + metadata=metadata, + snapshot=snapshot, + ) + def log_request( self, request_data: Dict[str, Any], filename: str = "request.json" ) -> None: @@ -212,9 +263,16 @@ def log_request( data = { "request_id": self.request_id, - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "data": request_data, } + self.log_transform_pass( + "raw_client_request", + request_data, + direction="request", + stage="client", + transport="sse" if self.streaming else "http", + ) self._write_json(filename, data) def log_transformed_request( @@ -239,18 +297,30 @@ def log_transformed_request( stripped_transformed = _strip_framework_keys(transformed_data) stripped_original = _strip_framework_keys(original_data) + changed_from_previous: Optional[bool] = None try: - if json.dumps(stripped_transformed, sort_keys=True, default=str) == json.dumps( + changed_from_previous = json.dumps(stripped_transformed, sort_keys=True, default=str) != json.dumps( stripped_original, sort_keys=True, default=str - ): - return + ) except (TypeError, ValueError): - pass + changed_from_previous = None + + self.log_transform_pass( + "prepared_provider_request", + transformed_data, + direction="request", + stage="client", + transport="sse" if transformed_data.get("stream") else "http", + changed_from_previous=changed_from_previous, + ) + + if changed_from_previous is False: + return logged = _strip_framework_keys(transformed_data) data = { "request_id": self.request_id, - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "data": logged, } self._write_json("request_transformed.json", data) @@ -266,9 +336,17 @@ def log_stream_chunk(self, chunk: Dict[str, Any]) -> None: return log_entry = { - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "chunk": chunk, } + self.log_transform_pass( + "parsed_stream_chunk", + chunk, + direction="stream", + stage="client", + transport="sse", + snapshot=False, + ) content = json.dumps(log_entry, ensure_ascii=False) + "\n" self._append_text("streaming_chunks.jsonl", content) @@ -296,12 +374,20 @@ def log_response( data = { "request_id": self.request_id, - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "status_code": status_code, "duration_ms": round(duration_ms), "headers": dict(headers) if headers else None, "data": response_data, } + self.log_transform_pass( + "final_client_response", + response_data, + direction="response", + stage="final", + transport="sse" if self.streaming else "http", + metadata={"status_code": status_code, "headers": dict(headers) if headers else None}, + ) self._write_json(filename, data) # Also write metadata @@ -331,7 +417,7 @@ def _log_metadata( metadata = { "request_id": self.request_id, - "timestamp_utc": datetime.utcnow().isoformat(), + "timestamp_utc": _utc_timestamp(), "duration_ms": round(duration_ms), "status_code": status_code, "provider": self.provider, @@ -543,7 +629,7 @@ class ProviderLogger: or add custom methods for provider-specific logging needs. """ - __slots__ = ("enabled", "log_dir") + __slots__ = ("enabled", "log_dir", "_trace_writer") def __init__(self, context: Optional[TransactionContext]): """ @@ -554,6 +640,7 @@ def __init__(self, context: Optional[TransactionContext]): """ self.enabled = False self.log_dir: Optional[Path] = None + self._trace_writer: Optional[TransformTraceWriter] = None if context is None or not context.enabled: return @@ -563,10 +650,39 @@ def __init__(self, context: Optional[TransactionContext]): try: self.log_dir.mkdir(parents=True, exist_ok=True) + if getattr(context, "trace_enabled", False): + self._trace_writer = TransformTraceWriter( + context.log_dir, + component="provider", + provider=context.provider, + model=context.model, + enabled=True, + ) except Exception as e: lib_logger.error(f"ProviderLogger: Failed to create directory: {e}") self.enabled = False + def _log_transform_pass( + self, + pass_name: str, + data: Any, + *, + direction: str, + stage: str = "provider", + transport: Optional[str] = None, + snapshot: bool = True, + ) -> None: + if not self.enabled or not self._trace_writer: + return + self._trace_writer.record( + pass_name, + data, + direction=direction, + stage=stage, + transport=transport, + snapshot=snapshot, + ) + def log_request(self, payload: Dict[str, Any]) -> None: """ Log the request payload sent to the provider API. @@ -574,6 +690,12 @@ def log_request(self, payload: Dict[str, Any]) -> None: Args: payload: The transformed request payload """ + self._log_transform_pass( + "provider_request_payload", + payload, + direction="request", + transport="http", + ) self._write_json("request_payload.json", payload) def log_response_chunk(self, chunk: str) -> None: @@ -583,6 +705,13 @@ def log_response_chunk(self, chunk: str) -> None: Args: chunk: Raw chunk string from the stream """ + self._log_transform_pass( + "provider_raw_stream_chunk", + chunk, + direction="stream", + transport="sse", + snapshot=False, + ) self._append_text("response_stream.log", chunk + "\n") def log_final_response(self, response_data: Dict[str, Any]) -> None: @@ -592,6 +721,11 @@ def log_final_response(self, response_data: Dict[str, Any]) -> None: Args: response_data: The complete response data """ + self._log_transform_pass( + "provider_final_response", + response_data, + direction="response", + ) self._write_json("final_response.json", response_data) def log_error(self, error_message: str) -> None: @@ -601,7 +735,13 @@ def log_error(self, error_message: str) -> None: Args: error_message: The error message to log """ - timestamp = datetime.utcnow().isoformat() + timestamp = _utc_timestamp() + self._log_transform_pass( + "provider_error", + {"timestamp_utc": timestamp, "message": error_message}, + direction="error", + snapshot=False, + ) self._append_text("error.log", f"[{timestamp}] {error_message}\n") def log_extra(self, filename: str, data: Union[Dict[str, Any], str]) -> None: diff --git a/tests/test_transaction_logger_transform_trace.py b/tests/test_transaction_logger_transform_trace.py new file mode 100644 index 000000000..93dd75927 --- /dev/null +++ b/tests/test_transaction_logger_transform_trace.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.client.executor import RequestExecutor +from rotator_library.transaction_logger import ProviderLogger, TransactionLogger +from rotator_library.transform_trace import REDACTED + + +def _trace_entries(log_dir): + trace_file = log_dir / "transform_trace.jsonl" + return [json.loads(line) for line in trace_file.read_text(encoding="utf-8").splitlines()] + + +def test_log_request_writes_legacy_file_and_raw_trace(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + logger.log_request({"model": "gpt-test", "api_key": "secret", "messages": [{"role": "user", "content": "hi"}]}) + + assert (logger.log_dir / "request.json").exists() + entries = _trace_entries(logger.log_dir) + assert entries[0]["pass_name"] == "raw_client_request" + assert entries[0]["direction"] == "request" + assert entries[0]["data"]["api_key"] == REDACTED + assert (logger.log_dir / "transforms" / "0001_raw_client_request.json").exists() + + +def test_log_transformed_request_records_trace_even_when_legacy_file_skips(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + request = {"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]} + + logger.log_transformed_request(request, dict(request)) + + entries = _trace_entries(logger.log_dir) + assert entries[0]["pass_name"] == "prepared_provider_request" + assert entries[0]["changed_from_previous"] is False + assert not (logger.log_dir / "request_transformed.json").exists() + + +def test_log_response_and_stream_chunk_write_trace_entries(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + logger.log_request({"model": "gpt-test", "stream": True}) + + logger.log_stream_chunk({"choices": [{"delta": {"content": "Hi"}}]}) + logger.log_response({"model": "gpt-test", "choices": [], "usage": {"total_tokens": 2}}) + + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert "parsed_stream_chunk" in pass_names + assert "final_client_response" in pass_names + stream_entry = next(entry for entry in entries if entry["pass_name"] == "parsed_stream_chunk") + assert stream_entry["direction"] == "stream" + assert not (logger.log_dir / "transforms" / "0002_parsed_stream_chunk.json").exists() + + +def test_provider_logger_writes_provider_trace_entries(tmp_path) -> None: + logger = TransactionLogger("gemini_cli", "gemini_cli/gemini-test", parent_dir=tmp_path) + provider_logger = ProviderLogger(logger.get_context()) + + provider_logger.log_request({"credential_identifier": "secret", "body": {"text": "hi"}}) + provider_logger.log_response_chunk("data: chunk") + provider_logger.log_final_response({"ok": True}) + provider_logger.log_error("provider failed") + + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names == [ + "provider_request_payload", + "provider_raw_stream_chunk", + "provider_final_response", + "provider_error", + ] + assert entries[0]["component"] == "provider" + assert entries[0]["data"]["credential_identifier"] == REDACTED + assert (logger.log_dir / "provider" / "request_payload.json").exists() + + +@pytest.mark.asyncio +async def test_stream_wrapper_records_raw_parsed_and_assembled_trace(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + executor = RequestExecutor.__new__(RequestExecutor) + + async def stream(): + yield 'data: {"id":"chunk_1","choices":[{"delta":{"content":"Hi"},"finish_reason":null}]}\n\n' + yield 'data: {"choices":[{"delta":{},"finish_reason":"stop"}]}\n\n' + yield "data: [DONE]\n\n" + + chunks = [chunk async for chunk in executor._transaction_logging_stream_wrapper(stream(), logger, {})] + + assert chunks[-1] == "data: [DONE]\n\n" + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names.count("raw_stream_chunk") == 3 + assert pass_names.count("parsed_stream_chunk") == 2 + assert "assembled_stream_response" in pass_names + assert "final_client_response" in pass_names + + +def test_transaction_logger_disabled_writes_no_trace(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", enabled=False, parent_dir=tmp_path) + + logger.log_request({"model": "gpt-test"}) + logger.log_response({"model": "gpt-test"}) + + assert logger.log_dir is None + assert not list(tmp_path.iterdir()) From f20162623438c715c77b8d76b24e1217b07180ef Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:40:27 +0200 Subject: [PATCH 012/182] fix(logging): harden transform trace correlation Hardens Phase 2 tracing after review by adding request, session, scope, classifier, exact model, and credential correlation to trace entries where available. Expands redaction for cookies and credential-bearing headers, extracts structured fields from SDK-like objects before repr fallback, scrubs header-like secrets from provider error text, and adds a standardized transform_log_error helper. Prevents provider snapshot collisions by namespacing provider writer snapshots while keeping stream chunks in JSONL only. Tests: python -m pytest tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py Tests: python -m pytest tests/test_session_tracking.py tests/test_selection_engine.py --- src/rotator_library/client/executor.py | 18 ++- src/rotator_library/client/request_builder.py | 6 + src/rotator_library/transaction_logger.py | 103 ++++++++++++- src/rotator_library/transform_trace.py | 138 ++++++++++++++++-- ...test_transaction_logger_transform_trace.py | 36 ++++- tests/test_transform_trace.py | 44 +++++- 6 files changed, 327 insertions(+), 18 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index cbcd8a570..589a6c33e 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -608,7 +608,14 @@ async def _execute_non_streaming( # Log transformed request if it differs from original if context.transaction_logger: context.transaction_logger.log_transformed_request( - kwargs, context.kwargs + kwargs, + context.kwargs, + credential_id=cred_context.stable_id, + metadata={ + "session_id": context.session_id, + "scope_key": context.usage_manager_key, + "classifier": context.classifier, + }, ) # Get provider plugin @@ -824,7 +831,14 @@ async def _execute_streaming( # Log transformed request if it differs from original if context.transaction_logger: context.transaction_logger.log_transformed_request( - kwargs, context.kwargs + kwargs, + context.kwargs, + credential_id=cred_context.stable_id, + metadata={ + "session_id": context.session_id, + "scope_key": context.usage_manager_key, + "classifier": context.classifier, + }, ) # Add stream usage metadata for active providers. diff --git a/src/rotator_library/client/request_builder.py b/src/rotator_library/client/request_builder.py index c9455ec23..7fb6573b4 100644 --- a/src/rotator_library/client/request_builder.py +++ b/src/rotator_library/client/request_builder.py @@ -133,6 +133,12 @@ async def build_completion_context( scope_key=scope["usage_manager_key"], hints=await self._get_session_hints(provider, resolved_model, kwargs), ) + if transaction_logger: + transaction_logger.set_trace_context( + session_id=session.session_id, + scope_key=scope["usage_manager_key"], + classifier=scope["classifier"], + ) return RequestContext( model=resolved_model, diff --git a/src/rotator_library/transaction_logger.py b/src/rotator_library/transaction_logger.py index 9f77b9821..5a74e5a43 100644 --- a/src/rotator_library/transaction_logger.py +++ b/src/rotator_library/transaction_logger.py @@ -33,7 +33,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Union -from .transform_trace import TransformTraceWriter +from .transform_trace import TransformTraceWriter, provider_snapshot_namespace from .utils.paths import get_logs_dir lib_logger = logging.getLogger("rotator_library") @@ -97,6 +97,18 @@ class TransactionContext: model: str """Model name (sanitized for filesystem use).""" + trace_model: Optional[str] = None + """Exact model name used in transform trace entries.""" + + session_id: Optional[str] = None + """Inferred session id for trace correlation, when available.""" + + scope_key: Optional[str] = None + """Usage scope key for trace correlation, when available.""" + + classifier: Optional[str] = None + """Classifier/private routing label for trace correlation, when available.""" + trace_enabled: bool = False """Whether provider loggers should append transform trace entries.""" @@ -122,6 +134,10 @@ class TransactionLogger: "request_id", "provider", "model", + "trace_model", + "session_id", + "scope_key", + "classifier", "streaming", "api_format", "_dir_available", @@ -151,6 +167,10 @@ def __init__( self.start_time = time.time() self.request_id = str(uuid.uuid4())[:8] # 8-char short ID self.provider = provider + self.trace_model = model + self.session_id: Optional[str] = None + self.scope_key: Optional[str] = None + self.classifier: Optional[str] = None self.api_format = api_format # Strip provider prefix from model if present @@ -190,7 +210,8 @@ def __init__( self.log_dir, component="client", provider=provider, - model=self.model, + model=self.trace_model, + request_id=self.request_id, enabled=True, ) except Exception as e: @@ -211,10 +232,40 @@ def get_context(self) -> TransactionContext: enabled=self.enabled, provider=self.provider, model=self.model, + trace_model=self.trace_model, + session_id=self.session_id, + scope_key=self.scope_key, + classifier=self.classifier, trace_enabled=bool(self._trace_writer), ) return self._context + def set_trace_context( + self, + *, + session_id: Optional[str] = None, + scope_key: Optional[str] = None, + classifier: Optional[str] = None, + ) -> None: + """Attach routing/session metadata discovered after logger creation.""" + + if session_id is not None: + self.session_id = session_id + if scope_key is not None: + self.scope_key = scope_key + if classifier is not None: + self.classifier = classifier + if self._trace_writer: + self._trace_writer.update_context( + session_id=self.session_id, + scope_key=self.scope_key, + classifier=self.classifier, + ) + if self._context: + self._context.session_id = self.session_id + self._context.scope_key = self.scope_key + self._context.classifier = self.classifier + def log_transform_pass( self, pass_name: str, @@ -227,6 +278,7 @@ def log_transform_pass( transport: Optional[str] = None, changed_from_previous: Optional[bool] = None, metadata: Optional[Dict[str, Any]] = None, + scrub_strings: bool = False, snapshot: bool = True, ) -> None: """Record an additive transform trace entry if tracing is available.""" @@ -243,9 +295,41 @@ def log_transform_pass( transport=transport, changed_from_previous=changed_from_previous, metadata=metadata, + scrub_strings=scrub_strings, snapshot=snapshot, ) + def log_transform_error( + self, + failed_pass_name: str, + error: BaseException, + *, + payload: Any = None, + stage: str = "client", + protocol: Optional[str] = None, + transport: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Record a standardized transform/logging failure without raising.""" + + error_data = { + "failed_pass_name": failed_pass_name, + "error_type": type(error).__name__, + "message": str(error), + "payload": payload, + } + self.log_transform_pass( + "transform_log_error", + error_data, + direction="error", + stage=stage, + protocol=protocol, + transport=transport, + metadata=metadata, + scrub_strings=True, + snapshot=False, + ) + def log_request( self, request_data: Dict[str, Any], filename: str = "request.json" ) -> None: @@ -279,6 +363,9 @@ def log_transformed_request( self, transformed_data: Dict[str, Any], original_data: Dict[str, Any], + *, + credential_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> None: """ Log the transformed request if it differs from the original. @@ -310,8 +397,10 @@ def log_transformed_request( transformed_data, direction="request", stage="client", + credential_id=credential_id, transport="sse" if transformed_data.get("stream") else "http", changed_from_previous=changed_from_previous, + metadata=metadata, ) if changed_from_previous is False: @@ -655,7 +744,12 @@ def __init__(self, context: Optional[TransactionContext]): context.log_dir, component="provider", provider=context.provider, - model=context.model, + model=context.trace_model or context.model, + request_id=context.request_id, + session_id=context.session_id, + scope_key=context.scope_key, + classifier=context.classifier, + snapshot_namespace=provider_snapshot_namespace(), enabled=True, ) except Exception as e: @@ -670,6 +764,7 @@ def _log_transform_pass( direction: str, stage: str = "provider", transport: Optional[str] = None, + scrub_strings: bool = False, snapshot: bool = True, ) -> None: if not self.enabled or not self._trace_writer: @@ -680,6 +775,7 @@ def _log_transform_pass( direction=direction, stage=stage, transport=transport, + scrub_strings=scrub_strings, snapshot=snapshot, ) @@ -740,6 +836,7 @@ def log_error(self, error_message: str) -> None: "provider_error", {"timestamp_utc": timestamp, "message": error_message}, direction="error", + scrub_strings=True, snapshot=False, ) self._append_text("error.log", f"[{timestamp}] {error_message}\n") diff --git a/src/rotator_library/transform_trace.py b/src/rotator_library/transform_trace.py index 55c00c105..00144facb 100644 --- a/src/rotator_library/transform_trace.py +++ b/src/rotator_library/transform_trace.py @@ -12,10 +12,13 @@ import json import logging +import re +import uuid from dataclasses import dataclass, field +from dataclasses import is_dataclass from datetime import UTC, datetime from pathlib import Path -from typing import Any, Optional +from typing import Any, Mapping, Optional from .protocols import serialize_value @@ -28,8 +31,12 @@ "api-key", "credential-identifier", "authorization", + "proxy-authorization", + "cookie", + "set-cookie", "x-api-key", "x-goog-api-key", + "openai-api-key", "access-token", "refresh-token", "client-secret", @@ -39,12 +46,59 @@ } ) +_SENSITIVE_TEXT_RE = re.compile( + r"(?im)\b(authorization|proxy-authorization|x-api-key|x-goog-api-key|openai-api-key|api[_-]?key|access[_-]?token|refresh[_-]?token|client[_-]?secret|cookie|set-cookie)\b(\s*[:=]\s*)([^\r\n]+)" +) + def _normalise_key(key: Any) -> str: return str(key).strip().lower().replace("_", "-") -def sanitize_for_trace(value: Any) -> Any: +def scrub_sensitive_text(value: str) -> str: + """Scrub obvious credential-bearing header/query fragments from text. + + General trace payloads use key-based redaction to avoid hiding model text. + Error strings and object reprs are different: providers often embed HTTP + headers or query strings in exception text, so this targeted scrub only + applies when callers opt into string scrubbing. + """ + + return _SENSITIVE_TEXT_RE.sub(lambda match: f"{match.group(1)}{match.group(2)}{REDACTED}", value) + + +def _object_mapping(value: Any) -> Optional[dict[str, Any]]: + """Extract structured data from common SDK objects before using repr().""" + + for method_name in ("model_dump", "dict"): + method = getattr(value, method_name, None) + if callable(method): + try: + result = method() + if isinstance(result, Mapping): + return dict(result) + except Exception: + pass + structured: dict[str, Any] = {"type": f"{type(value).__module__}.{type(value).__name__}"} + found = False + for attr in ("status_code", "headers", "url", "method", "text", "content"): + if hasattr(value, attr): + try: + structured[attr] = getattr(value, attr) + found = True + except Exception: + pass + if found: + return structured + if hasattr(value, "__dict__") and not is_dataclass(value): + try: + return {"type": structured["type"], "attributes": vars(value)} + except Exception: + return None + return None + + +def sanitize_for_trace(value: Any, *, scrub_strings: bool = False) -> Any: """Return a JSON-safe, recursively redacted value for trace files. Redaction is intentionally key-based rather than value-based. Model text may @@ -53,17 +107,30 @@ def sanitize_for_trace(value: Any) -> Any: only when their key name is known to carry credentials. """ - serialized = serialize_value(value) - if isinstance(serialized, dict): + if isinstance(value, Mapping): sanitized = {} - for key, item in serialized.items(): + for key, item in value.items(): if _normalise_key(key) in _SENSITIVE_KEYS: sanitized[str(key)] = REDACTED else: - sanitized[str(key)] = sanitize_for_trace(item) + sanitized[str(key)] = sanitize_for_trace(item, scrub_strings=scrub_strings) return sanitized + if isinstance(value, (list, tuple, set, frozenset)): + return [sanitize_for_trace(item, scrub_strings=scrub_strings) for item in value] + if isinstance(value, str): + return scrub_sensitive_text(value) if scrub_strings else value + + object_mapping = _object_mapping(value) + if object_mapping is not None: + return sanitize_for_trace(object_mapping, scrub_strings=scrub_strings) + + serialized = serialize_value(value) + if isinstance(serialized, dict): + return sanitize_for_trace(serialized, scrub_strings=scrub_strings) if isinstance(serialized, list): - return [sanitize_for_trace(item) for item in serialized] + return [sanitize_for_trace(item, scrub_strings=scrub_strings) for item in serialized] + if isinstance(serialized, str): + return scrub_sensitive_text(serialized) if scrub_strings else serialized return serialized @@ -90,6 +157,7 @@ class TransformTraceEntry: pass_name: str direction: str stage: str + request_id: Optional[str] = None timestamp_utc: str = field(default_factory=lambda: datetime.now(UTC).isoformat()) protocol: Optional[str] = None provider: Optional[str] = None @@ -97,8 +165,12 @@ class TransformTraceEntry: credential_id: Optional[str] = None transport: Optional[str] = None changed_from_previous: Optional[bool] = None + session_id: Optional[str] = None + scope_key: Optional[str] = None + classifier: Optional[str] = None metadata: dict[str, Any] = field(default_factory=dict) data: Any = None + scrub_strings: bool = False def to_dict(self) -> dict[str, Any]: return { @@ -107,6 +179,7 @@ def to_dict(self) -> dict[str, Any]: "pass_name": self.pass_name, "direction": self.direction, "stage": self.stage, + "request_id": self.request_id, "timestamp_utc": self.timestamp_utc, "protocol": self.protocol, "provider": self.provider, @@ -114,8 +187,11 @@ def to_dict(self) -> dict[str, Any]: "credential_id": self.credential_id, "transport": self.transport, "changed_from_previous": self.changed_from_previous, - "metadata": sanitize_for_trace(self.metadata), - "data": sanitize_for_trace(self.data), + "session_id": self.session_id, + "scope_key": self.scope_key, + "classifier": self.classifier, + "metadata": sanitize_for_trace(self.metadata, scrub_strings=self.scrub_strings), + "data": sanitize_for_trace(self.data, scrub_strings=self.scrub_strings), } @@ -134,17 +210,43 @@ def __init__( component: str, provider: Optional[str] = None, model: Optional[str] = None, + request_id: Optional[str] = None, + session_id: Optional[str] = None, + scope_key: Optional[str] = None, + classifier: Optional[str] = None, + snapshot_namespace: Optional[str] = None, enabled: bool = True, ) -> None: self.log_dir = log_dir self.component = component self.provider = provider self.model = model + self.request_id = request_id + self.session_id = session_id + self.scope_key = scope_key + self.classifier = classifier + self.snapshot_namespace = snapshot_namespace self.enabled = enabled self._sequence = 0 self.trace_file = log_dir / "transform_trace.jsonl" self.snapshot_dir = log_dir / "transforms" + def update_context( + self, + *, + session_id: Optional[str] = None, + scope_key: Optional[str] = None, + classifier: Optional[str] = None, + ) -> None: + """Update immutable-ish correlation fields discovered after creation.""" + + if session_id is not None: + self.session_id = session_id + if scope_key is not None: + self.scope_key = scope_key + if classifier is not None: + self.classifier = classifier + def record( self, pass_name: str, @@ -156,7 +258,11 @@ def record( credential_id: Optional[str] = None, transport: Optional[str] = None, changed_from_previous: Optional[bool] = None, + session_id: Optional[str] = None, + scope_key: Optional[str] = None, + classifier: Optional[str] = None, metadata: Optional[dict[str, Any]] = None, + scrub_strings: bool = False, snapshot: bool = True, ) -> Optional[TransformTraceEntry]: """Record a transform pass, swallowing logging failures.""" @@ -170,14 +276,19 @@ def record( pass_name=pass_name, direction=direction, stage=stage, + request_id=self.request_id, protocol=protocol, provider=self.provider, model=self.model, credential_id=credential_id, transport=transport, changed_from_previous=changed_from_previous, + session_id=session_id if session_id is not None else self.session_id, + scope_key=scope_key if scope_key is not None else self.scope_key, + classifier=classifier if classifier is not None else self.classifier, metadata=metadata or {}, data=data, + scrub_strings=scrub_strings, ) try: self.log_dir.mkdir(parents=True, exist_ok=True) @@ -185,9 +296,16 @@ def record( handle.write(json.dumps(entry.to_dict(), ensure_ascii=False) + "\n") if snapshot and direction != "stream": self.snapshot_dir.mkdir(parents=True, exist_ok=True) - snapshot_name = f"{entry.sequence:04d}_{sanitize_filename(pass_name)}.json" + namespace = f"{sanitize_filename(self.snapshot_namespace)}_" if self.snapshot_namespace else "" + snapshot_name = f"{entry.sequence:04d}_{namespace}{sanitize_filename(pass_name)}.json" with open(self.snapshot_dir / snapshot_name, "w", encoding="utf-8") as handle: json.dump(entry.to_dict(), handle, indent=2, ensure_ascii=False) except Exception as exc: lib_logger.debug("Transform trace write failed for %s: %s", pass_name, exc) return entry + + +def provider_snapshot_namespace() -> str: + """Return a short namespace that prevents provider snapshot collisions.""" + + return f"provider_{uuid.uuid4().hex[:8]}" diff --git a/tests/test_transaction_logger_transform_trace.py b/tests/test_transaction_logger_transform_trace.py index 93dd75927..956b23ba3 100644 --- a/tests/test_transaction_logger_transform_trace.py +++ b/tests/test_transaction_logger_transform_trace.py @@ -22,6 +22,7 @@ def test_log_request_writes_legacy_file_and_raw_trace(tmp_path) -> None: assert (logger.log_dir / "request.json").exists() entries = _trace_entries(logger.log_dir) assert entries[0]["pass_name"] == "raw_client_request" + assert entries[0]["request_id"] == logger.request_id assert entries[0]["direction"] == "request" assert entries[0]["data"]["api_key"] == REDACTED assert (logger.log_dir / "transforms" / "0001_raw_client_request.json").exists() @@ -29,13 +30,18 @@ def test_log_request_writes_legacy_file_and_raw_trace(tmp_path) -> None: def test_log_transformed_request_records_trace_even_when_legacy_file_skips(tmp_path) -> None: logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + logger.set_trace_context(session_id="session_1", scope_key="scope_1", classifier="class_a") request = {"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]} - logger.log_transformed_request(request, dict(request)) + logger.log_transformed_request(request, dict(request), credential_id="cred_1") entries = _trace_entries(logger.log_dir) assert entries[0]["pass_name"] == "prepared_provider_request" assert entries[0]["changed_from_previous"] is False + assert entries[0]["credential_id"] == "cred_1" + assert entries[0]["session_id"] == "session_1" + assert entries[0]["scope_key"] == "scope_1" + assert entries[0]["classifier"] == "class_a" assert not (logger.log_dir / "request_transformed.json").exists() @@ -73,6 +79,7 @@ def test_provider_logger_writes_provider_trace_entries(tmp_path) -> None: "provider_error", ] assert entries[0]["component"] == "provider" + assert entries[0]["request_id"] == logger.request_id assert entries[0]["data"]["credential_identifier"] == REDACTED assert (logger.log_dir / "provider" / "request_payload.json").exists() @@ -106,3 +113,30 @@ def test_transaction_logger_disabled_writes_no_trace(tmp_path) -> None: assert logger.log_dir is None assert not list(tmp_path.iterdir()) + + +def test_provider_error_trace_scrubs_header_like_secret_text(tmp_path) -> None: + logger = TransactionLogger("gemini_cli", "gemini_cli/gemini-test", parent_dir=tmp_path) + provider_logger = ProviderLogger(logger.get_context()) + + provider_logger.log_error("upstream failed Authorization: Bearer secret-token") + + entries = _trace_entries(logger.log_dir) + assert entries[0]["pass_name"] == "provider_error" + assert "secret-token" not in entries[0]["data"]["message"] + assert "[REDACTED]" in entries[0]["data"]["message"] + + +def test_log_transform_error_uses_standard_shape_and_scrubs_text(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + logger.log_transform_error( + "after_field_cache_injection", + RuntimeError("bad Authorization: Bearer secret"), + payload={"cookie": "sid=secret"}, + ) + + entries = _trace_entries(logger.log_dir) + assert entries[0]["pass_name"] == "transform_log_error" + assert entries[0]["data"]["failed_pass_name"] == "after_field_cache_injection" + assert "secret" not in json.dumps(entries[0]["data"]) diff --git a/tests/test_transform_trace.py b/tests/test_transform_trace.py index f24a1b4cb..928be5720 100644 --- a/tests/test_transform_trace.py +++ b/tests/test_transform_trace.py @@ -10,6 +10,7 @@ TransformTraceWriter, sanitize_filename, sanitize_for_trace, + scrub_sensitive_text, ) @@ -22,7 +23,7 @@ class ExamplePayload: def test_sanitize_for_trace_redacts_sensitive_keys_recursively() -> None: payload = { "api_key": "secret-key", - "headers": {"Authorization": "Bearer secret", "normal": "token in normal text"}, + "headers": {"Authorization": "Bearer secret", "Cookie": "sid=secret", "normal": "token in normal text"}, "items": [{"refresh_token": "refresh", "text": "token should remain in value"}], } @@ -30,11 +31,38 @@ def test_sanitize_for_trace_redacts_sensitive_keys_recursively() -> None: assert sanitized["api_key"] == REDACTED assert sanitized["headers"]["Authorization"] == REDACTED + assert sanitized["headers"]["Cookie"] == REDACTED assert sanitized["headers"]["normal"] == "token in normal text" assert sanitized["items"][0]["refresh_token"] == REDACTED assert sanitized["items"][0]["text"] == "token should remain in value" +def test_scrub_sensitive_text_targets_header_like_fragments_only() -> None: + text = "normal token text remains\nAuthorization: Bearer abc123\nset-cookie: sid=secret" + + scrubbed = scrub_sensitive_text(text) + + assert "normal token text remains" in scrubbed + assert "Authorization: [REDACTED]" in scrubbed + assert "set-cookie: [REDACTED]" in scrubbed + + +def test_sanitize_for_trace_extracts_sdk_like_objects_before_repr() -> None: + class SdkError: + def __init__(self) -> None: + self.status_code = 401 + self.headers = {"Authorization": "Bearer secret"} + + def __repr__(self) -> str: + return "SdkError(Authorization: Bearer leaked)" + + sanitized = sanitize_for_trace(SdkError(), scrub_strings=True) + + assert sanitized["status_code"] == 401 + assert sanitized["headers"]["Authorization"] == REDACTED + assert "leaked" not in json.dumps(sanitized) + + def test_sanitize_for_trace_serializes_common_non_json_values() -> None: payload = ExamplePayload(when=datetime(2026, 1, 2, 3, 4, 5), amount=Decimal("1.5")) @@ -52,7 +80,7 @@ def test_sanitize_filename_is_stable_and_filesystem_safe() -> None: def test_transform_trace_writer_records_jsonl_and_snapshots(tmp_path) -> None: - writer = TransformTraceWriter(tmp_path, component="client", provider="openai", model="gpt-test") + writer = TransformTraceWriter(tmp_path, component="client", provider="openai", model="gpt-test", request_id="req_1") first = writer.record("raw_client_request", {"model": "gpt-test"}, direction="request", stage="client") second = writer.record("parsed_stream_chunk", {"delta": "hi"}, direction="stream", stage="client", snapshot=True) @@ -65,6 +93,7 @@ def test_transform_trace_writer_records_jsonl_and_snapshots(tmp_path) -> None: lines = (tmp_path / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines() assert len(lines) == 2 assert json.loads(lines[0])["pass_name"] == "raw_client_request" + assert json.loads(lines[0])["request_id"] == "req_1" assert json.loads(lines[1])["direction"] == "stream" assert (tmp_path / "transforms" / "0001_raw_client_request.json").exists() assert not (tmp_path / "transforms" / "0002_parsed_stream_chunk.json").exists() @@ -75,3 +104,14 @@ def test_transform_trace_writer_disabled_writes_nothing(tmp_path) -> None: assert writer.record("raw_client_request", {}, direction="request", stage="client") is None assert not (tmp_path / "transform_trace.jsonl").exists() + + +def test_transform_trace_writer_snapshot_namespace_prevents_collisions(tmp_path) -> None: + first = TransformTraceWriter(tmp_path, component="provider", snapshot_namespace="provider_a") + second = TransformTraceWriter(tmp_path, component="provider", snapshot_namespace="provider_b") + + first.record("provider_request_payload", {"a": 1}, direction="request", stage="provider") + second.record("provider_request_payload", {"b": 2}, direction="request", stage="provider") + + assert (tmp_path / "transforms" / "0001_provider_a_provider_request_payload.json").exists() + assert (tmp_path / "transforms" / "0001_provider_b_provider_request_payload.json").exists() From ef3cc1960060c0e358ea452f5ef99c89ad7f505a Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:49:21 +0200 Subject: [PATCH 013/182] docs(experimental): plan adapters and field cache Adds the Phase 3 plan for the adapter registry, built-in adapter bases, field-cache rule schema, path engine, store abstractions, scoped key behavior, transform trace integration, tests, risks, and review checkpoints. Reports remain uncommitted for user-facing review only. --- .../phase-3-adapters-field-cache.md | 284 ++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 docs/experimental/phase-3-adapters-field-cache.md diff --git a/docs/experimental/phase-3-adapters-field-cache.md b/docs/experimental/phase-3-adapters-field-cache.md new file mode 100644 index 000000000..d7fbf53cf --- /dev/null +++ b/docs/experimental/phase-3-adapters-field-cache.md @@ -0,0 +1,284 @@ +# Phase 3 Plan: Adapter And Field-Cache System + +## Goal + +Add the configurable adapter and field-cache foundation that lets protocol/provider implementations preserve and re-inject provider-specific state without hardcoding every provider. This phase should make it possible to define rules for values like reasoning content, thought signatures, prompt cache keys, provider session IDs, and response IDs, with strict provider/model/credential/session/classifier scoping and transform trace visibility. + +## Non-Goals + +- Do not migrate all existing providers to native protocols yet. +- Do not add `/v1/responses` routes yet. +- Do not replace `ProviderTransforms` yet. +- Do not implement JSON config loading for these rules yet; Phase 10 owns config polish. +- Do not instantiate SQLite or any DB. +- Do not implement provider-specific Claude/Codex/Copilot/Antigravity behavior yet. +- Do not make field cache a replacement for `SessionTracker`. + +## Current Code Context + +- `ProviderTransforms` is still the active runtime transform path and should remain untouched unless a small hook is needed. +- Phase 1 provides unified request/response/stream dataclasses and protocol adapters. +- Phase 2 provides `log_transform_pass()` and `log_transform_error()` for request/response/stream trace states. +- `ProviderCache` already exists and supports async `store_async()` / `retrieve_async()` with memory+disk TTL behavior, but it stores strings and starts async background tasks on construction. +- `RequestContext` has session, scope, classifier, provider, model, and credential metadata needed for cache isolation. +- Existing runtime behavior should stay stable while the new adapter/cache foundation is built and tested in isolation. + +## Files To Add + +- `src/rotator_library/adapters/__init__.py` +- `src/rotator_library/adapters/base.py` +- `src/rotator_library/adapters/registry.py` +- `src/rotator_library/adapters/builtin.py` +- `src/rotator_library/field_cache/__init__.py` +- `src/rotator_library/field_cache/types.py` +- `src/rotator_library/field_cache/paths.py` +- `src/rotator_library/field_cache/store.py` +- `src/rotator_library/field_cache/engine.py` +- `tests/test_adapter_registry.py` +- `tests/test_field_cache_paths.py` +- `tests/test_field_cache_engine.py` +- `tests/test_field_cache_trace.py` + +## Files Possibly Touched + +- `src/rotator_library/providers/provider_interface.py` to add optional provider declarations for adapter names and field-cache rules. +- `src/rotator_library/client/transforms.py` only if adding a no-op future hook is cleaner. +- `src/rotator_library/__init__.py` only if public lazy exports are useful. +- Avoid `RequestExecutor` runtime wiring in the first adapter/cache commits unless tests prove the hook is harmless and necessary. + +## Adapter System + +- Create a protocol-neutral `Adapter` base class. +- Adapters operate on raw dicts, unified dataclasses, or stream events depending on `supported_stages`. +- Adapter methods are override-friendly: + - `transform_request(payload, context) -> payload` + - `transform_response(payload, context) -> payload` + - `transform_stream_event(payload, context) -> payload` +- Adapter context should include provider, model, protocol, credential_id, session_id, scope_key, classifier, transport, metadata, and optional `transaction_logger`. +- Registry should support: + - explicit registration + - alias resolution + - auto-discovery from `src/rotator_library/adapters/` + - duplicate/collision checks similar to protocol registry + - order-preserving adapter chain resolution + +Built-in adapters: + +- `reasoning_rewrite`: copy/rename reasoning fields between configured paths. +- `reasoning_content`: normalize common reasoning content field names. +- `suppress_developer_role`: convert or remove developer-role messages for providers that reject it. +- `model_override`: replace model in outbound payload. +- `field_rename`: generic configured source-path to target-path copy/move. + +These adapters are bases, not gospel. Providers can subclass/copy/mutate adapters or override provider methods. + +## Field-Cache Rules + +`FieldCacheRule` fields: + +- `name` +- `source`: `request`, `response`, `stream_event`, `unified_request`, `unified_response`, `unified_stream_event` +- `path`: JSON-path-like selector +- `mode`: `last`, `all`, `last_user_turn`, `last_assistant_turn`, `per_tool_call` +- `scope`: list/tuple of scope dimensions +- `inject`: optional `FieldCacheInjection` +- `enabled` +- `ttl_seconds` optional, for future store implementations +- `metadata` for provider/model notes + +`FieldCacheInjection` fields: + +- `target`: `request`, `unified_request`, `metadata`, or provider-specific future target +- `path` +- `when_missing_only` +- `insert` +- `as_list` + +Scope dimensions: + +- `provider` +- `model` +- `credential` +- `session` +- `conversation` +- `classifier` + +Default scope for conversation-affecting rules: + +- provider + model + classifier + session. + +Cache keys must include rule name and selected scope values. Missing optional values should use stable placeholders only where safe; for session-scoped rules with no session, default to no-op unless a rule explicitly allows fallback. + +## Path Engine + +Support dotted paths: + +- `choices.0.message.reasoning_content` +- `choices.*.message.reasoning_content` +- `messages[-1].reasoning_content` +- `candidates.*.content.parts.*.thoughtSignature` + +Behavior: + +- Supports dict keys, list indices, `*`, and `[-1]`. +- Extraction returns all matches in stable traversal order. +- Injection creates missing dict containers when unambiguous. +- Injection does not create containers through wildcard paths. +- Malformed paths raise `FieldCachePathError` with useful messages. +- Missing paths are no-op, not errors. + +## Store Plan + +- `FieldCacheStore` protocol/interface with async `get`, `set`, `append`, and `clear`. +- `InMemoryFieldCacheStore` for tests and simple runtime. +- `ProviderCacheFieldStore` wrapper around an injected `ProviderCache`, JSON serializing values. +- Do not instantiate `ProviderCache` in module import paths because it starts async tasks. +- No SQLite. + +## Engine Plan + +`FieldCacheEngine` responsibilities: + +- hold rules and store +- validate rule names and paths +- `extract(source, payload, context, transaction_logger=None)` +- `inject(target, payload, context, transaction_logger=None)` + +Extraction: + +- select rules matching source +- extract values by path +- apply mode +- store under scoped key +- trace `after_field_cache_extraction` with rule name, match count, mode, scope key, and sanitized values summary + +Injection: + +- select rules with matching injection target +- retrieve scoped values +- apply mode result +- inject into a payload copy by path +- trace `after_field_cache_injection` with rule name, hit/miss, target path, and changed flag + +Engine rules: + +- Do not mutate caller payload unless explicitly requested; default returns a deep copy. +- Failures should call `log_transform_error()` with failed rule/pass and then raise validation errors in tests. Runtime integration later can choose fail-open or fail-closed per provider. + +## Modes + +- `last`: store latest extracted value. +- `all`: append extracted values preserving order. +- `last_user_turn`: use metadata/turn marker when available; for raw messages fallback to latest user-associated match if path includes messages. +- `last_assistant_turn`: same for assistant. +- `per_tool_call`: requires a tool-call ID path or extracted object with obvious `id` / `tool_call_id`; otherwise validation error. + +Phase 3 should implement `last` and `all` fully, and provide validated structure for turn/tool modes with tests for validation and simple cases. If a mode needs runtime conversation indexing not yet available, document it as present but limited. + +## Transform Trace Requirements + +Every adapter invocation should be traceable: + +- `before_adapter_chain` +- `after_adapter` +- `after_adapter_chain` + +Every field-cache injection/extraction should be traceable: + +- `before_field_cache_injection` +- `after_field_cache_injection` +- `before_field_cache_extraction` +- `after_field_cache_extraction` + +Errors use Phase 2 `transform_log_error`. + +Trace metadata should include: + +- adapter name +- rule name +- source/target +- path +- mode +- scope dimensions +- cache key hash or readable scoped key with no secrets +- match count +- changed flag + +Do not log huge cached payloads by default; log summaries and sanitized sample values. + +## Provider Declaration Plan + +Add optional provider methods/class attributes only if needed: + +- `protocol_name` +- `adapter_names` +- `field_cache_rules` +- `get_adapter_config(model)` +- `get_field_cache_rules(model)` + +Keep defaults empty/no-op to avoid provider changes. These declarations are for later provider migration. + +## Tests + +Adapter registry: + +- auto-discovers built-ins +- aliases resolve +- duplicate names/collisions are deterministic +- chain order is preserved +- no-op adapter preserves payload + +Built-in adapters: + +- `model_override` +- `suppress_developer_role` +- `field_rename` +- reasoning field copy/rename + +Path engine: + +- dict path extraction +- list index extraction +- wildcard extraction +- `[-1]` extraction +- injection into simple dict path +- missing path no-op +- wildcard injection rejected +- malformed path error + +Field-cache engine: + +- extract response value and inject into next request +- `last` overwrites +- `all` appends +- scope isolation by provider/model/session/classifier/credential +- missing path no-op +- malformed rule validation +- stream_event extraction +- trace entries emitted for injection/extraction +- transform errors emitted for rule failures + +Regression: + +- Phase 1 protocol tests. +- Phase 2 logging tests. +- core session/selection tests. + +## Commit Checkpoints + +1. Add adapter base/registry and built-in no-op/simple adapters with tests. +2. Add field-cache rule dataclasses and path engine with tests. +3. Add field-cache stores and scoped key builder with tests. +4. Add field-cache engine extraction/injection with trace tests. +5. Add optional provider declaration methods if needed. +6. Run focused and regression tests. +7. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 3 report. + +## Risks And Mitigations + +- JSON path scope could grow too large. Keep Phase 3 path syntax intentionally small and explicit. +- Cache leaks across providers/models/sessions would be severe. Scope-key tests must be extensive. +- Storing huge model outputs can bloat caches. Engine should store only matched values and trace summaries. +- ProviderCache lifecycle can be awkward because it starts async tasks. Wrap injected instances; do not instantiate globally. +- Turn/tool modes can be under-specified. Validate and implement safe subsets rather than guessing. +- Runtime integration can destabilize requests. Keep Phase 3 mostly isolated until review confirms the foundation. From 80f45068a91e8a782848dcf17917a3026ca8da67 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:51:34 +0200 Subject: [PATCH 014/182] feat(adapters): add payload adapter registry Adds the Phase 3 adapter foundation with an override-friendly async base adapter, adapter context, ordered chain runner, auto-discovered registry, aliases, duplicate collision checks, and built-in base adapters for no-op, model override, developer-role suppression, and reasoning content normalization. Runtime request execution is not wired to the adapter chain yet; this checkpoint keeps behavior unchanged while establishing the extension point for native protocols and providers. Tests: python -m pytest tests/test_adapter_registry.py --- src/rotator_library/adapters/__init__.py | 28 +++++ src/rotator_library/adapters/base.py | 136 +++++++++++++++++++++++ src/rotator_library/adapters/builtin.py | 103 +++++++++++++++++ src/rotator_library/adapters/registry.py | 98 ++++++++++++++++ tests/test_adapter_registry.py | 100 +++++++++++++++++ 5 files changed, 465 insertions(+) create mode 100644 src/rotator_library/adapters/__init__.py create mode 100644 src/rotator_library/adapters/base.py create mode 100644 src/rotator_library/adapters/builtin.py create mode 100644 src/rotator_library/adapters/registry.py create mode 100644 tests/test_adapter_registry.py diff --git a/src/rotator_library/adapters/__init__.py b/src/rotator_library/adapters/__init__.py new file mode 100644 index 000000000..0c28ed20c --- /dev/null +++ b/src/rotator_library/adapters/__init__.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Composable payload adapters used between protocols and providers.""" + +from .base import AdapterContext, PayloadAdapter, run_adapter_chain +from .registry import ( + ADAPTER_ALIASES, + ADAPTER_PLUGINS, + get_adapter, + get_adapter_class, + list_adapters, + register_adapter, + resolve_adapter_name, +) + +__all__ = [ + "ADAPTER_ALIASES", + "ADAPTER_PLUGINS", + "AdapterContext", + "PayloadAdapter", + "get_adapter", + "get_adapter_class", + "list_adapters", + "register_adapter", + "resolve_adapter_name", + "run_adapter_chain", +] diff --git a/src/rotator_library/adapters/base.py b/src/rotator_library/adapters/base.py new file mode 100644 index 000000000..9e7caa6da --- /dev/null +++ b/src/rotator_library/adapters/base.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Base classes for configurable protocol/provider payload adapters.""" + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Iterable, Mapping, Optional + + +@dataclass +class AdapterContext: + """Context available to every adapter pass. + + Adapters are intentionally base implementations, not provider law. Provider + code can subclass adapters, copy and mutate them, or bypass them when a + protocol needs special behavior. The context mirrors transform trace fields + so each adapter pass can be debugged by provider/model/session/scope. + """ + + provider: Optional[str] = None + model: Optional[str] = None + protocol: Optional[str] = None + credential_id: Optional[str] = None + session_id: Optional[str] = None + scope_key: Optional[str] = None + classifier: Optional[str] = None + transport: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + adapter_config: dict[str, dict[str, Any]] = field(default_factory=dict) + transaction_logger: Optional[Any] = None + + def config_for(self, adapter_name: str) -> dict[str, Any]: + """Return adapter-specific config without mutating the original context.""" + + return dict(self.adapter_config.get(adapter_name, {})) + + +class PayloadAdapter: + """Override-friendly base adapter. + + The default adapter is a no-op. Concrete adapters override only the stages + they support. Methods are async so future provider adapters can consult + caches, metadata services, or remote capability probes without changing the + chain runner API. + """ + + name: str = "" + aliases: tuple[str, ...] = () + supported_stages: tuple[str, ...] = ("request", "response", "stream_event") + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + return payload + + async def transform_response(self, payload: Any, context: AdapterContext) -> Any: + return payload + + async def transform_stream_event(self, payload: Any, context: AdapterContext) -> Any: + return payload + + async def transform(self, stage: str, payload: Any, context: AdapterContext) -> Any: + """Dispatch a stage-specific transform with a useful error for typos.""" + + if stage not in self.supported_stages: + return payload + if stage == "request": + return await self.transform_request(payload, context) + if stage == "response": + return await self.transform_response(payload, context) + if stage == "stream_event": + return await self.transform_stream_event(payload, context) + raise ValueError(f"Unknown adapter stage: {stage}") + + +def _trace(context: AdapterContext, pass_name: str, payload: Any, *, stage: str, metadata: Mapping[str, Any]) -> None: + logger = context.transaction_logger + if not logger: + return + logger.log_transform_pass( + pass_name, + payload, + direction="stream" if stage == "stream_event" else stage, + stage="adapter", + protocol=context.protocol, + credential_id=context.credential_id, + transport=context.transport, + metadata=dict(metadata), + snapshot=stage != "stream_event", + ) + + +async def run_adapter_chain( + adapters: Iterable[PayloadAdapter], + payload: Any, + context: AdapterContext, + *, + stage: str, + mutate: bool = False, +) -> Any: + """Run adapters in order and emit trace entries around the chain. + + By default the payload is deep-copied before the first adapter. This keeps + Phase 3 isolated and prevents surprise mutations until runtime integration + explicitly chooses mutating behavior for performance. + """ + + current = payload if mutate else deepcopy(payload) + adapter_names = [adapter.name for adapter in adapters] + _trace(context, "before_adapter_chain", current, stage=stage, metadata={"adapters": adapter_names}) + for adapter in adapters: + before = deepcopy(current) + try: + current = await adapter.transform(stage, current, context) + except Exception as exc: + if context.transaction_logger: + context.transaction_logger.log_transform_error( + f"adapter:{adapter.name}:{stage}", + exc, + payload=before, + stage="adapter", + protocol=context.protocol, + transport=context.transport, + metadata={"adapter": adapter.name, "adapter_stage": stage}, + ) + raise + _trace( + context, + "after_adapter", + current, + stage=stage, + metadata={"adapter": adapter.name, "adapter_stage": stage, "changed": current != before}, + ) + _trace(context, "after_adapter_chain", current, stage=stage, metadata={"adapters": adapter_names}) + return current diff --git a/src/rotator_library/adapters/builtin.py b/src/rotator_library/adapters/builtin.py new file mode 100644 index 000000000..08c588688 --- /dev/null +++ b/src/rotator_library/adapters/builtin.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Built-in base adapters for common provider payload quirks.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +from .base import AdapterContext, PayloadAdapter + + +class NoOpAdapter(PayloadAdapter): + """Adapter that intentionally leaves payloads unchanged.""" + + name = "noop" + aliases = ("none", "passthrough") + + +class ModelOverrideAdapter(PayloadAdapter): + """Replace the outbound model field from adapter config. + + Config shape: + `{ "model": "provider/native-model-name" }` + """ + + name = "model_override" + aliases = ("override_model",) + supported_stages = ("request",) + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + config = context.config_for(self.name) + override = config.get("model") or context.metadata.get("model_override") + if not override or not isinstance(payload, dict): + return payload + updated = deepcopy(payload) + updated["model"] = override + return updated + + +class SuppressDeveloperRoleAdapter(PayloadAdapter): + """Convert or remove developer-role messages for providers that reject them. + + Config shape: + `{ "mode": "system" | "user" | "drop" }` + """ + + name = "suppress_developer_role" + aliases = ("developer_role",) + supported_stages = ("request",) + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + if not isinstance(payload, dict) or not isinstance(payload.get("messages"), list): + return payload + mode = context.config_for(self.name).get("mode", "system") + if mode not in {"system", "user", "drop"}: + raise ValueError("suppress_developer_role mode must be system, user, or drop") + updated = deepcopy(payload) + messages = [] + for message in updated.get("messages", []): + if not isinstance(message, dict) or message.get("role") != "developer": + messages.append(message) + continue + if mode == "drop": + continue + converted = dict(message) + converted["role"] = mode + messages.append(converted) + updated["messages"] = messages + return updated + + +class ReasoningContentAdapter(PayloadAdapter): + """Normalize common reasoning fields on assistant messages. + + This base adapter copies `reasoning`, `reasoning_content`, or configured + aliases into the configured output field. It deliberately does not delete + source fields; provider-specific subclasses can choose stricter behavior. + """ + + name = "reasoning_content" + aliases = ("reasoning_rewrite",) + supported_stages = ("response",) + + async def transform_response(self, payload: Any, context: AdapterContext) -> Any: + if not isinstance(payload, dict): + return payload + config = context.config_for(self.name) + output_field = config.get("output_field", "reasoning_content") + source_fields = tuple(config.get("source_fields", ("reasoning_content", "reasoning"))) + updated = deepcopy(payload) + for choice in updated.get("choices", []) if isinstance(updated.get("choices"), list) else []: + message = choice.get("message") if isinstance(choice, dict) else None + if not isinstance(message, dict): + continue + if output_field in message: + continue + for source_field in source_fields: + if source_field in message: + message[output_field] = message[source_field] + break + return updated diff --git a/src/rotator_library/adapters/registry.py b/src/rotator_library/adapters/registry.py new file mode 100644 index 000000000..7d702c0be --- /dev/null +++ b/src/rotator_library/adapters/registry.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Auto-discovery registry for payload adapters.""" + +from __future__ import annotations + +import importlib +import inspect +import logging +import pkgutil +from typing import Type + +from .base import PayloadAdapter + +lib_logger = logging.getLogger("rotator_library") + +ADAPTER_PLUGINS: dict[str, Type[PayloadAdapter]] = {} +ADAPTER_ALIASES: dict[str, str] = {} +_ADAPTER_INSTANCES: dict[str, PayloadAdapter] = {} + +_INFRASTRUCTURE_MODULES = {"base", "registry"} + + +def register_adapter(adapter_class: Type[PayloadAdapter], *, replace: bool = False) -> Type[PayloadAdapter]: + """Register an adapter class and its aliases with collision checks.""" + + if not inspect.isclass(adapter_class) or not issubclass(adapter_class, PayloadAdapter): + raise TypeError("adapter_class must inherit PayloadAdapter") + if adapter_class is PayloadAdapter: + raise TypeError("cannot register PayloadAdapter itself") + name = adapter_class.name + if not name: + raise ValueError(f"Adapter {adapter_class.__name__} must define a name") + alias_owner = ADAPTER_ALIASES.get(name) + if alias_owner and alias_owner != name and not replace: + raise ValueError(f"Adapter name conflicts with registered alias: {name}") + existing = ADAPTER_PLUGINS.get(name) + if existing and existing is not adapter_class and not replace: + raise ValueError(f"Adapter name already registered: {name}") + if replace and existing and existing is not adapter_class: + for alias, owner in list(ADAPTER_ALIASES.items()): + if owner == name: + ADAPTER_ALIASES.pop(alias, None) + ADAPTER_PLUGINS[name] = adapter_class + _ADAPTER_INSTANCES.pop(name, None) + for alias in adapter_class.aliases: + existing_name = ADAPTER_ALIASES.get(alias) + if existing_name and existing_name != name and not replace: + raise ValueError(f"Adapter alias already registered: {alias}") + if alias in ADAPTER_PLUGINS and alias != name and not replace: + raise ValueError(f"Adapter alias conflicts with registered name: {alias}") + ADAPTER_ALIASES[alias] = name + lib_logger.debug("Registered adapter: %s", name) + return adapter_class + + +def resolve_adapter_name(name: str) -> str: + if name in ADAPTER_PLUGINS: + return name + if name in ADAPTER_ALIASES: + return ADAPTER_ALIASES[name] + raise KeyError(f"Unknown adapter: {name}") + + +def get_adapter_class(name: str) -> Type[PayloadAdapter]: + return ADAPTER_PLUGINS[resolve_adapter_name(name)] + + +def get_adapter(name: str) -> PayloadAdapter: + canonical = resolve_adapter_name(name) + if canonical not in _ADAPTER_INSTANCES: + _ADAPTER_INSTANCES[canonical] = ADAPTER_PLUGINS[canonical]() + return _ADAPTER_INSTANCES[canonical] + + +def list_adapters() -> list[str]: + return sorted(ADAPTER_PLUGINS) + + +def _register_adapters() -> None: + package = importlib.import_module(__package__ or "rotator_library.adapters") + for _, module_name, _ in pkgutil.iter_modules(package.__path__): + if module_name.startswith("_") or module_name in _INFRASTRUCTURE_MODULES: + continue + module = importlib.import_module(f"{package.__name__}.{module_name}") + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if ( + inspect.isclass(attribute) + and issubclass(attribute, PayloadAdapter) + and attribute is not PayloadAdapter + and attribute.__module__ == module.__name__ + ): + register_adapter(attribute) + + +_register_adapters() diff --git a/tests/test_adapter_registry.py b/tests/test_adapter_registry.py new file mode 100644 index 000000000..90050dc31 --- /dev/null +++ b/tests/test_adapter_registry.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import pytest + +from rotator_library.adapters import ( + AdapterContext, + PayloadAdapter, + get_adapter, + get_adapter_class, + list_adapters, + register_adapter, + resolve_adapter_name, + run_adapter_chain, +) + + +def test_adapter_registry_auto_discovers_builtins_and_aliases() -> None: + adapters = list_adapters() + + assert "noop" in adapters + assert "model_override" in adapters + assert resolve_adapter_name("passthrough") == "noop" + assert get_adapter_class("override_model").name == "model_override" + assert get_adapter("none") is get_adapter("noop") + + +def test_adapter_registry_rejects_duplicate_names_and_alias_collisions() -> None: + class DuplicateNoop(PayloadAdapter): + name = "noop" + + class AliasCollision(PayloadAdapter): + name = "custom_alias_collision" + aliases = ("noop",) + + with pytest.raises(ValueError): + register_adapter(DuplicateNoop) + with pytest.raises(ValueError): + register_adapter(AliasCollision) + + +@pytest.mark.asyncio +async def test_noop_adapter_preserves_payload_without_mutating_original() -> None: + payload = {"model": "original", "messages": []} + + result = await run_adapter_chain([get_adapter("noop")], payload, AdapterContext(), stage="request") + + assert result == payload + assert result is not payload + + +@pytest.mark.asyncio +async def test_model_override_adapter_changes_model_from_config() -> None: + payload = {"model": "public-name", "messages": []} + context = AdapterContext(adapter_config={"model_override": {"model": "native-name"}}) + + result = await run_adapter_chain([get_adapter("model_override")], payload, context, stage="request") + + assert result["model"] == "native-name" + assert payload["model"] == "public-name" + + +@pytest.mark.asyncio +async def test_suppress_developer_role_converts_or_drops_messages() -> None: + payload = {"messages": [{"role": "developer", "content": "rules"}, {"role": "user", "content": "hi"}]} + + system_result = await run_adapter_chain([get_adapter("suppress_developer_role")], payload, AdapterContext(), stage="request") + drop_result = await run_adapter_chain( + [get_adapter("suppress_developer_role")], + payload, + AdapterContext(adapter_config={"suppress_developer_role": {"mode": "drop"}}), + stage="request", + ) + + assert system_result["messages"][0]["role"] == "system" + assert [message["role"] for message in drop_result["messages"]] == ["user"] + + +@pytest.mark.asyncio +async def test_reasoning_content_adapter_copies_common_reasoning_field() -> None: + payload = {"choices": [{"message": {"role": "assistant", "reasoning": "hidden"}}]} + + result = await run_adapter_chain([get_adapter("reasoning_content")], payload, AdapterContext(), stage="response") + + assert result["choices"][0]["message"]["reasoning_content"] == "hidden" + + +@pytest.mark.asyncio +async def test_adapter_chain_order_is_preserved() -> None: + payload = {"model": "public", "messages": [{"role": "developer", "content": "rules"}]} + context = AdapterContext(adapter_config={"model_override": {"model": "native"}}) + + result = await run_adapter_chain( + [get_adapter("model_override"), get_adapter("suppress_developer_role")], + payload, + context, + stage="request", + ) + + assert result["model"] == "native" + assert result["messages"][0]["role"] == "system" From eef4383b23bd17e18246149cf54ab519888d176d Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:53:39 +0200 Subject: [PATCH 015/182] feat(field-cache): add rules and path helpers Adds field-cache rule and injection dataclasses, cache context scope values, default provider/model/classifier/session scoping, and a small JSON-path-like engine for extraction and predictable injection. The path helper supports dict keys, list indexes, wildcard extraction, tail indexes, missing-path no-ops, and explicit errors for malformed paths or wildcard injection. Tests: python -m pytest tests/test_field_cache_paths.py --- src/rotator_library/field_cache/__init__.py | 17 +++ src/rotator_library/field_cache/paths.py | 151 ++++++++++++++++++++ src/rotator_library/field_cache/types.py | 91 ++++++++++++ tests/test_field_cache_paths.py | 76 ++++++++++ 4 files changed, 335 insertions(+) create mode 100644 src/rotator_library/field_cache/__init__.py create mode 100644 src/rotator_library/field_cache/paths.py create mode 100644 src/rotator_library/field_cache/types.py create mode 100644 tests/test_field_cache_paths.py diff --git a/src/rotator_library/field_cache/__init__.py b/src/rotator_library/field_cache/__init__.py new file mode 100644 index 000000000..27ec118b5 --- /dev/null +++ b/src/rotator_library/field_cache/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Configurable provider field-cache rules and helpers.""" + +from .paths import FieldCachePathError, extract_path, inject_path, parse_path +from .types import FieldCacheContext, FieldCacheInjection, FieldCacheRule + +__all__ = [ + "FieldCacheContext", + "FieldCacheInjection", + "FieldCachePathError", + "FieldCacheRule", + "extract_path", + "inject_path", + "parse_path", +] diff --git a/src/rotator_library/field_cache/paths.py b/src/rotator_library/field_cache/paths.py new file mode 100644 index 000000000..96997ecad --- /dev/null +++ b/src/rotator_library/field_cache/paths.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Small JSON-path-like selector used by field-cache rules.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + + +class FieldCachePathError(ValueError): + """Raised when a field-cache path is malformed or cannot be injected.""" + + +@dataclass(frozen=True) +class PathToken: + kind: Literal["key", "index", "wildcard"] + value: str | int | None = None + + +def parse_path(path: str) -> tuple[PathToken, ...]: + """Parse a minimal dotted path with indexes and wildcards. + + Supported examples: `choices.0.message.content`, `choices.*.message`, + `messages[-1].reasoning_content`. Escaping is intentionally unsupported in + Phase 3 so malformed provider configs fail early and clearly. + """ + + if not path or path.startswith(".") or path.endswith(".") or ".." in path: + raise FieldCachePathError(f"Malformed field-cache path: {path!r}") + tokens: list[PathToken] = [] + for segment in path.split("."): + if not segment: + raise FieldCachePathError(f"Malformed field-cache path: {path!r}") + _parse_segment(segment, tokens, path) + return tuple(tokens) + + +def _parse_segment(segment: str, tokens: list[PathToken], full_path: str) -> None: + if segment == "*": + tokens.append(PathToken("wildcard")) + return + cursor = 0 + while cursor < len(segment): + bracket = segment.find("[", cursor) + if bracket == -1: + value = segment[cursor:] + if value: + if value.lstrip("-").isdigit(): + tokens.append(PathToken("index", int(value))) + else: + tokens.append(PathToken("key", value)) + return + if bracket > cursor: + tokens.append(PathToken("key", segment[cursor:bracket])) + close = segment.find("]", bracket) + if close == -1: + raise FieldCachePathError(f"Unclosed index in field-cache path: {full_path!r}") + raw_index = segment[bracket + 1 : close] + if raw_index == "*": + tokens.append(PathToken("wildcard")) + else: + try: + tokens.append(PathToken("index", int(raw_index))) + except ValueError as exc: + raise FieldCachePathError(f"Invalid index {raw_index!r} in field-cache path: {full_path!r}") from exc + cursor = close + 1 + + +def extract_path(payload: Any, path: str) -> list[Any]: + """Return every value matching path in stable traversal order.""" + + tokens = parse_path(path) + current = [payload] + for token in tokens: + next_values: list[Any] = [] + for value in current: + next_values.extend(_extract_token(value, token)) + current = next_values + if not current: + break + return current + + +def _extract_token(value: Any, token: PathToken) -> list[Any]: + if token.kind == "key": + if isinstance(value, dict) and token.value in value: + return [value[token.value]] + return [] + if token.kind == "index": + if isinstance(value, list) and value: + index = int(token.value) + if -len(value) <= index < len(value): + return [value[index]] + return [] + if token.kind == "wildcard": + if isinstance(value, dict): + return list(value.values()) + if isinstance(value, list): + return list(value) + return [] + raise FieldCachePathError(f"Unknown path token: {token}") + + +def inject_path(payload: Any, path: str, injected_value: Any, *, when_missing_only: bool = False) -> bool: + """Inject a value at a simple path, creating dict containers as needed. + + Wildcard injection is rejected because creating multiple branches can be + ambiguous and provider-specific. List indexes must already exist; this keeps + mutation predictable for message-tail use cases like `messages[-1].field`. + """ + + tokens = parse_path(path) + if any(token.kind == "wildcard" for token in tokens): + raise FieldCachePathError("Wildcard injection is not supported") + if not isinstance(payload, dict): + raise FieldCachePathError("Field-cache injection root must be a dict") + current = payload + for index, token in enumerate(tokens): + is_last = index == len(tokens) - 1 + if token.kind == "key": + if not isinstance(current, dict): + raise FieldCachePathError(f"Cannot inject key {token.value!r} into non-dict value") + key = str(token.value) + if is_last: + if when_missing_only and key in current: + return False + changed = current.get(key) != injected_value + current[key] = injected_value + return changed + if key not in current or current[key] is None: + current[key] = [] if tokens[index + 1].kind == "index" else {} + current = current[key] + continue + if token.kind == "index": + if not isinstance(current, list) or not current: + raise FieldCachePathError("Cannot inject into missing or empty list") + list_index = int(token.value) + if not (-len(current) <= list_index < len(current)): + raise FieldCachePathError(f"List index out of range for field-cache injection: {list_index}") + if is_last: + if when_missing_only and current[list_index] is not None: + return False + changed = current[list_index] != injected_value + current[list_index] = injected_value + return changed + current = current[list_index] + continue + raise FieldCachePathError(f"Unsupported injection token: {token}") + return False diff --git a/src/rotator_library/field_cache/types.py b/src/rotator_library/field_cache/types.py new file mode 100644 index 000000000..02a44add8 --- /dev/null +++ b/src/rotator_library/field_cache/types.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Data types for provider field-cache rules.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal, Optional + +FieldCacheSource = Literal[ + "request", + "response", + "stream_event", + "unified_request", + "unified_response", + "unified_stream_event", +] +FieldCacheTarget = Literal["request", "unified_request", "metadata"] +FieldCacheMode = Literal["last", "all", "last_user_turn", "last_assistant_turn", "per_tool_call"] +FieldCacheScope = Literal["provider", "model", "credential", "session", "conversation", "classifier"] + +DEFAULT_SCOPE: tuple[FieldCacheScope, ...] = ("provider", "model", "classifier", "session") + + +@dataclass(frozen=True) +class FieldCacheInjection: + """Where and how a cached value should be injected into a later payload.""" + + target: FieldCacheTarget + path: str + when_missing_only: bool = False + insert: bool = False + as_list: bool = False + + +@dataclass(frozen=True) +class FieldCacheRule: + """Declarative rule for extracting and re-injecting provider state. + + Rules are protocol/provider extensions, not session-affinity logic. Session + tracking decides continuity; field-cache rules preserve protocol state such + as reasoning content, thought signatures, prompt cache keys, and response IDs. + """ + + name: str + source: FieldCacheSource + path: str + mode: FieldCacheMode = "last" + scope: tuple[FieldCacheScope, ...] = DEFAULT_SCOPE + inject: Optional[FieldCacheInjection] = None + enabled: bool = True + ttl_seconds: Optional[int] = None + metadata: dict[str, Any] = field(default_factory=dict) + allow_missing_session: bool = False + + def __post_init__(self) -> None: + if not self.name or any(char in self.name for char in "/\\:"): + raise ValueError("FieldCacheRule.name must be non-empty and filesystem-safe") + if self.mode not in {"last", "all", "last_user_turn", "last_assistant_turn", "per_tool_call"}: + raise ValueError(f"Unsupported field-cache mode: {self.mode}") + if not self.scope: + raise ValueError("FieldCacheRule.scope must contain at least one dimension") + + +@dataclass(frozen=True) +class FieldCacheContext: + """Scope values used to isolate cached provider fields.""" + + provider: Optional[str] = None + model: Optional[str] = None + credential_id: Optional[str] = None + session_id: Optional[str] = None + conversation_id: Optional[str] = None + classifier: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + + def value_for_scope(self, scope: FieldCacheScope) -> Optional[str]: + if scope == "provider": + return self.provider + if scope == "model": + return self.model + if scope == "credential": + return self.credential_id + if scope == "session": + return self.session_id + if scope == "conversation": + return self.conversation_id + if scope == "classifier": + return self.classifier + raise ValueError(f"Unsupported field-cache scope: {scope}") diff --git a/tests/test_field_cache_paths.py b/tests/test_field_cache_paths.py new file mode 100644 index 000000000..4f5b22066 --- /dev/null +++ b/tests/test_field_cache_paths.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import pytest + +from rotator_library.field_cache.paths import FieldCachePathError, extract_path, inject_path, parse_path +from rotator_library.field_cache.types import FieldCacheInjection, FieldCacheRule + + +def test_parse_path_supports_keys_indexes_wildcards_and_tail_index() -> None: + tokens = parse_path("choices.*.message.parts[-1]") + + assert [token.kind for token in tokens] == ["key", "wildcard", "key", "key", "index"] + assert tokens[-1].value == -1 + + +def test_extract_path_handles_nested_dict_list_and_wildcard() -> None: + payload = { + "choices": [ + {"message": {"reasoning_content": "a"}}, + {"message": {"reasoning_content": "b"}}, + ] + } + + assert extract_path(payload, "choices.*.message.reasoning_content") == ["a", "b"] + assert extract_path(payload, "choices.1.message.reasoning_content") == ["b"] + + +def test_extract_path_missing_values_are_noop() -> None: + assert extract_path({"choices": []}, "choices.*.message.reasoning_content") == [] + assert extract_path({}, "missing.path") == [] + + +def test_extract_path_tail_index() -> None: + payload = {"messages": [{"content": "first"}, {"content": "last"}]} + + assert extract_path(payload, "messages[-1].content") == ["last"] + + +def test_inject_path_creates_dict_containers() -> None: + payload = {"messages": [{"role": "assistant"}]} + + changed = inject_path(payload, "messages[-1].reasoning_content", "hidden") + + assert changed is True + assert payload["messages"][-1]["reasoning_content"] == "hidden" + + +def test_inject_path_respects_when_missing_only() -> None: + payload = {"metadata": {"prompt_cache_key": "existing"}} + + changed = inject_path(payload, "metadata.prompt_cache_key", "new", when_missing_only=True) + + assert changed is False + assert payload["metadata"]["prompt_cache_key"] == "existing" + + +def test_inject_path_rejects_wildcards_and_missing_lists() -> None: + with pytest.raises(FieldCachePathError): + inject_path({"choices": []}, "choices.*.message.reasoning_content", "x") + with pytest.raises(FieldCachePathError): + inject_path({}, "messages[-1].reasoning_content", "x") + + +def test_malformed_paths_and_rules_raise_useful_errors() -> None: + with pytest.raises(FieldCachePathError): + parse_path("choices..message") + with pytest.raises(FieldCachePathError): + parse_path("messages[abc]") + with pytest.raises(ValueError): + FieldCacheRule(name="bad/name", source="response", path="x") + assert FieldCacheRule( + name="reasoning_content", + source="response", + path="choices.*.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="messages[-1].reasoning_content"), + ).scope == ("provider", "model", "classifier", "session") From ce9d6606ef63c81e437de108fd3f21d204bb8957 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:55:53 +0200 Subject: [PATCH 016/182] feat(field-cache): add scoped cache engine Adds async field-cache stores, a ProviderCache-backed wrapper, scoped cache key construction, and the extraction/injection engine for last, all, turn-compatible, stream-event, and per-tool-call-validated rules. The engine copies payloads by default, isolates values by provider/model/session/classifier/credential scope, skips missing required session scope, and emits transform trace metadata when a transaction logger is supplied. Tests: python -m pytest tests/test_field_cache_engine.py tests/test_field_cache_paths.py --- src/rotator_library/field_cache/__init__.py | 8 + src/rotator_library/field_cache/engine.py | 228 ++++++++++++++++++++ src/rotator_library/field_cache/store.py | 80 +++++++ tests/test_field_cache_engine.py | 161 ++++++++++++++ 4 files changed, 477 insertions(+) create mode 100644 src/rotator_library/field_cache/engine.py create mode 100644 src/rotator_library/field_cache/store.py create mode 100644 tests/test_field_cache_engine.py diff --git a/src/rotator_library/field_cache/__init__.py b/src/rotator_library/field_cache/__init__.py index 27ec118b5..ae5d62aa2 100644 --- a/src/rotator_library/field_cache/__init__.py +++ b/src/rotator_library/field_cache/__init__.py @@ -4,13 +4,21 @@ """Configurable provider field-cache rules and helpers.""" from .paths import FieldCachePathError, extract_path, inject_path, parse_path +from .engine import FieldCacheEngine, FieldCacheOperation, build_cache_key +from .store import FieldCacheStore, InMemoryFieldCacheStore, ProviderCacheFieldStore from .types import FieldCacheContext, FieldCacheInjection, FieldCacheRule __all__ = [ "FieldCacheContext", + "FieldCacheEngine", "FieldCacheInjection", + "FieldCacheOperation", "FieldCachePathError", "FieldCacheRule", + "FieldCacheStore", + "InMemoryFieldCacheStore", + "ProviderCacheFieldStore", + "build_cache_key", "extract_path", "inject_path", "parse_path", diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py new file mode 100644 index 000000000..6d06c36e1 --- /dev/null +++ b/src/rotator_library/field_cache/engine.py @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Extraction and injection engine for field-cache rules.""" + +from __future__ import annotations + +import hashlib +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Iterable, Optional + +from .paths import FieldCachePathError, extract_path, inject_path, parse_path +from .store import FieldCacheStore, InMemoryFieldCacheStore +from .types import FieldCacheContext, FieldCacheRule + + +@dataclass +class FieldCacheOperation: + """Summary of one field-cache rule application.""" + + rule_name: str + cache_key: Optional[str] + matched: int = 0 + changed: bool = False + hit: bool = False + skipped: bool = False + reason: Optional[str] = None + sample_values: list[Any] = field(default_factory=list) + + +def _safe_scope_value(value: str) -> str: + digest = hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] + return digest + + +def build_cache_key(rule: FieldCacheRule, context: FieldCacheContext) -> Optional[str]: + """Build a scoped cache key or return None when required scope is absent.""" + + parts = [f"rule={rule.name}"] + for scope in rule.scope: + value = context.value_for_scope(scope) + if value is None or value == "": + if scope == "session" and not rule.allow_missing_session: + return None + value = "_none" + if scope in {"provider", "model"}: + safe_value = value.replace("/", "_").replace("\\", "_").replace(":", "_") + else: + safe_value = _safe_scope_value(value) + parts.append(f"{scope}={safe_value}") + return "|".join(parts) + + +class FieldCacheEngine: + """Apply field-cache extraction and injection rules. + + The engine is isolated from request execution in Phase 3. It defaults to + copying payloads before injection so tests and future providers can reason + about mutations explicitly. + """ + + def __init__(self, rules: Iterable[FieldCacheRule], store: Optional[FieldCacheStore] = None) -> None: + self.rules = tuple(rules) + self.store = store or InMemoryFieldCacheStore() + self._validate_rules() + + def _validate_rules(self) -> None: + names: set[str] = set() + for rule in self.rules: + if rule.name in names: + raise ValueError(f"Duplicate field-cache rule name: {rule.name}") + names.add(rule.name) + parse_path(rule.path) + if rule.inject: + parse_path(rule.inject.path) + if rule.mode == "per_tool_call" and not rule.metadata.get("tool_call_id_path"): + raise ValueError("per_tool_call field-cache mode requires metadata.tool_call_id_path") + + async def extract( + self, + source: str, + payload: Any, + context: FieldCacheContext, + *, + transaction_logger: Optional[Any] = None, + ) -> list[FieldCacheOperation]: + operations: list[FieldCacheOperation] = [] + for rule in self._rules_for_source(source): + operation = FieldCacheOperation(rule_name=rule.name, cache_key=build_cache_key(rule, context)) + if not operation.cache_key: + operation.skipped = True + operation.reason = "missing_required_scope" + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_extraction", payload, rule, operation, source=source) + continue + try: + values = extract_path(payload, rule.path) + operation.matched = len(values) + operation.sample_values = values[:3] + if values: + await self._store_values(rule, operation.cache_key, values) + operation.changed = True + except Exception as exc: + self._log_error(transaction_logger, "field_cache_extract", exc, payload, rule) + raise + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_extraction", payload, rule, operation, source=source) + return operations + + async def inject( + self, + target: str, + payload: Any, + context: FieldCacheContext, + *, + transaction_logger: Optional[Any] = None, + mutate: bool = False, + ) -> tuple[Any, list[FieldCacheOperation]]: + updated = payload if mutate else deepcopy(payload) + operations: list[FieldCacheOperation] = [] + for rule in self._rules_for_injection(target): + operation = FieldCacheOperation(rule_name=rule.name, cache_key=build_cache_key(rule, context)) + if not rule.inject: + continue + if not operation.cache_key: + operation.skipped = True + operation.reason = "missing_required_scope" + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + continue + try: + cached = await self.store.get(operation.cache_key) + if cached is None: + operation.reason = "cache_miss" + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + continue + operation.hit = True + value = cached if rule.inject.as_list or rule.mode == "all" else _last_value(cached) + operation.changed = inject_path( + updated, + rule.inject.path, + value, + when_missing_only=rule.inject.when_missing_only, + ) + operation.sample_values = value if isinstance(value, list) else [value] + except Exception as exc: + self._log_error(transaction_logger, "field_cache_inject", exc, updated, rule) + raise + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + return updated, operations + + def _rules_for_source(self, source: str) -> list[FieldCacheRule]: + return [rule for rule in self.rules if rule.enabled and rule.source == source] + + def _rules_for_injection(self, target: str) -> list[FieldCacheRule]: + return [rule for rule in self.rules if rule.enabled and rule.inject and rule.inject.target == target] + + async def _store_values(self, rule: FieldCacheRule, cache_key: str, values: list[Any]) -> None: + if rule.mode == "all": + await self.store.append(cache_key, values) + return + if rule.mode in {"last", "last_user_turn", "last_assistant_turn"}: + await self.store.set(cache_key, values[-1]) + return + if rule.mode == "per_tool_call": + tool_path = rule.metadata.get("tool_call_id_path") + stored = {} + for value in values: + tool_ids = extract_path(value, tool_path) if tool_path else [] + if tool_ids: + stored[str(tool_ids[0])] = value + await self.store.set(cache_key, stored) + return + raise ValueError(f"Unsupported field-cache mode: {rule.mode}") + + def _trace( + self, + transaction_logger: Optional[Any], + pass_name: str, + payload: Any, + rule: FieldCacheRule, + operation: FieldCacheOperation, + **extra_metadata: Any, + ) -> None: + if not transaction_logger: + return + transaction_logger.log_transform_pass( + pass_name, + payload, + direction="stream" if rule.source == "stream_event" else "request" if "injection" in pass_name else "response", + stage="adapter", + metadata={ + "rule_name": rule.name, + "source": rule.source, + "path": rule.path, + "mode": rule.mode, + "scope": list(rule.scope), + "cache_key": operation.cache_key, + "matched": operation.matched, + "changed": operation.changed, + "hit": operation.hit, + "skipped": operation.skipped, + "reason": operation.reason, + "sample_values": operation.sample_values[:3], + **extra_metadata, + }, + snapshot=rule.source != "stream_event", + ) + + def _log_error(self, transaction_logger: Optional[Any], pass_name: str, error: BaseException, payload: Any, rule: FieldCacheRule) -> None: + if not transaction_logger: + return + transaction_logger.log_transform_error( + pass_name, + error, + payload=payload, + stage="adapter", + metadata={"rule_name": rule.name, "path": rule.path, "mode": rule.mode}, + ) + + +def _last_value(value: Any) -> Any: + if isinstance(value, list): + return value[-1] if value else None + return value diff --git a/src/rotator_library/field_cache/store.py b/src/rotator_library/field_cache/store.py new file mode 100644 index 000000000..9394bebdd --- /dev/null +++ b/src/rotator_library/field_cache/store.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Async stores for field-cache values.""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, Protocol + +from ..protocols import serialize_value + + +class FieldCacheStore(Protocol): + """Minimal async store interface used by `FieldCacheEngine`.""" + + async def get(self, key: str) -> Any: ... + + async def set(self, key: str, value: Any) -> None: ... + + async def append(self, key: str, values: list[Any]) -> list[Any]: ... + + async def clear(self) -> None: ... + + +class InMemoryFieldCacheStore: + """Simple process-local store for tests and lightweight runtime use.""" + + def __init__(self) -> None: + self._values: dict[str, Any] = {} + + async def get(self, key: str) -> Any: + return deepcopy(self._values.get(key)) + + async def set(self, key: str, value: Any) -> None: + self._values[key] = deepcopy(value) + + async def append(self, key: str, values: list[Any]) -> list[Any]: + current = self._values.get(key) + if not isinstance(current, list): + current = [] + current = deepcopy(current) + deepcopy(values) + self._values[key] = current + return deepcopy(current) + + async def clear(self) -> None: + self._values.clear() + + +class ProviderCacheFieldStore: + """Field-cache store backed by an injected `ProviderCache` instance. + + The wrapper does not create `ProviderCache` itself because that class starts + background async tasks during initialization. Providers or later config code + should own that lifecycle and pass an initialized cache here. + """ + + def __init__(self, provider_cache: Any) -> None: + self._cache = provider_cache + + async def get(self, key: str) -> Any: + raw = await self._cache.retrieve_async(key) + if raw is None: + return None + return json.loads(raw) + + async def set(self, key: str, value: Any) -> None: + await self._cache.store_async(key, json.dumps(serialize_value(value), ensure_ascii=False)) + + async def append(self, key: str, values: list[Any]) -> list[Any]: + current = await self.get(key) + if not isinstance(current, list): + current = [] + current = current + serialize_value(values) + await self.set(key, current) + return current + + async def clear(self) -> None: + await self._cache.clear() diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py new file mode 100644 index 000000000..98d6dbbf2 --- /dev/null +++ b/tests/test_field_cache_engine.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.field_cache import ( + FieldCacheContext, + FieldCacheEngine, + FieldCacheInjection, + FieldCacheRule, + InMemoryFieldCacheStore, + ProviderCacheFieldStore, + build_cache_key, +) + + +def _reasoning_rule(mode: str = "last", scope=("provider", "model", "session")) -> FieldCacheRule: + return FieldCacheRule( + name="reasoning_content", + source="response", + path="choices.*.message.reasoning_content", + mode=mode, + scope=scope, + inject=FieldCacheInjection(target="request", path="messages[-1].reasoning_content"), + ) + + +def _context(**overrides) -> FieldCacheContext: + values = {"provider": "openai", "model": "gpt-test", "session_id": "session-a", "classifier": "global"} + values.update(overrides) + return FieldCacheContext(**values) + + +@pytest.mark.asyncio +async def test_extract_response_value_and_inject_into_next_request() -> None: + engine = FieldCacheEngine([_reasoning_rule()]) + response = {"choices": [{"message": {"reasoning_content": "hidden"}}]} + request = {"messages": [{"role": "user", "content": "hi"}]} + + operations = await engine.extract("response", response, _context()) + updated, injection_operations = await engine.inject("request", request, _context()) + + assert operations[0].matched == 1 + assert injection_operations[0].hit is True + assert updated["messages"][-1]["reasoning_content"] == "hidden" + assert "reasoning_content" not in request["messages"][-1] + + +@pytest.mark.asyncio +async def test_last_mode_overwrites_prior_value() -> None: + engine = FieldCacheEngine([_reasoning_rule()]) + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "first"}}]}, _context()) + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "second"}}]}, _context()) + updated, _ = await engine.inject("request", {"messages": [{}]}, _context()) + + assert updated["messages"][-1]["reasoning_content"] == "second" + + +@pytest.mark.asyncio +async def test_all_mode_appends_values() -> None: + engine = FieldCacheEngine([_reasoning_rule(mode="all")]) + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "first"}}]}, _context()) + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "second"}}]}, _context()) + updated, _ = await engine.inject("request", {"messages": [{}]}, _context()) + + assert updated["messages"][-1]["reasoning_content"] == ["first", "second"] + + +@pytest.mark.asyncio +async def test_scope_isolation_by_session_and_classifier() -> None: + rule = _reasoning_rule(scope=("provider", "model", "session", "classifier")) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "a"}}]}, _context(session_id="session-a")) + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "b"}}]}, _context(session_id="session-b")) + + updated, _ = await engine.inject("request", {"messages": [{}]}, _context(session_id="session-a")) + + assert updated["messages"][-1]["reasoning_content"] == "a" + assert build_cache_key(rule, _context(session_id="session-a")) != build_cache_key(rule, _context(session_id="session-b")) + + +@pytest.mark.asyncio +async def test_missing_session_scope_skips_by_default() -> None: + engine = FieldCacheEngine([_reasoning_rule()]) + + operations = await engine.extract("response", {"choices": [{"message": {"reasoning_content": "x"}}]}, _context(session_id=None)) + + assert operations[0].skipped is True + assert operations[0].reason == "missing_required_scope" + + +@pytest.mark.asyncio +async def test_missing_path_is_noop() -> None: + engine = FieldCacheEngine([_reasoning_rule()]) + + operations = await engine.extract("response", {"choices": []}, _context()) + + assert operations[0].matched == 0 + assert operations[0].changed is False + + +@pytest.mark.asyncio +async def test_stream_event_extraction() -> None: + rule = FieldCacheRule( + name="provider_session_id", + source="stream_event", + path="metadata.provider_session_id", + scope=("provider", "model", "session"), + inject=FieldCacheInjection(target="request", path="metadata.provider_session_id"), + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("stream_event", {"metadata": {"provider_session_id": "sid_1"}}, _context()) + updated, _ = await engine.inject("request", {"metadata": {}}, _context()) + + assert updated["metadata"]["provider_session_id"] == "sid_1" + + +def test_per_tool_call_requires_tool_call_id_path() -> None: + with pytest.raises(ValueError): + FieldCacheEngine([ + FieldCacheRule(name="tool_state", source="response", path="tool_calls.*", mode="per_tool_call") + ]) + + +@pytest.mark.asyncio +async def test_provider_cache_field_store_wraps_json_string_cache() -> None: + class FakeProviderCache: + def __init__(self) -> None: + self.values = {} + + async def retrieve_async(self, key: str): + return self.values.get(key) + + async def store_async(self, key: str, value: str) -> None: + json.loads(value) + self.values[key] = value + + async def clear(self) -> None: + self.values.clear() + + store = ProviderCacheFieldStore(FakeProviderCache()) + + await store.set("key", {"value": 1}) + assert await store.get("key") == {"value": 1} + await store.append("key", [{"value": 2}]) + assert await store.get("key") == [{"value": 2}] + + +@pytest.mark.asyncio +async def test_in_memory_store_returns_deep_copies() -> None: + store = InMemoryFieldCacheStore() + await store.set("key", {"nested": []}) + value = await store.get("key") + value["nested"].append("mutated") + + assert await store.get("key") == {"nested": []} From 790e52e2d4a48fcc49e4aa12ef1f6ba2fe4bebe7 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:57:51 +0200 Subject: [PATCH 017/182] test(field-cache): cover adapter and cache trace passes Adds the missing before_field_cache_extraction and before_field_cache_injection trace passes so field-cache operations now emit both before and after states. Adds trace-focused tests for adapter chains, field-cache extraction/injection, rule metadata, cache hits, mutation flags, and transform_log_error emission on failed injection. Tests: python -m pytest tests/test_field_cache_trace.py tests/test_field_cache_engine.py tests/test_adapter_registry.py --- src/rotator_library/field_cache/engine.py | 2 + tests/test_field_cache_trace.py | 88 +++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 tests/test_field_cache_trace.py diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index 6d06c36e1..f0f328410 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -88,6 +88,7 @@ async def extract( operations: list[FieldCacheOperation] = [] for rule in self._rules_for_source(source): operation = FieldCacheOperation(rule_name=rule.name, cache_key=build_cache_key(rule, context)) + self._trace(transaction_logger, "before_field_cache_extraction", payload, rule, operation, source=source) if not operation.cache_key: operation.skipped = True operation.reason = "missing_required_scope" @@ -123,6 +124,7 @@ async def inject( operation = FieldCacheOperation(rule_name=rule.name, cache_key=build_cache_key(rule, context)) if not rule.inject: continue + self._trace(transaction_logger, "before_field_cache_injection", updated, rule, operation, target=target) if not operation.cache_key: operation.skipped = True operation.reason = "missing_required_scope" diff --git a/tests/test_field_cache_trace.py b/tests/test_field_cache_trace.py new file mode 100644 index 000000000..4ea423673 --- /dev/null +++ b/tests/test_field_cache_trace.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.adapters import AdapterContext, get_adapter, run_adapter_chain +from rotator_library.field_cache import FieldCacheContext, FieldCacheEngine, FieldCacheInjection, FieldCacheRule +from rotator_library.transaction_logger import TransactionLogger + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_adapter_chain_emits_before_after_trace_entries(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + context = AdapterContext( + provider="openai", + model="gpt-test", + protocol="openai_chat", + credential_id="cred_1", + transport="http", + transaction_logger=logger, + adapter_config={"model_override": {"model": "native"}}, + ) + + result = await run_adapter_chain([get_adapter("model_override")], {"model": "public"}, context, stage="request") + + entries = _trace_entries(logger.log_dir) + assert result["model"] == "native" + assert [entry["pass_name"] for entry in entries] == ["before_adapter_chain", "after_adapter", "after_adapter_chain"] + assert entries[1]["metadata"]["adapter"] == "model_override" + assert entries[1]["metadata"]["changed"] is True + assert entries[1]["credential_id"] == "cred_1" + + +@pytest.mark.asyncio +async def test_field_cache_extract_and_inject_emit_before_after_trace_entries(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="reasoning_content", + source="response", + path="choices.*.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="messages[-1].reasoning_content"), + ) + engine = FieldCacheEngine([rule]) + context = FieldCacheContext(provider="openai", model="gpt-test", session_id="session_1", classifier="global") + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "hidden"}}]}, context, transaction_logger=logger) + updated, _ = await engine.inject("request", {"messages": [{"role": "user"}]}, context, transaction_logger=logger) + + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert updated["messages"][-1]["reasoning_content"] == "hidden" + assert pass_names == [ + "before_field_cache_extraction", + "after_field_cache_extraction", + "before_field_cache_injection", + "after_field_cache_injection", + ] + assert entries[1]["metadata"]["rule_name"] == "reasoning_content" + assert entries[1]["metadata"]["matched"] == 1 + assert entries[3]["metadata"]["hit"] is True + assert entries[3]["metadata"]["changed"] is True + + +@pytest.mark.asyncio +async def test_field_cache_errors_emit_transform_log_error(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="bad_injection", + source="response", + path="choices.*.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="messages.*.reasoning_content"), + ) + engine = FieldCacheEngine([rule]) + context = FieldCacheContext(provider="openai", model="gpt-test", session_id="session_1", classifier="global") + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "hidden"}}]}, context, transaction_logger=logger) + with pytest.raises(Exception): + await engine.inject("request", {"messages": [{"role": "user"}]}, context, transaction_logger=logger) + + entries = _trace_entries(logger.log_dir) + error_entry = next(entry for entry in entries if entry["pass_name"] == "transform_log_error") + assert error_entry["data"]["failed_pass_name"] == "field_cache_inject" + assert error_entry["metadata"]["rule_name"] == "bad_injection" From bcb4a580f2568361ed0fff6076a00fe4045ccd3f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sat, 30 May 2026 23:59:23 +0200 Subject: [PATCH 018/182] feat(providers): declare protocol adapter hooks Adds optional provider declarations for native protocol name, ordered adapter names, adapter config, and field-cache rules, all defaulting to empty/no-op behavior so existing providers remain on the current execution path until they opt in. These methods are the Phase 3 bridge that later provider work will use to attach native protocols, adapter chains, and provider-specific field-cache rules per model. Tests: python -m pytest tests/test_provider_protocol_declarations.py tests/test_adapter_registry.py tests/test_field_cache_engine.py tests/test_field_cache_paths.py tests/test_field_cache_trace.py --- .../providers/provider_interface.py | 49 +++++++++++++++++++ tests/test_provider_protocol_declarations.py | 39 +++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tests/test_provider_protocol_declarations.py diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index fb5811bf4..8b4696cad 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -270,6 +270,13 @@ class ProviderInterface(ABC, metaclass=SingletonABCMeta): Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]] ] = {} + # Native protocol/adapter declarations introduced for the experimental + # protocol stack. Defaults are intentionally no-op so existing providers keep + # the LiteLLM-backed execution path until they opt into native protocols. + protocol_name: Optional[str] = None + adapter_names: Tuple[str, ...] = () + field_cache_rules: Tuple[Any, ...] = () + @abstractmethod async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: """ @@ -313,6 +320,48 @@ def get_session_tracking_hints( """ return None + def get_protocol_name(self, model: str = "") -> Optional[str]: + """Return the native protocol adapter name this provider prefers. + + Providers may override this method when protocol choice varies by model. + Returning ``None`` keeps the current fallback execution behavior. Later + provider phases will use this as the bridge from provider declarations to + native protocol parsing/building. + """ + + return self.protocol_name + + def get_adapter_names(self, model: str = "") -> Tuple[str, ...]: + """Return ordered adapter names for this provider/model. + + The order is significant and is preserved by the adapter chain runner. + Providers can override this for model-specific quirks without mutating the + global adapter registry. + """ + + return tuple(self.adapter_names) + + def get_adapter_config(self, model: str = "") -> Dict[str, Dict[str, Any]]: + """Return adapter-specific config keyed by adapter name. + + Config is intentionally a plain dict so custom providers can define it in + env/JSON later without importing adapter classes. Phase 10 will add formal + config loading and validation. + """ + + return {} + + def get_field_cache_rules(self, model: str = "") -> Tuple[Any, ...]: + """Return field-cache rules for provider-specific protocol state. + + Rules preserve provider state such as reasoning content, thought + signatures, prompt cache keys, provider session IDs, and response IDs. + They are not a replacement for ``SessionTracker``; session tracking still + decides continuity and credential affinity. + """ + + return tuple(self.field_cache_rules) + async def acompletion( self, client: httpx.AsyncClient, **kwargs ) -> Union[ diff --git a/tests/test_provider_protocol_declarations.py b/tests/test_provider_protocol_declarations.py new file mode 100644 index 000000000..2465a7201 --- /dev/null +++ b/tests/test_provider_protocol_declarations.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from rotator_library.field_cache import FieldCacheRule +from rotator_library.providers.provider_interface import ProviderInterface + + +class DeclarationProvider(ProviderInterface): + protocol_name = "openai_chat" + adapter_names = ("model_override", "suppress_developer_role") + field_cache_rules = ( + FieldCacheRule(name="reasoning_content", source="response", path="choices.*.message.reasoning_content"), + ) + + async def get_models(self, api_key, client): + return [] + + +def test_provider_interface_defaults_are_noop_for_protocol_stack() -> None: + provider = DeclarationProvider() + + assert provider.get_protocol_name("model") == "openai_chat" + assert provider.get_adapter_names("model") == ("model_override", "suppress_developer_role") + assert provider.get_adapter_config("model") == {} + assert provider.get_field_cache_rules("model")[0].name == "reasoning_content" + + +def test_provider_interface_methods_can_be_model_specific() -> None: + class ModelSpecificProvider(DeclarationProvider): + def get_protocol_name(self, model: str = ""): + return "responses" if "response" in model else super().get_protocol_name(model) + + def get_adapter_config(self, model: str = ""): + return {"model_override": {"model": f"native/{model}"}} + + provider = ModelSpecificProvider() + + assert provider.get_protocol_name("response-model") == "responses" + assert provider.get_protocol_name("chat-model") == "openai_chat" + assert provider.get_adapter_config("chat-model") == {"model_override": {"model": "native/chat-model"}} From 27aa7c1a924f9735ba2016625b83e66f1971a6c7 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:08:23 +0200 Subject: [PATCH 019/182] fix(field-cache): close Phase 3 review gaps Adds the planned field_rename adapter, fixes field-cache trace direction for stream-sourced request injection, caps trace sample values, and documents the current limits of turn/tool-cache modes. Expands coverage for credential/provider scope isolation, stream-sourced injection trace direction, large sample truncation, field_rename behavior, and plain provider no-op protocol defaults. Tests: python -m pytest tests/test_adapter_registry.py tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_provider_protocol_declarations.py Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py tests/test_session_tracking.py tests/test_selection_engine.py --- src/rotator_library/adapters/builtin.py | 61 ++++++++++++++++++++ src/rotator_library/field_cache/engine.py | 35 +++++++++-- tests/test_adapter_registry.py | 22 +++++++ tests/test_field_cache_engine.py | 39 +++++++++++++ tests/test_field_cache_trace.py | 20 +++++++ tests/test_provider_protocol_declarations.py | 14 +++++ 6 files changed, 187 insertions(+), 4 deletions(-) diff --git a/src/rotator_library/adapters/builtin.py b/src/rotator_library/adapters/builtin.py index 08c588688..7406729b0 100644 --- a/src/rotator_library/adapters/builtin.py +++ b/src/rotator_library/adapters/builtin.py @@ -8,6 +8,7 @@ from copy import deepcopy from typing import Any +from ..field_cache.paths import extract_path, inject_path from .base import AdapterContext, PayloadAdapter @@ -101,3 +102,63 @@ async def transform_response(self, payload: Any, context: AdapterContext) -> Any message[output_field] = message[source_field] break return updated + + +class FieldRenameAdapter(PayloadAdapter): + """Copy configured values between paths on raw payloads. + + Config shape: + `{ "rules": [{ "source_path": "a.b", "target_path": "c.d", "stage": "request", "move": false }] }` + + This adapter is conservative by design: it copies the last matched value by + default, delegates target ambiguity checks to `inject_path`, and only removes + the source for simple dotted-key paths when `move=true`. + """ + + name = "field_rename" + aliases = ("field_copy",) + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + return self._transform_stage(payload, context, "request") + + async def transform_response(self, payload: Any, context: AdapterContext) -> Any: + return self._transform_stage(payload, context, "response") + + async def transform_stream_event(self, payload: Any, context: AdapterContext) -> Any: + return self._transform_stage(payload, context, "stream_event") + + def _transform_stage(self, payload: Any, context: AdapterContext, stage: str) -> Any: + if not isinstance(payload, dict): + return payload + updated = deepcopy(payload) + for rule in context.config_for(self.name).get("rules", []): + if rule.get("stage", stage) != stage: + continue + values = extract_path(updated, rule["source_path"]) + if not values: + continue + value = values if rule.get("as_list") else values[-1] + inject_path( + updated, + rule["target_path"], + value, + when_missing_only=bool(rule.get("when_missing_only", False)), + ) + if rule.get("move"): + _delete_simple_path(updated, rule["source_path"]) + return updated + + +def _delete_simple_path(payload: dict[str, Any], path: str) -> None: + """Delete a simple dotted dict path after a conservative move operation.""" + + parts = path.split(".") + if any("[" in part or part == "*" for part in parts): + return + current: Any = payload + for part in parts[:-1]: + if not isinstance(current, dict): + return + current = current.get(part) + if isinstance(current, dict): + current.pop(parts[-1], None) diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index f0f328410..58b43e962 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -57,7 +57,9 @@ class FieldCacheEngine: The engine is isolated from request execution in Phase 3. It defaults to copying payloads before injection so tests and future providers can reason - about mutations explicitly. + about mutations explicitly. Turn-specific and per-tool modes are validated + but limited until later phases add richer conversation indexing and + per-tool-call injection targets. """ def __init__(self, rules: Iterable[FieldCacheRule], store: Optional[FieldCacheStore] = None) -> None: @@ -98,7 +100,7 @@ async def extract( try: values = extract_path(payload, rule.path) operation.matched = len(values) - operation.sample_values = values[:3] + operation.sample_values = _sample_values(values) if values: await self._store_values(rule, operation.cache_key, values) operation.changed = True @@ -146,7 +148,7 @@ async def inject( value, when_missing_only=rule.inject.when_missing_only, ) - operation.sample_values = value if isinstance(value, list) else [value] + operation.sample_values = _sample_values(value if isinstance(value, list) else [value]) except Exception as exc: self._log_error(transaction_logger, "field_cache_inject", exc, updated, rule) raise @@ -192,7 +194,7 @@ def _trace( transaction_logger.log_transform_pass( pass_name, payload, - direction="stream" if rule.source == "stream_event" else "request" if "injection" in pass_name else "response", + direction=_trace_direction(pass_name, rule.source, extra_metadata), stage="adapter", metadata={ "rule_name": rule.name, @@ -228,3 +230,28 @@ def _last_value(value: Any) -> Any: if isinstance(value, list): return value[-1] if value else None return value + + +def _trace_direction(pass_name: str, source: str, metadata: dict[str, Any]) -> str: + if "injection" in pass_name: + target = metadata.get("target") + if target in {"stream_event", "unified_stream_event"}: + return "stream" + if target in {"request", "unified_request", "metadata"}: + return "request" + return "response" + if source in {"stream_event", "unified_stream_event"}: + return "stream" + if source in {"request", "unified_request"}: + return "request" + return "response" + + +def _sample_values(values: list[Any], *, max_items: int = 3, max_text: int = 500) -> list[Any]: + samples: list[Any] = [] + for value in values[:max_items]: + if isinstance(value, str) and len(value) > max_text: + samples.append(f"{value[:max_text]}...") + else: + samples.append(value) + return samples diff --git a/tests/test_adapter_registry.py b/tests/test_adapter_registry.py index 90050dc31..59965f4d5 100644 --- a/tests/test_adapter_registry.py +++ b/tests/test_adapter_registry.py @@ -19,7 +19,9 @@ def test_adapter_registry_auto_discovers_builtins_and_aliases() -> None: assert "noop" in adapters assert "model_override" in adapters + assert "field_rename" in adapters assert resolve_adapter_name("passthrough") == "noop" + assert resolve_adapter_name("field_copy") == "field_rename" assert get_adapter_class("override_model").name == "model_override" assert get_adapter("none") is get_adapter("noop") @@ -84,6 +86,26 @@ async def test_reasoning_content_adapter_copies_common_reasoning_field() -> None assert result["choices"][0]["message"]["reasoning_content"] == "hidden" +@pytest.mark.asyncio +async def test_field_rename_adapter_copies_and_moves_configured_fields() -> None: + payload = {"old": {"field": "value"}, "messages": [{}]} + context = AdapterContext( + adapter_config={ + "field_rename": { + "rules": [ + {"source_path": "old.field", "target_path": "messages[-1].new_field", "move": True} + ] + } + } + ) + + result = await run_adapter_chain([get_adapter("field_rename")], payload, context, stage="request") + + assert result["messages"][-1]["new_field"] == "value" + assert "field" not in result["old"] + assert payload["old"]["field"] == "value" + + @pytest.mark.asyncio async def test_adapter_chain_order_is_preserved() -> None: payload = {"model": "public", "messages": [{"role": "developer", "content": "rules"}]} diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py index 98d6dbbf2..6a47545a6 100644 --- a/tests/test_field_cache_engine.py +++ b/tests/test_field_cache_engine.py @@ -83,6 +83,30 @@ async def test_scope_isolation_by_session_and_classifier() -> None: assert build_cache_key(rule, _context(session_id="session-a")) != build_cache_key(rule, _context(session_id="session-b")) +@pytest.mark.asyncio +async def test_scope_isolation_by_credential_and_provider() -> None: + rule = _reasoning_rule(scope=("provider", "model", "credential")) + engine = FieldCacheEngine([rule]) + + await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "cred-a"}}]}, + _context(provider="openai", credential_id="credential-a"), + ) + await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "cred-b"}}]}, + _context(provider="openai", credential_id="credential-b"), + ) + + updated, _ = await engine.inject("request", {"messages": [{}]}, _context(provider="openai", credential_id="credential-a")) + + assert updated["messages"][-1]["reasoning_content"] == "cred-a" + assert build_cache_key(rule, _context(provider="openai", credential_id="credential-a")) != build_cache_key( + rule, _context(provider="other", credential_id="credential-a") + ) + + @pytest.mark.asyncio async def test_missing_session_scope_skips_by_default() -> None: engine = FieldCacheEngine([_reasoning_rule()]) @@ -120,6 +144,21 @@ async def test_stream_event_extraction() -> None: assert updated["metadata"]["provider_session_id"] == "sid_1" +@pytest.mark.asyncio +async def test_trace_sample_values_are_truncated() -> None: + rule = _reasoning_rule() + engine = FieldCacheEngine([rule]) + long_value = "x" * 700 + + operations = await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": long_value}}]}, + _context(), + ) + + assert operations[0].sample_values[0].endswith("...") + + def test_per_tool_call_requires_tool_call_id_path() -> None: with pytest.raises(ValueError): FieldCacheEngine([ diff --git a/tests/test_field_cache_trace.py b/tests/test_field_cache_trace.py index 4ea423673..775f033bc 100644 --- a/tests/test_field_cache_trace.py +++ b/tests/test_field_cache_trace.py @@ -66,6 +66,26 @@ async def test_field_cache_extract_and_inject_emit_before_after_trace_entries(tm assert entries[3]["metadata"]["changed"] is True +@pytest.mark.asyncio +async def test_stream_sourced_rule_injection_trace_uses_request_direction(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="provider_session_id", + source="stream_event", + path="metadata.provider_session_id", + inject=FieldCacheInjection(target="request", path="metadata.provider_session_id"), + ) + engine = FieldCacheEngine([rule]) + context = FieldCacheContext(provider="openai", model="gpt-test", session_id="session_1", classifier="global") + + await engine.extract("stream_event", {"metadata": {"provider_session_id": "sid_1"}}, context, transaction_logger=logger) + await engine.inject("request", {"metadata": {}}, context, transaction_logger=logger) + + entries = _trace_entries(logger.log_dir) + injection_entries = [entry for entry in entries if "injection" in entry["pass_name"]] + assert {entry["direction"] for entry in injection_entries} == {"request"} + + @pytest.mark.asyncio async def test_field_cache_errors_emit_transform_log_error(tmp_path) -> None: logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) diff --git a/tests/test_provider_protocol_declarations.py b/tests/test_provider_protocol_declarations.py index 2465a7201..a16ee56f3 100644 --- a/tests/test_provider_protocol_declarations.py +++ b/tests/test_provider_protocol_declarations.py @@ -15,6 +15,20 @@ async def get_models(self, api_key, client): return [] +class BareProvider(ProviderInterface): + async def get_models(self, api_key, client): + return [] + + +def test_provider_interface_plain_defaults_are_noop() -> None: + provider = BareProvider() + + assert provider.get_protocol_name("model") is None + assert provider.get_adapter_names("model") == () + assert provider.get_adapter_config("model") == {} + assert provider.get_field_cache_rules("model") == () + + def test_provider_interface_defaults_are_noop_for_protocol_stack() -> None: provider = DeclarationProvider() From 4caf56a9caf6732183f23aa2ad7b664518e6403a Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:16:12 +0200 Subject: [PATCH 020/182] docs(experimental): plan responses api Adds the Phase 4 plan for Responses routes, response storage, previous_response_id continuation, bridge execution through the current client path, HTTP SSE conversion, WebSocket extension seams, tests, risks, and review checkpoints. Reports remain uncommitted for user-facing review only. --- docs/experimental/phase-4-responses-api.md | 285 +++++++++++++++++++++ 1 file changed, 285 insertions(+) create mode 100644 docs/experimental/phase-4-responses-api.md diff --git a/docs/experimental/phase-4-responses-api.md b/docs/experimental/phase-4-responses-api.md new file mode 100644 index 000000000..901510066 --- /dev/null +++ b/docs/experimental/phase-4-responses-api.md @@ -0,0 +1,285 @@ +# Phase 4 Plan: Responses API + +## Goal + +Add a first-class OpenAI-compatible Responses API surface with durable response storage, `previous_response_id` continuation support, HTTP SSE streaming, and an explicit WebSocket transport seam for future implementation. Phase 4 should use the Phase 1 Responses protocol, Phase 2 transform trace logging, and Phase 3 field-cache/adapter foundations without forcing all providers onto native protocols yet. + +## Non-Goals + +- Do not migrate every provider to native protocol execution in this phase. +- Do not remove or replace `/v1/chat/completions`. +- Do not replace `UsageManager`, `SessionTracker`, retry-after parsing, or current credential rotation. +- Do not implement full multi-user/security controls. +- Do not add SQLite or any database. +- Do not implement full provider-specific Responses APIs for Claude Code/Codex/Copilot yet. +- Do not implement WebSocket runtime handling yet; provide transport interfaces/seams and tests for HTTP SSE behavior. + +## Current Code Context + +- FastAPI routes live in `src/proxy_app/main.py`. +- `RotatingClient.acompletion()` is the stable execution entry point today. +- `ResponsesProtocol` already parses/formats Responses request/response/stream shapes and declares future `websocket` support. +- `TransactionLogger` can record new transform passes and errors. +- `FieldCacheEngine` and `ProviderCacheFieldStore` can preserve response IDs/output items but need targeted key deletion for Responses delete endpoints. +- Existing stream wrappers aggregate chat-completions SSE chunks, but Responses SSE needs its own event formatting instead of chat completion chunk output. +- Runtime execution should stay conservative: Phase 4 can bridge Responses requests to the existing completion path until native provider protocols are wired in later phases. + +## Files To Add + +- `src/rotator_library/responses/__init__.py` +- `src/rotator_library/responses/types.py` +- `src/rotator_library/responses/store.py` +- `src/rotator_library/responses/bridge.py` +- `src/rotator_library/responses/service.py` +- `src/rotator_library/responses/streaming.py` +- `tests/test_responses_store.py` +- `tests/test_responses_bridge.py` +- `tests/test_responses_service.py` +- `tests/test_responses_routes.py` +- `tests/test_responses_streaming.py` + +## Files Likely To Touch + +- `src/proxy_app/main.py` +- `src/rotator_library/client/rotating_client.py` +- `src/rotator_library/field_cache/store.py` if targeted delete is added to the store protocol. +- `src/rotator_library/__init__.py` only if lazy public exports are useful. +- Existing protocol tests only if the Responses adapter needs an additive event/output fix. + +## Route/API Scope + +`POST /v1/responses`: + +- Accept raw Responses JSON. +- Support non-streaming response. +- Support `stream: true` with HTTP SSE. +- Validate required `model`. +- Accept `input`, `instructions`, `tools`, `tool_choice`, `reasoning`, `metadata`, `include`, `store`, `previous_response_id`, and common generation params. + +`GET /v1/responses/{response_id}`: + +- Return stored response object when available. +- Return 404-compatible error if missing or deleted. + +`DELETE /v1/responses/{response_id}`: + +- Delete stored response metadata/output. +- Return `{ "id": ..., "object": "response.deleted", "deleted": true }` or equivalent compatible shape. + +`GET /v1/responses/{response_id}/input_items`: + +- Return stored input items for clients that need continuation inspection. + +Optional if simple: + +- `POST /v1/responses/{response_id}/cancel` for active stream cancellation can be planned but not fully implemented unless active stream tracking lands cleanly. + +## Response Storage + +`StoredResponse` fields: + +- `id` +- `created_at` +- `model` +- `status` +- `request` +- `response` +- `input_items` +- `output_items` +- `usage` +- `metadata` +- `session_id` +- `scope_key` +- `classifier` +- `expires_at` optional + +`ResponsesStore` protocol: + +- `save(stored_response)` +- `get(response_id)` +- `delete(response_id)` +- `list_input_items(response_id)` + +Store implementations: + +- `InMemoryResponsesStore` for tests and runtime default if no persistent cache is configured. +- `ProviderCacheResponsesStore` as an optional wrapper around injected `ProviderCache`. +- JSON serialization through `serialize_value()`. +- No global `ProviderCache` construction at import time. +- No SQLite. + +Targeted deletion: + +- Add `delete(key)` to `FieldCacheStore` only if reused directly. +- Prefer a dedicated Responses store so field cache remains small and focused. + +## Response IDs + +- Generate IDs like `resp_` when upstream response lacks an ID. +- Preserve upstream IDs when present. +- Store every response when request `store` is true or omitted if compatibility expects default storage. +- If `store: false`, return the response but do not persist it; `previous_response_id` cannot refer to it later. +- Include `previous_response_id` in stored metadata for lineage/debugging. + +## `previous_response_id` + +- On request with `previous_response_id`, load the stored parent response. +- If not found, return a 404/400-compatible invalid request error rather than silently ignoring. +- Build continuation context conservatively: + - Keep original current request input. + - Add parent response output items into protocol/service metadata for traceability. + - For bridge execution through chat completions, convert parent response message items into prior assistant messages where safe. + - Do not inject unknown output item types into chat messages unless `ResponsesProtocol` can represent them safely. +- Record transform trace pass: + - `responses_previous_response_loaded` + - Include parent ID, output count, input item count, and whether bridge context was expanded. +- This should line up with Phase 3 field-cache storage but does not need to fully rely on field-cache rules yet. + +## Bridge Execution + +`ResponsesBridge` responsibilities: + +- Convert `UnifiedRequest` from `ResponsesProtocol.parse_request()` into chat-completions kwargs for current `RotatingClient.acompletion()`. +- Map `model` to `model`. +- Map `instructions`/system blocks to a system message. +- Map input messages to chat messages. +- Map tool definitions to OpenAI tool shape when possible. +- Map generation params to existing compatible kwargs. +- Map `stream` to `stream`. +- Preserve unsupported Responses fields in trace metadata and request metadata, not silently discard them. +- Rebuild Responses output from chat completion responses using `ResponsesProtocol.format_response()`. + +This bridge is temporary compatibility, not the final native provider path. Docstrings/comments should explain that native provider execution will replace the bridge for covered providers in later phases. + +## Service + +`ResponsesService` responsibilities: + +- Own `ResponsesProtocol`, `ResponsesBridge`, `ResponsesStore`, and optional transform logging. +- `create_response(raw_request, client, request=None)` +- `stream_response(raw_request, client, request=None)` +- `get_response(response_id)` +- `delete_response(response_id)` +- `list_input_items(response_id)` + +Service transform passes: + +- `raw_responses_request` +- `parsed_unified_request` +- `responses_previous_response_loaded` +- `responses_bridge_chat_request` +- `raw_responses_provider_response` or `raw_chat_bridge_response` +- `parsed_unified_response` +- `stored_responses_response` +- `final_responses_response` + +Errors use `log_transform_error()`. + +## Streaming + +- HTTP SSE first. +- Convert chat-completions stream chunks into Responses SSE events. +- Required event names: + - `response.created` + - `response.output_item.added` + - `response.output_text.delta` + - `response.output_item.done` + - `response.completed` + - `response.failed` on errors + - final `data: [DONE]` only if compatibility requires it; prefer Responses-style events and document the chosen behavior. +- Preserve raw stream chunk trace: + - `raw_chat_bridge_stream_chunk` + - `parsed_unified_stream_event` + - `formatted_responses_stream_event` + - `stored_responses_stream_response` +- Accumulate final output items so streamed responses are retrievable by `GET /v1/responses/{id}` after completion. +- Handle client disconnect without crashing the server. +- Do not add the broader stream stall/backpressure overhaul; that belongs to the later streaming phase. + +## WebSocket Seam + +- Add transport-neutral interfaces: + - `ResponsesTransport` or small formatter abstraction. + - `ResponsesSSEFormatter`. + - `ResponsesWebSocketFormatter` placeholder or protocol with `NotImplementedError`. +- Keep protocol/service logic independent from `StreamingResponse`. +- Add tests that formatter/service API can select `sse` now and recognizes `websocket` as a future transport without implementing a WebSocket route. +- Do not expose a WebSocket endpoint yet unless it is a no-op documented placeholder; prefer no route over a misleading route. + +## FastAPI Integration + +- Add routes near existing OpenAI-compatible routes in `main.py`. +- Use existing `verify_api_key`. +- Use existing raw request logging where applicable. +- Use `JSONResponse` for normal responses. +- Use `StreamingResponse(..., media_type="text/event-stream")` for streams. +- Use OpenAI-compatible error response shape for validation/missing response IDs. +- Store service on app state during lifespan if it needs shared storage, or construct module-level in-memory store carefully if no async lifecycle is needed. Prefer `app.state.responses_service`. + +## Tests + +Store tests: + +- save/get/delete response +- input items list +- missing response returns None +- JSON serialization with provider-cache wrapper if implemented +- no SQLite/import-time cache construction + +Bridge tests: + +- Responses input string/list converts to chat messages +- instructions become system message +- tool definitions preserve function call shape +- unsupported fields are preserved in metadata/trace +- chat completion response formats back to Responses output + +Service tests: + +- non-stream create stores response +- `store: false` does not persist +- `previous_response_id` loads parent +- missing previous response raises useful error +- transform trace passes emitted + +Route tests: + +- `POST /v1/responses` non-stream success with fake client +- `POST /v1/responses` missing model returns 400 +- `GET /v1/responses/{id}` success and 404 +- `DELETE /v1/responses/{id}` success +- `GET /v1/responses/{id}/input_items` + +Streaming tests: + +- stream emits Responses SSE event names in order +- deltas aggregate into stored final response +- stream errors emit `response.failed` +- client disconnect path is safe if testable + +Regression tests: + +- Phase 1 protocol tests +- Phase 2 transform logging tests +- Phase 3 adapter/cache tests +- `test_session_tracking.py` +- `test_selection_engine.py` + +## Commit Checkpoints + +1. Add Responses store and tests. +2. Add Responses bridge and tests. +3. Add Responses service with non-stream create/get/delete/input-items and tests. +4. Add HTTP routes and route tests. +5. Add HTTP SSE formatter/streaming conversion and tests. +6. Add WebSocket transport seam tests/docstrings without runtime route. +7. Run focused and regression tests. +8. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 4 report. + +## Risks And Mitigations + +- Responses API is broader than chat completions. Keep the bridge explicit about unsupported fields and preserve them in metadata/trace rather than pretending full native support. +- `previous_response_id` can leak context across users/scopes if storage keys are global. Include request/session/scope/classifier metadata and leave full multi-user isolation for later, but avoid credential secrets in storage keys. +- Streaming conversion can lose item fidelity. Accumulate output text and tool/reasoning items explicitly; preserve raw chunks in trace. +- In-memory store will not survive restarts. This is acceptable for Phase 4 default; provider-cache-backed persistence can be enabled where lifecycle is available. +- WebSocket support can be over-promised. Expose the abstraction seam only, not a functioning route, until the later transport phase. From c881602561378d1ef4e6e12eba9ef8477e2ddf60 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:17:45 +0200 Subject: [PATCH 021/182] feat(responses): add response storage Adds the Phase 4 Responses storage foundation with StoredResponse, local response ID generation, an in-memory store, and a ProviderCache-backed wrapper that accepts an injected cache instead of constructing one globally. The store supports save, get, delete, and input item listing, preserves JSON-safe response metadata for previous_response_id continuation, and avoids SQLite or new persistence dependencies. Tests: python -m pytest tests/test_responses_store.py --- src/rotator_library/responses/__init__.py | 15 +++ src/rotator_library/responses/store.py | 100 +++++++++++++++++++ src/rotator_library/responses/types.py | 98 +++++++++++++++++++ tests/test_responses_store.py | 111 ++++++++++++++++++++++ 4 files changed, 324 insertions(+) create mode 100644 src/rotator_library/responses/__init__.py create mode 100644 src/rotator_library/responses/store.py create mode 100644 src/rotator_library/responses/types.py create mode 100644 tests/test_responses_store.py diff --git a/src/rotator_library/responses/__init__.py b/src/rotator_library/responses/__init__.py new file mode 100644 index 000000000..07959eedc --- /dev/null +++ b/src/rotator_library/responses/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Responses API service, storage, and streaming helpers.""" + +from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore +from .types import StoredResponse, generate_response_id + +__all__ = [ + "InMemoryResponsesStore", + "ProviderCacheResponsesStore", + "ResponsesStore", + "StoredResponse", + "generate_response_id", +] diff --git a/src/rotator_library/responses/store.py b/src/rotator_library/responses/store.py new file mode 100644 index 000000000..549ee64ba --- /dev/null +++ b/src/rotator_library/responses/store.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Storage backends for Responses API objects.""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, Optional, Protocol + +from .types import StoredResponse + + +class ResponsesStore(Protocol): + """Minimal async store for response retrieval and continuation.""" + + async def save(self, response: StoredResponse) -> None: ... + + async def get(self, response_id: str) -> Optional[StoredResponse]: ... + + async def delete(self, response_id: str) -> bool: ... + + async def list_input_items(self, response_id: str) -> Optional[list[Any]]: ... + + +class InMemoryResponsesStore: + """Process-local Responses store. + + This is the Phase 4 default because it has no async lifecycle and avoids a + new persistence dependency. A provider-cache-backed store can be injected by + later configuration code when disk persistence is desired. + """ + + def __init__(self) -> None: + self._responses: dict[str, StoredResponse] = {} + + async def save(self, response: StoredResponse) -> None: + self._responses[response.id] = StoredResponse.from_dict(response.to_dict()) + + async def get(self, response_id: str) -> Optional[StoredResponse]: + response = self._responses.get(response_id) + if response is None: + return None + if response.is_expired(): + self._responses.pop(response_id, None) + return None + return StoredResponse.from_dict(response.to_dict()) + + async def delete(self, response_id: str) -> bool: + return self._responses.pop(response_id, None) is not None + + async def list_input_items(self, response_id: str) -> Optional[list[Any]]: + response = await self.get(response_id) + if response is None: + return None + return deepcopy(response.input_items) + + +class ProviderCacheResponsesStore: + """Responses store backed by an injected `ProviderCache` instance. + + The wrapper does not instantiate `ProviderCache`, because that class starts + background tasks. The caller owns cache lifecycle and shutdown. + """ + + def __init__(self, provider_cache: Any, *, prefix: str = "responses") -> None: + self._cache = provider_cache + self._prefix = prefix + + async def save(self, response: StoredResponse) -> None: + await self._cache.store_async(self._key(response.id), json.dumps(response.to_dict(), ensure_ascii=False)) + + async def get(self, response_id: str) -> Optional[StoredResponse]: + raw = await self._cache.retrieve_async(self._key(response_id)) + if raw is None: + return None + response = StoredResponse.from_dict(json.loads(raw)) + if response.is_expired(): + await self.delete(response_id) + return None + return response + + async def delete(self, response_id: str) -> bool: + delete = getattr(self._cache, "delete_async", None) + if delete: + return bool(await delete(self._key(response_id))) + # ProviderCache currently exposes clear(), not key-level deletion. When + # key deletion is unavailable, avoid clearing unrelated provider state. + return False + + async def list_input_items(self, response_id: str) -> Optional[list[Any]]: + response = await self.get(response_id) + if response is None: + return None + return deepcopy(response.input_items) + + def _key(self, response_id: str) -> str: + safe_id = response_id.replace("/", "_").replace("\\", "_").replace(":", "_") + return f"{self._prefix}:{safe_id}" diff --git a/src/rotator_library/responses/types.py b/src/rotator_library/responses/types.py new file mode 100644 index 000000000..5edb78297 --- /dev/null +++ b/src/rotator_library/responses/types.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Data types for the Responses API compatibility layer.""" + +from __future__ import annotations + +import secrets +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +from ..protocols import serialize_value + + +def generate_response_id() -> str: + """Return a local Responses-compatible identifier. + + Upstream IDs are preserved when providers return them. This helper is only + used by the bridge path when the current chat-completions backend has no + native Responses ID to expose. + """ + + return f"resp_{secrets.token_urlsafe(18).replace('-', '').replace('_', '')[:24]}" + + +@dataclass +class StoredResponse: + """Persisted response object used for retrieval and continuation. + + The shape stores both client-facing response data and enough request/session + metadata for `previous_response_id` debugging. It deliberately avoids storing + credential secrets; callers should pass only stable identifiers if they need + credential correlation later. + """ + + id: str + model: str + status: str + response: dict[str, Any] + request: dict[str, Any] = field(default_factory=dict) + input_items: list[Any] = field(default_factory=list) + output_items: list[Any] = field(default_factory=list) + usage: Optional[dict[str, Any]] = None + metadata: dict[str, Any] = field(default_factory=dict) + created_at: float = field(default_factory=time.time) + session_id: Optional[str] = None + scope_key: Optional[str] = None + classifier: Optional[str] = None + expires_at: Optional[float] = None + + def is_expired(self, now: Optional[float] = None) -> bool: + """Return whether the response should be treated as unavailable.""" + + return self.expires_at is not None and (now if now is not None else time.time()) >= self.expires_at + + def to_dict(self) -> dict[str, Any]: + """Serialize with JSON-safe values for disk/cache persistence.""" + + return serialize_value( + { + "id": self.id, + "created_at": self.created_at, + "model": self.model, + "status": self.status, + "request": self.request, + "response": self.response, + "input_items": self.input_items, + "output_items": self.output_items, + "usage": self.usage, + "metadata": self.metadata, + "session_id": self.session_id, + "scope_key": self.scope_key, + "classifier": self.classifier, + "expires_at": self.expires_at, + } + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "StoredResponse": + """Rehydrate a stored response from a JSON-compatible dict.""" + + return cls( + id=str(data["id"]), + created_at=float(data.get("created_at") or time.time()), + model=str(data.get("model") or ""), + status=str(data.get("status") or "completed"), + request=dict(data.get("request") or {}), + response=dict(data.get("response") or {}), + input_items=list(data.get("input_items") or []), + output_items=list(data.get("output_items") or []), + usage=data.get("usage") if isinstance(data.get("usage"), dict) else None, + metadata=dict(data.get("metadata") or {}), + session_id=data.get("session_id"), + scope_key=data.get("scope_key"), + classifier=data.get("classifier"), + expires_at=data.get("expires_at"), + ) diff --git a/tests/test_responses_store.py b/tests/test_responses_store.py new file mode 100644 index 000000000..0f8ff720c --- /dev/null +++ b/tests/test_responses_store.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import json +import time + +import pytest + +from rotator_library.responses import InMemoryResponsesStore, ProviderCacheResponsesStore, StoredResponse, generate_response_id + + +def _stored(response_id: str = "resp_test") -> StoredResponse: + return StoredResponse( + id=response_id, + model="gpt-test", + status="completed", + request={"model": "gpt-test", "input": "hello"}, + response={"id": response_id, "object": "response", "output": []}, + input_items=[{"type": "message", "role": "user", "content": "hello"}], + output_items=[{"type": "message", "role": "assistant", "content": []}], + usage={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + metadata={"previous_response_id": None}, + ) + + +def test_generate_response_id_uses_responses_prefix() -> None: + assert generate_response_id().startswith("resp_") + assert generate_response_id() != generate_response_id() + + +@pytest.mark.asyncio +async def test_in_memory_store_save_get_delete_and_input_items() -> None: + store = InMemoryResponsesStore() + stored = _stored() + + await store.save(stored) + loaded = await store.get(stored.id) + input_items = await store.list_input_items(stored.id) + deleted = await store.delete(stored.id) + + assert loaded is not None + assert loaded.to_dict() == stored.to_dict() + assert input_items == stored.input_items + assert deleted is True + assert await store.get(stored.id) is None + + +@pytest.mark.asyncio +async def test_in_memory_store_returns_copies_and_expires() -> None: + store = InMemoryResponsesStore() + stored = _stored("resp_expiring") + stored.expires_at = time.time() + 100 + await store.save(stored) + + loaded = await store.get(stored.id) + assert loaded is not None + loaded.response["mutated"] = True + assert (await store.get(stored.id)).response.get("mutated") is None + + stored.expires_at = time.time() - 1 + await store.save(stored) + assert await store.get(stored.id) is None + + +@pytest.mark.asyncio +async def test_provider_cache_store_serializes_json_and_does_not_clear_without_key_delete() -> None: + class FakeProviderCache: + def __init__(self) -> None: + self.values = {} + + async def store_async(self, key: str, value: str) -> None: + json.loads(value) + self.values[key] = value + + async def retrieve_async(self, key: str): + return self.values.get(key) + + cache = FakeProviderCache() + store = ProviderCacheResponsesStore(cache) + stored = _stored("resp_provider_cache") + + await store.save(stored) + loaded = await store.get(stored.id) + + assert loaded is not None + assert loaded.id == stored.id + assert await store.list_input_items(stored.id) == stored.input_items + assert await store.delete(stored.id) is False + assert cache.values + + +@pytest.mark.asyncio +async def test_provider_cache_store_uses_key_delete_when_available() -> None: + class FakeProviderCache: + def __init__(self) -> None: + self.values = {} + + async def store_async(self, key: str, value: str) -> None: + self.values[key] = value + + async def retrieve_async(self, key: str): + return self.values.get(key) + + async def delete_async(self, key: str) -> bool: + return self.values.pop(key, None) is not None + + store = ProviderCacheResponsesStore(FakeProviderCache()) + stored = _stored("resp_delete") + await store.save(stored) + + assert await store.delete(stored.id) is True + assert await store.get(stored.id) is None From f3541ab73547b2c091525dd13d35c96f4df24d1f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:26:24 +0200 Subject: [PATCH 022/182] feat(responses): add chat bridge Adds the temporary Responses-to-chat bridge for Phase 4, converting parsed Responses requests into current chat-completions kwargs and converting chat-completion responses back into Responses objects. The bridge preserves previous_response_id metadata, parent response messages, tool definitions, generation parameters, and unsupported extension fields for trace/debugging until native provider execution is wired in later phases. Tests: python -m pytest tests/test_responses_bridge.py tests/test_responses_store.py --- src/rotator_library/responses/__init__.py | 2 + src/rotator_library/responses/bridge.py | 242 ++++++++++++++++++++++ tests/test_responses_bridge.py | 108 ++++++++++ 3 files changed, 352 insertions(+) create mode 100644 src/rotator_library/responses/bridge.py create mode 100644 tests/test_responses_bridge.py diff --git a/src/rotator_library/responses/__init__.py b/src/rotator_library/responses/__init__.py index 07959eedc..eda3207e5 100644 --- a/src/rotator_library/responses/__init__.py +++ b/src/rotator_library/responses/__init__.py @@ -3,12 +3,14 @@ """Responses API service, storage, and streaming helpers.""" +from .bridge import ResponsesBridge from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore from .types import StoredResponse, generate_response_id __all__ = [ "InMemoryResponsesStore", "ProviderCacheResponsesStore", + "ResponsesBridge", "ResponsesStore", "StoredResponse", "generate_response_id", diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py new file mode 100644 index 000000000..eb919731a --- /dev/null +++ b/src/rotator_library/responses/bridge.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Bridge between Responses requests and the current chat-completions executor.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Optional + +from ..protocols import ContentBlock, ToolDefinition, UnifiedMessage, UnifiedRequest, serialize_value +from ..protocols.responses import ResponsesProtocol +from .types import generate_response_id + +_CHAT_GENERATION_KEYS = { + "frequency_penalty", + "logit_bias", + "logprobs", + "max_completion_tokens", + "max_tokens", + "max_output_tokens", + "metadata", + "n", + "parallel_tool_calls", + "presence_penalty", + "reasoning", + "reasoning_effort", + "response_format", + "seed", + "stop", + "stream_options", + "temperature", + "tool_choice", + "top_logprobs", + "top_p", + "user", +} + + +class ResponsesBridge: + """Temporary compatibility bridge from Responses to chat completions. + + Later provider phases should call native Responses-capable providers directly. + Until then, this bridge makes `/v1/responses` useful while preserving enough + metadata and trace detail to debug fields that cannot be represented in chat. + """ + + def __init__(self, protocol: Optional[ResponsesProtocol] = None) -> None: + self.protocol = protocol or ResponsesProtocol() + + def to_chat_kwargs( + self, + unified: UnifiedRequest, + *, + parent_response: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """Convert a parsed Responses request to chat-completions kwargs.""" + + messages: list[dict[str, Any]] = [] + system_text = _blocks_to_text(unified.system) + if system_text: + messages.append({"role": "system", "content": system_text}) + if parent_response: + messages.extend(_parent_output_to_messages(parent_response.get("output") or [])) + messages.extend(_message_to_chat(message) for message in unified.messages) + kwargs: dict[str, Any] = { + "model": unified.model, + "messages": messages, + "stream": unified.stream, + } + if unified.tools: + kwargs["tools"] = [_tool_to_chat(tool) for tool in unified.tools] + for key, value in unified.generation_params.items(): + if key in _CHAT_GENERATION_KEYS: + kwargs[_chat_generation_key(key)] = deepcopy(value) + if unified.metadata: + kwargs.setdefault("metadata", deepcopy(unified.metadata)) + unsupported = { + "previous_response_id": unified.previous_response_id, + "extra": deepcopy(unified.extra), + } + kwargs["_responses_bridge"] = {k: v for k, v in unsupported.items() if v} + return kwargs + + def from_chat_response( + self, + chat_response: Any, + unified_request: UnifiedRequest, + *, + response_id: Optional[str] = None, + ) -> dict[str, Any]: + """Convert a chat-completions response into a Responses object.""" + + response = _as_dict(chat_response) + output = [] + for index, choice in enumerate(response.get("choices") or []): + if not isinstance(choice, dict): + continue + message = choice.get("message") or {} + output.extend(_chat_message_to_output_items(message, index)) + responses_payload = { + "id": response_id or response.get("id") or generate_response_id(), + "object": "response", + "created_at": response.get("created") or response.get("created_at"), + "model": response.get("model") or unified_request.model, + "status": _status_from_chat(response), + "output": output, + "usage": _usage_to_responses(response.get("usage")), + } + if unified_request.metadata: + responses_payload["metadata"] = deepcopy(unified_request.metadata) + return self.protocol.format_response(self.protocol.parse_response(responses_payload)) + + +def _chat_generation_key(key: str) -> str: + if key == "max_output_tokens": + return "max_tokens" + return key + + +def _message_to_chat(message: UnifiedMessage) -> dict[str, Any]: + payload = {"role": message.role, "content": _blocks_to_chat_content(message.content)} + if message.name: + payload["name"] = message.name + if message.tool_call_id: + payload["tool_call_id"] = message.tool_call_id + if message.tool_calls: + payload["tool_calls"] = [call.to_dict() for call in message.tool_calls] + return payload + + +def _blocks_to_chat_content(blocks: list[ContentBlock]) -> Any: + if not blocks: + return "" + if len(blocks) == 1 and blocks[0].text is not None: + return blocks[0].text + content = [] + for block in blocks: + if block.text is not None: + content.append({"type": "text", "text": block.text}) + elif block.source is not None: + content.append({"type": block.type, "source": deepcopy(block.source)}) + elif block.raw is not None: + content.append(deepcopy(block.raw)) + return content or "" + + +def _blocks_to_text(blocks: list[ContentBlock]) -> str: + return "".join(block.text or "" for block in blocks if block.text) + + +def _tool_to_chat(tool: ToolDefinition) -> dict[str, Any]: + if tool.type == "function": + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": deepcopy(tool.input_schema), + }, + } + payload = {"type": tool.type, "name": tool.name} + payload.update(deepcopy(tool.extra)) + return payload + + +def _parent_output_to_messages(output: list[Any]) -> list[dict[str, Any]]: + messages = [] + for item in output: + if not isinstance(item, dict) or item.get("type") != "message": + continue + text = _responses_content_to_text(item.get("content") or []) + if text: + messages.append({"role": item.get("role") or "assistant", "content": text}) + return messages + + +def _chat_message_to_output_items(message: dict[str, Any], index: int) -> list[dict[str, Any]]: + items = [] + content = message.get("content") + if content is not None: + items.append( + { + "id": f"msg_{index}", + "type": "message", + "role": message.get("role") or "assistant", + "content": [{"type": "output_text", "text": content if isinstance(content, str) else serialize_value(content)}], + } + ) + for tool_call in message.get("tool_calls") or []: + function = tool_call.get("function") if isinstance(tool_call, dict) else None + if isinstance(function, dict): + items.append( + { + "id": tool_call.get("id"), + "type": "function_call", + "call_id": tool_call.get("id"), + "name": function.get("name"), + "arguments": function.get("arguments"), + } + ) + return items + + +def _responses_content_to_text(content: list[Any]) -> str: + parts = [] + for block in content: + if isinstance(block, dict): + text = block.get("text") + if text: + parts.append(str(text)) + elif isinstance(block, str): + parts.append(block) + return "".join(parts) + + +def _status_from_chat(response: dict[str, Any]) -> str: + for choice in response.get("choices") or []: + finish_reason = choice.get("finish_reason") if isinstance(choice, dict) else None + if finish_reason in {"length", "content_filter"}: + return "incomplete" + return "completed" + + +def _usage_to_responses(usage: Any) -> Any: + if not isinstance(usage, dict): + return usage + return { + "input_tokens": usage.get("prompt_tokens", usage.get("input_tokens", 0)), + "output_tokens": usage.get("completion_tokens", usage.get("output_tokens", 0)), + "total_tokens": usage.get("total_tokens", 0), + } + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return deepcopy(value) + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + return serialize_value(value) if isinstance(serialize_value(value), dict) else {} diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py new file mode 100644 index 000000000..9c2fdc836 --- /dev/null +++ b/tests/test_responses_bridge.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from rotator_library.protocols.responses import ResponsesProtocol +from rotator_library.responses import ResponsesBridge + + +def test_bridge_converts_responses_request_to_chat_kwargs() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request( + { + "model": "gpt-test", + "instructions": "Follow rules.", + "input": "Hello", + "max_output_tokens": 20, + "metadata": {"trace": "yes"}, + } + ) + + kwargs = bridge.to_chat_kwargs(unified) + + assert kwargs["model"] == "gpt-test" + assert kwargs["messages"] == [ + {"role": "system", "content": "Follow rules."}, + {"role": "user", "content": "Hello"}, + ] + assert kwargs["max_tokens"] == 20 + assert kwargs["metadata"] == {"trace": "yes"} + + +def test_bridge_adds_parent_response_messages_for_previous_response_id() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}) + parent = {"output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Earlier"}]}]} + + kwargs = bridge.to_chat_kwargs(unified, parent_response=parent) + + assert kwargs["messages"] == [ + {"role": "assistant", "content": "Earlier"}, + {"role": "user", "content": "Continue"}, + ] + assert kwargs["_responses_bridge"]["previous_response_id"] == "resp_parent" + + +def test_bridge_preserves_tool_definitions() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request( + { + "model": "gpt-test", + "input": "Use tool", + "tools": [{"type": "function", "name": "lookup", "description": "Lookup", "parameters": {"type": "object"}}], + } + ) + + kwargs = bridge.to_chat_kwargs(unified) + + assert kwargs["tools"] == [ + {"type": "function", "function": {"name": "lookup", "description": "Lookup", "parameters": {"type": "object"}}} + ] + + +def test_bridge_converts_chat_response_to_responses_payload() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hello"}) + chat_response = { + "id": "chat_1", + "model": "gpt-test", + "created": 123, + "choices": [{"message": {"role": "assistant", "content": "Hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + + response = bridge.from_chat_response(chat_response, unified) + + assert response["id"] == "chat_1" + assert response["object"] == "response" + assert response["status"] == "completed" + assert response["output"][0]["content"] == [{"type": "output_text", "text": "Hi"}] + assert response["usage"]["input_tokens"] == 1 + assert response["usage"]["output_tokens"] == 2 + assert response["usage"]["total_tokens"] == 3 + + +def test_bridge_converts_chat_tool_calls_to_responses_output_items() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hello"}) + chat_response = { + "model": "gpt-test", + "choices": [ + { + "message": { + "role": "assistant", + "tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": "{}"}}], + } + } + ], + } + + response = bridge.from_chat_response(chat_response, unified, response_id="resp_tool") + + assert response["id"] == "resp_tool" + assert response["output"] == [ + {"id": "call_1", "type": "function_call", "call_id": "call_1", "name": "lookup", "arguments": "{}"} + ] From f1839cdc974a96fde72c3bdc706737615c889055 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:28:10 +0200 Subject: [PATCH 023/182] feat(responses): add response service Adds the non-streaming Responses service around the protocol adapter, bridge, and response store with validation, previous_response_id loading, get/delete/input-items helpers, and transform trace passes. The service keeps Phase 4 runtime conservative by bridging through the existing chat completion client path while preserving response storage and lineage metadata for later native provider work. Tests: python -m pytest tests/test_responses_service.py tests/test_responses_bridge.py tests/test_responses_store.py --- src/rotator_library/responses/__init__.py | 3 + src/rotator_library/responses/service.py | 192 ++++++++++++++++++++++ tests/test_responses_service.py | 122 ++++++++++++++ 3 files changed, 317 insertions(+) create mode 100644 src/rotator_library/responses/service.py create mode 100644 tests/test_responses_service.py diff --git a/src/rotator_library/responses/__init__.py b/src/rotator_library/responses/__init__.py index eda3207e5..6a20a30e7 100644 --- a/src/rotator_library/responses/__init__.py +++ b/src/rotator_library/responses/__init__.py @@ -4,6 +4,7 @@ """Responses API service, storage, and streaming helpers.""" from .bridge import ResponsesBridge +from .service import ResponsesService, ResponsesServiceError from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore from .types import StoredResponse, generate_response_id @@ -11,6 +12,8 @@ "InMemoryResponsesStore", "ProviderCacheResponsesStore", "ResponsesBridge", + "ResponsesService", + "ResponsesServiceError", "ResponsesStore", "StoredResponse", "generate_response_id", diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py new file mode 100644 index 000000000..fc1131476 --- /dev/null +++ b/src/rotator_library/responses/service.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Service layer for the OpenAI-compatible Responses API.""" + +from __future__ import annotations + +import time +from copy import deepcopy +from typing import Any, Optional + +from ..protocols import ProtocolContext +from ..protocols.responses import ResponsesProtocol +from .bridge import ResponsesBridge +from .store import InMemoryResponsesStore, ResponsesStore +from .types import StoredResponse + + +class ResponsesServiceError(ValueError): + """Error with an HTTP-compatible status code for proxy routes.""" + + def __init__(self, message: str, *, status_code: int = 400, error_type: str = "invalid_request_error") -> None: + self.status_code = status_code + self.error_type = error_type + super().__init__(message) + + +class ResponsesService: + """Create, store, retrieve, and delete Responses API objects. + + Phase 4 deliberately bridges through the existing chat-completions execution + path. Native Responses-capable provider execution will replace the bridge for + covered providers in later phases without changing the route/storage surface. + """ + + def __init__( + self, + *, + protocol: Optional[ResponsesProtocol] = None, + bridge: Optional[ResponsesBridge] = None, + store: Optional[ResponsesStore] = None, + ) -> None: + self.protocol = protocol or ResponsesProtocol() + self.bridge = bridge or ResponsesBridge(self.protocol) + self.store = store or InMemoryResponsesStore() + + async def create_response( + self, + raw_request: dict[str, Any], + client: Any, + *, + request: Optional[Any] = None, + transaction_logger: Optional[Any] = None, + ) -> dict[str, Any]: + """Create a non-streaming Responses object through the chat bridge.""" + + if not raw_request.get("model"): + raise ResponsesServiceError("'model' is required", status_code=400) + if raw_request.get("stream"): + raise ResponsesServiceError("Use stream_response for streaming requests", status_code=400) + + self._trace(transaction_logger, "raw_responses_request", raw_request, direction="request", stage="client") + unified = self.protocol.parse_request(raw_request, ProtocolContext(source_protocol="responses")) + self._trace(transaction_logger, "parsed_unified_request", unified.to_dict(), direction="request", stage="protocol") + + parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) + chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) + bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) + self._trace( + transaction_logger, + "responses_bridge_chat_request", + chat_kwargs, + direction="request", + stage="adapter", + metadata={"bridge_metadata": bridge_metadata}, + ) + + chat_response = await client.acompletion(request=request, **chat_kwargs) + self._trace(transaction_logger, "raw_chat_bridge_response", self._response_to_dict(chat_response), direction="response", stage="provider") + + response_payload = self.bridge.from_chat_response(chat_response, unified) + self._trace(transaction_logger, "parsed_unified_response", response_payload, direction="response", stage="protocol") + + if raw_request.get("store", True): + stored = self._stored_response(raw_request, response_payload, parent) + await self.store.save(stored) + self._trace(transaction_logger, "stored_responses_response", stored.to_dict(), direction="metadata", stage="final") + + self._trace(transaction_logger, "final_responses_response", response_payload, direction="response", stage="final") + return response_payload + + async def get_response(self, response_id: str) -> dict[str, Any]: + """Return a stored response payload or raise a 404-compatible error.""" + + stored = await self.store.get(response_id) + if stored is None: + raise ResponsesServiceError(f"Response not found: {response_id}", status_code=404, error_type="not_found_error") + return deepcopy(stored.response) + + async def delete_response(self, response_id: str) -> dict[str, Any]: + """Delete a stored response and return a compatible deletion object.""" + + deleted = await self.store.delete(response_id) + if not deleted: + raise ResponsesServiceError(f"Response not found: {response_id}", status_code=404, error_type="not_found_error") + return {"id": response_id, "object": "response.deleted", "deleted": True} + + async def list_input_items(self, response_id: str) -> dict[str, Any]: + """Return stored input items for a response continuation.""" + + items = await self.store.list_input_items(response_id) + if items is None: + raise ResponsesServiceError(f"Response not found: {response_id}", status_code=404, error_type="not_found_error") + return {"object": "list", "data": items} + + async def _load_previous_response(self, response_id: Optional[str], transaction_logger: Optional[Any]) -> Optional[StoredResponse]: + if not response_id: + return None + parent = await self.store.get(response_id) + if parent is None: + raise ResponsesServiceError(f"Previous response not found: {response_id}", status_code=404, error_type="not_found_error") + self._trace( + transaction_logger, + "responses_previous_response_loaded", + parent.to_dict(), + direction="metadata", + stage="adapter", + metadata={ + "previous_response_id": response_id, + "output_count": len(parent.output_items), + "input_item_count": len(parent.input_items), + "bridge_context_expanded": True, + }, + ) + return parent + + def _stored_response( + self, + raw_request: dict[str, Any], + response_payload: dict[str, Any], + parent: Optional[StoredResponse], + ) -> StoredResponse: + return StoredResponse( + id=str(response_payload["id"]), + created_at=float(response_payload.get("created_at") or time.time()), + model=str(response_payload.get("model") or raw_request.get("model") or ""), + status=str(response_payload.get("status") or "completed"), + request=deepcopy(raw_request), + response=deepcopy(response_payload), + input_items=_input_items(raw_request), + output_items=deepcopy(response_payload.get("output") or []), + usage=deepcopy(response_payload.get("usage")) if isinstance(response_payload.get("usage"), dict) else None, + metadata={"previous_response_id": parent.id if parent else raw_request.get("previous_response_id")}, + ) + + @staticmethod + def _response_to_dict(response: Any) -> Any: + if isinstance(response, dict): + return deepcopy(response) + if hasattr(response, "model_dump"): + return response.model_dump() + if hasattr(response, "dict"): + return response.dict() + return repr(response) + + @staticmethod + def _trace( + transaction_logger: Optional[Any], + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + metadata: Optional[dict[str, Any]] = None, + ) -> None: + if not transaction_logger: + return + transaction_logger.log_transform_pass( + pass_name, + data, + direction=direction, + stage=stage, + protocol="responses", + metadata=metadata or {}, + ) + + +def _input_items(raw_request: dict[str, Any]) -> list[Any]: + value = raw_request.get("input") + if value is None: + return [] + return deepcopy(value if isinstance(value, list) else [value]) diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py new file mode 100644 index 000000000..0408958c6 --- /dev/null +++ b/tests/test_responses_service.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.responses import InMemoryResponsesStore, ResponsesService, ResponsesServiceError, StoredResponse +from rotator_library.transaction_logger import TransactionLogger + + +class FakeClient: + def __init__(self) -> None: + self.calls = [] + + async def acompletion(self, **kwargs): + self.calls.append(kwargs) + return { + "id": "chat_response_1", + "model": kwargs["model"], + "choices": [{"message": {"role": "assistant", "content": "Hello back"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_create_response_stores_non_streaming_response() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + client = FakeClient() + + response = await service.create_response({"model": "gpt-test", "input": "Hello"}, client) + + assert response["id"] == "chat_response_1" + assert response["output"][0]["content"][0]["text"] == "Hello back" + assert (await store.get("chat_response_1")) is not None + assert client.calls[0]["messages"] == [{"role": "user", "content": "Hello"}] + + +@pytest.mark.asyncio +async def test_store_false_does_not_persist_response() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + response = await service.create_response({"model": "gpt-test", "input": "Hello", "store": False}, FakeClient()) + + assert await store.get(response["id"]) is None + + +@pytest.mark.asyncio +async def test_previous_response_id_loads_parent_context() -> None: + store = InMemoryResponsesStore() + await store.save( + StoredResponse( + id="resp_parent", + model="gpt-test", + status="completed", + request={"input": "Earlier"}, + response={ + "id": "resp_parent", + "object": "response", + "model": "gpt-test", + "status": "completed", + "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Earlier"}]}], + }, + input_items=["Earlier"], + output_items=[{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Earlier"}]}], + ) + ) + client = FakeClient() + service = ResponsesService(store=store) + + await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}, client) + + assert client.calls[0]["messages"] == [ + {"role": "assistant", "content": "Earlier"}, + {"role": "user", "content": "Continue"}, + ] + + +@pytest.mark.asyncio +async def test_missing_previous_response_id_raises_not_found() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + with pytest.raises(ResponsesServiceError) as exc_info: + await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "missing"}, FakeClient()) + + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_delete_and_list_input_items() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + response = await service.create_response({"model": "gpt-test", "input": ["Hello"]}, FakeClient()) + + assert (await service.get_response(response["id"]))["id"] == response["id"] + assert await service.list_input_items(response["id"]) == {"object": "list", "data": ["Hello"]} + assert await service.delete_response(response["id"]) == {"id": response["id"], "object": "response.deleted", "deleted": True} + with pytest.raises(ResponsesServiceError): + await service.get_response(response["id"]) + + +@pytest.mark.asyncio +async def test_service_emits_transform_trace_passes(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + await service.create_response({"model": "gpt-test", "input": "Hello"}, FakeClient(), transaction_logger=logger) + + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert pass_names == [ + "raw_responses_request", + "parsed_unified_request", + "responses_bridge_chat_request", + "raw_chat_bridge_response", + "parsed_unified_response", + "stored_responses_response", + "final_responses_response", + ] From e7f0e7d9753dfd2c5ed4d9972fad55253a7a2479 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:30:01 +0200 Subject: [PATCH 024/182] feat(responses): add non-stream routes Adds FastAPI routes for POST /v1/responses, GET /v1/responses/{id}, DELETE /v1/responses/{id}, and GET /v1/responses/{id}/input_items using the Phase 4 ResponsesService. The create route currently handles non-streaming requests through the bridge and returns a documented 501 for streaming until the SSE checkpoint lands next. Tests: python -m pytest tests/test_responses_routes.py tests/test_responses_service.py tests/test_responses_bridge.py tests/test_responses_store.py --- src/proxy_app/main.py | 102 +++++++++++++++++++++++++++++++++ tests/test_responses_routes.py | 72 +++++++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 tests/test_responses_routes.py diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 4aa689720..88e54d48f 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -140,6 +140,7 @@ from proxy_app.request_logger import log_request_to_console from proxy_app.batch_manager import EmbeddingBatcher from proxy_app.detailed_logger import RawIOLogger + from rotator_library.responses import ResponsesService, ResponsesServiceError print(" → Discovering provider plugins...") # Provider lazy loading happens during import, so time it here @@ -592,6 +593,10 @@ async def process_credential(provider: str, path: str, provider_instance): # print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})") client.background_refresher.start() # Start the background task app.state.rotating_client = client + # Phase 4 Responses API compatibility service. It currently bridges through + # the existing chat-completions client path; later native providers can reuse + # the same route/storage surface without changing clients. + app.state.responses_service = ResponsesService() # Warn if no provider credentials are configured if not client.all_credentials: @@ -663,6 +668,16 @@ def get_embedding_batcher(request: Request) -> EmbeddingBatcher: return request.app.state.embedding_batcher +def get_responses_service(request: Request) -> ResponsesService: + """Dependency to get the Responses API service instance from app state.""" + + service = getattr(request.app.state, "responses_service", None) + if service is None: + service = ResponsesService() + request.app.state.responses_service = service + return service + + async def verify_api_key(auth: str = Depends(api_key_header)): """Dependency to verify the proxy API key.""" # If PROXY_API_KEY is not set or empty, skip verification (open access) @@ -1011,6 +1026,93 @@ async def chat_completions( raise HTTPException(status_code=500, detail=str(e)) +def _responses_error_response(error: ResponsesServiceError) -> dict[str, Any]: + """Return an OpenAI-compatible error payload for Responses routes.""" + + return {"error": {"message": str(error), "type": error.error_type, "code": error.status_code}} + + +@app.post("/v1/responses") +async def responses_create( + request: Request, + client: RotatingClient = Depends(get_rotating_client), + service: ResponsesService = Depends(get_responses_service), + _=Depends(verify_api_key), +): + """OpenAI-compatible Responses API create endpoint. + + Phase 4 non-streaming requests bridge through the current completion engine. + Streaming requests are handled by the dedicated SSE checkpoint in this phase. + """ + + logger = RawIOLogger() if ENABLE_RAW_LOGGING else None + try: + request_data = await request.json() + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="Invalid JSON in request body.") + if logger: + logger.log_request(headers=dict(request.headers), body=request_data) + try: + if request_data.get("stream"): + raise ResponsesServiceError("Responses streaming is not enabled in this checkpoint", status_code=501, error_type="not_implemented_error") + result = await service.create_response(request_data, client, request=request) + if logger: + logger.log_final_response(status_code=200, headers=None, body=result) + return JSONResponse(content=result) + except ResponsesServiceError as e: + payload = _responses_error_response(e) + if logger: + logger.log_final_response(status_code=e.status_code, headers=None, body=payload) + raise HTTPException(status_code=e.status_code, detail=payload) + except Exception as e: + logging.error(f"Responses endpoint error: {e}") + if logger: + logger.log_final_response(status_code=500, headers=None, body={"error": str(e)}) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/v1/responses/{response_id}") +async def responses_get( + response_id: str, + service: ResponsesService = Depends(get_responses_service), + _=Depends(verify_api_key), +): + """Retrieve a stored Responses object by ID.""" + + try: + return JSONResponse(content=await service.get_response(response_id)) + except ResponsesServiceError as e: + raise HTTPException(status_code=e.status_code, detail=_responses_error_response(e)) + + +@app.delete("/v1/responses/{response_id}") +async def responses_delete( + response_id: str, + service: ResponsesService = Depends(get_responses_service), + _=Depends(verify_api_key), +): + """Delete a stored Responses object by ID.""" + + try: + return JSONResponse(content=await service.delete_response(response_id)) + except ResponsesServiceError as e: + raise HTTPException(status_code=e.status_code, detail=_responses_error_response(e)) + + +@app.get("/v1/responses/{response_id}/input_items") +async def responses_input_items( + response_id: str, + service: ResponsesService = Depends(get_responses_service), + _=Depends(verify_api_key), +): + """Return stored input items for a Responses object.""" + + try: + return JSONResponse(content=await service.list_input_items(response_id)) + except ResponsesServiceError as e: + raise HTTPException(status_code=e.status_code, detail=_responses_error_response(e)) + + # --- Anthropic Messages API Endpoint --- @app.post("/v1/messages") async def anthropic_messages( diff --git a/tests/test_responses_routes.py b/tests/test_responses_routes.py new file mode 100644 index 000000000..8333d0d62 --- /dev/null +++ b/tests/test_responses_routes.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from fastapi.testclient import TestClient + +from proxy_app import main as proxy_main +from rotator_library.responses import InMemoryResponsesStore, ResponsesService + + +class FakeClient: + async def acompletion(self, **kwargs): + return { + "id": "chat_route_1", + "model": kwargs["model"], + "choices": [{"message": {"role": "assistant", "content": "route ok"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + + +def _client() -> TestClient: + proxy_main.PROXY_API_KEY = None + proxy_main.ENABLE_RAW_LOGGING = False + proxy_main.app.state.rotating_client = FakeClient() + proxy_main.app.state.responses_service = ResponsesService(store=InMemoryResponsesStore()) + return TestClient(proxy_main.app) + + +def test_post_responses_non_stream_success() -> None: + client = _client() + + response = client.post("/v1/responses", json={"model": "gpt-test", "input": "hello"}) + + assert response.status_code == 200 + body = response.json() + assert body["id"] == "chat_route_1" + assert body["object"] == "response" + assert body["output"][0]["content"][0]["text"] == "route ok" + + +def test_post_responses_missing_model_returns_400() -> None: + client = _client() + + response = client.post("/v1/responses", json={"input": "hello"}) + + assert response.status_code == 400 + assert response.json()["detail"]["error"]["type"] == "invalid_request_error" + + +def test_get_delete_and_input_items_routes() -> None: + client = _client() + created = client.post("/v1/responses", json={"model": "gpt-test", "input": ["hello"]}).json() + + get_response = client.get(f"/v1/responses/{created['id']}") + input_items = client.get(f"/v1/responses/{created['id']}/input_items") + deleted = client.delete(f"/v1/responses/{created['id']}") + missing = client.get(f"/v1/responses/{created['id']}") + + assert get_response.status_code == 200 + assert get_response.json()["id"] == created["id"] + assert input_items.status_code == 200 + assert input_items.json() == {"object": "list", "data": ["hello"]} + assert deleted.status_code == 200 + assert deleted.json() == {"id": created["id"], "object": "response.deleted", "deleted": True} + assert missing.status_code == 404 + + +def test_post_responses_stream_checkpoint_returns_501_until_streaming_slice() -> None: + client = _client() + + response = client.post("/v1/responses", json={"model": "gpt-test", "input": "hello", "stream": True}) + + assert response.status_code == 501 + assert response.json()["detail"]["error"]["type"] == "not_implemented_error" From 1b02727188295b3b7bc020ba6e5b5fe00e6383ba Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:34:07 +0200 Subject: [PATCH 025/182] feat(responses): stream HTTP SSE events Adds Responses HTTP SSE formatting, chat-stream conversion, streamed response accumulation/storage, response.failed events on stream errors, and a WebSocket formatter seam that is explicit but not exposed as a runtime route. Updates POST /v1/responses to return text/event-stream for stream=true while preserving the existing non-stream route behavior. Tests: python -m pytest tests/test_responses_streaming.py tests/test_responses_routes.py tests/test_responses_service.py tests/test_responses_bridge.py tests/test_responses_store.py --- src/proxy_app/main.py | 6 +- src/rotator_library/responses/__init__.py | 3 + src/rotator_library/responses/service.py | 129 ++++++++++++++++++++- src/rotator_library/responses/streaming.py | 118 +++++++++++++++++++ tests/test_responses_routes.py | 14 ++- tests/test_responses_streaming.py | 79 +++++++++++++ 6 files changed, 344 insertions(+), 5 deletions(-) create mode 100644 src/rotator_library/responses/streaming.py create mode 100644 tests/test_responses_streaming.py diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 88e54d48f..60c5cabbe 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -1054,7 +1054,11 @@ async def responses_create( logger.log_request(headers=dict(request.headers), body=request_data) try: if request_data.get("stream"): - raise ResponsesServiceError("Responses streaming is not enabled in this checkpoint", status_code=501, error_type="not_implemented_error") + return StreamingResponse( + service.stream_response(request_data, client, request=request), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) result = await service.create_response(request_data, client, request=request) if logger: logger.log_final_response(status_code=200, headers=None, body=result) diff --git a/src/rotator_library/responses/__init__.py b/src/rotator_library/responses/__init__.py index 6a20a30e7..04de91582 100644 --- a/src/rotator_library/responses/__init__.py +++ b/src/rotator_library/responses/__init__.py @@ -6,6 +6,7 @@ from .bridge import ResponsesBridge from .service import ResponsesService, ResponsesServiceError from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore +from .streaming import ResponsesSSEFormatter, ResponsesWebSocketFormatter from .types import StoredResponse, generate_response_id __all__ = [ @@ -14,7 +15,9 @@ "ResponsesBridge", "ResponsesService", "ResponsesServiceError", + "ResponsesSSEFormatter", "ResponsesStore", + "ResponsesWebSocketFormatter", "StoredResponse", "generate_response_id", ] diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index fc1131476..f5e9e0157 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -7,13 +7,25 @@ import time from copy import deepcopy -from typing import Any, Optional +from typing import Any, AsyncGenerator, Optional from ..protocols import ProtocolContext from ..protocols.responses import ResponsesProtocol from .bridge import ResponsesBridge from .store import InMemoryResponsesStore, ResponsesStore +from .streaming import ( + ResponsesSSEFormatter, + ResponsesStreamState, + output_item_added_payload, + output_item_done_payload, + output_text_delta_payload, + parse_chat_sse_chunk, + response_completed_payload, + response_created_payload, + response_failed_payload, +) from .types import StoredResponse +from .types import generate_response_id class ResponsesServiceError(ValueError): @@ -89,6 +101,85 @@ async def create_response( self._trace(transaction_logger, "final_responses_response", response_payload, direction="response", stage="final") return response_payload + async def stream_response( + self, + raw_request: dict[str, Any], + client: Any, + *, + request: Optional[Any] = None, + transaction_logger: Optional[Any] = None, + ) -> AsyncGenerator[str, None]: + """Stream a Responses API request as HTTP SSE events.""" + + if not raw_request.get("model"): + raise ResponsesServiceError("'model' is required", status_code=400) + formatter = ResponsesSSEFormatter() + stream_request = dict(raw_request) + stream_request["stream"] = True + self._trace(transaction_logger, "raw_responses_request", stream_request, direction="request", stage="client") + unified = self.protocol.parse_request(stream_request, ProtocolContext(source_protocol="responses", transport="sse")) + self._trace(transaction_logger, "parsed_unified_request", unified.to_dict(), direction="request", stage="protocol") + parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) + chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) + bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) + chat_kwargs["stream"] = True + self._trace( + transaction_logger, + "responses_bridge_chat_request", + chat_kwargs, + direction="request", + stage="adapter", + metadata={"bridge_metadata": bridge_metadata, "transport": "sse"}, + ) + + response_id = generate_response_id() + state = ResponsesStreamState(response_id=response_id, model=unified.model) + usage = None + item_started = False + yield formatter.format_event("response.created", response_created_payload(response_id, unified.model)) + try: + chat_stream = await client.acompletion(request=request, **chat_kwargs) + async for raw_chunk in chat_stream: + self._trace(transaction_logger, "raw_chat_bridge_stream_chunk", raw_chunk, direction="stream", stage="provider") + chunk = parse_chat_sse_chunk(raw_chunk) + if not chunk or chunk.get("type") == "done": + continue + self._trace(transaction_logger, "parsed_unified_stream_event", chunk, direction="stream", stage="protocol") + if chunk.get("usage"): + usage = chunk["usage"] + delta = _chunk_text_delta(chunk) + if not delta: + continue + if not item_started: + item_started = True + added = output_item_added_payload(state) + self._trace(transaction_logger, "formatted_responses_stream_event", added, direction="stream", stage="final") + yield formatter.format_event("response.output_item.added", added) + state = ResponsesStreamState( + response_id=state.response_id, + model=state.model, + output_text=state.output_text + delta, + output_item_id=state.output_item_id, + ) + event = output_text_delta_payload(state, delta) + self._trace(transaction_logger, "formatted_responses_stream_event", event, direction="stream", stage="final") + yield formatter.format_event("response.output_text.delta", event) + + if not item_started: + yield formatter.format_event("response.output_item.added", output_item_added_payload(state)) + done_item = output_item_done_payload(state) + yield formatter.format_event("response.output_item.done", done_item) + completed = response_completed_payload(state, _usage_to_responses_stream(usage)) + await self._store_stream_response(stream_request, completed, parent) + self._trace(transaction_logger, "stored_responses_stream_response", completed, direction="metadata", stage="final") + yield formatter.format_event("response.completed", completed) + yield formatter.done() + except Exception as exc: + failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) + self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) + yield formatter.format_event("response.failed", failed) + yield formatter.done() + async def get_response(self, response_id: str) -> dict[str, Any]: """Return a stored response payload or raise a 404-compatible error.""" @@ -184,9 +275,45 @@ def _trace( metadata=metadata or {}, ) + @staticmethod + def _log_transform_error(transaction_logger: Optional[Any], pass_name: str, error: BaseException, payload: Any) -> None: + if transaction_logger: + transaction_logger.log_transform_error(pass_name, error, payload=payload, stage="adapter", protocol="responses") + + async def _store_stream_response( + self, + raw_request: dict[str, Any], + response_payload: dict[str, Any], + parent: Optional[StoredResponse], + ) -> None: + if not raw_request.get("store", True): + return + await self.store.save(self._stored_response(raw_request, response_payload, parent)) + def _input_items(raw_request: dict[str, Any]) -> list[Any]: value = raw_request.get("input") if value is None: return [] return deepcopy(value if isinstance(value, list) else [value]) + + +def _chunk_text_delta(chunk: dict[str, Any]) -> str: + choices = chunk.get("choices") if isinstance(chunk.get("choices"), list) else [] + if not choices: + return "" + delta = choices[0].get("delta") if isinstance(choices[0], dict) else None + if not isinstance(delta, dict): + return "" + content = delta.get("content") + return content if isinstance(content, str) else "" + + +def _usage_to_responses_stream(usage: Any) -> Any: + if not isinstance(usage, dict): + return usage + return { + "input_tokens": usage.get("prompt_tokens", usage.get("input_tokens", 0)), + "output_tokens": usage.get("completion_tokens", usage.get("output_tokens", 0)), + "total_tokens": usage.get("total_tokens", 0), + } diff --git a/src/rotator_library/responses/streaming.py b/src/rotator_library/responses/streaming.py new file mode 100644 index 000000000..d358ab65a --- /dev/null +++ b/src/rotator_library/responses/streaming.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""HTTP SSE formatting for the Responses API compatibility layer.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any + +from ..protocols import serialize_value + + +@dataclass(frozen=True) +class ResponsesStreamState: + """Mutable-by-replacement state accumulated from chat stream chunks.""" + + response_id: str + model: str + output_text: str = "" + output_item_id: str = "msg_0" + + +class ResponsesSSEFormatter: + """Format Responses API events as HTTP Server-Sent Events.""" + + transport = "sse" + + def format_event(self, event_name: str, payload: dict[str, Any]) -> str: + return f"event: {event_name}\ndata: {json.dumps(serialize_value(payload), ensure_ascii=False)}\n\n" + + def done(self) -> str: + """Return the final compatibility sentinel used by many SSE clients.""" + + return "data: [DONE]\n\n" + + +class ResponsesWebSocketFormatter: + """Placeholder transport seam for future WebSocket Responses support.""" + + transport = "websocket" + future_supported = True + + def format_event(self, event_name: str, payload: dict[str, Any]) -> str: + raise NotImplementedError("Responses WebSocket transport is planned but not implemented yet") + + +def parse_chat_sse_chunk(chunk: Any) -> dict[str, Any] | None: + """Decode a chat-completions stream chunk into a dict if possible.""" + + if isinstance(chunk, dict): + return chunk + if not isinstance(chunk, str): + return None + text = chunk.strip() + if not text: + return None + if text.startswith("data:"): + text = text[len("data:") :].strip() + if text == "[DONE]": + return {"type": "done"} + try: + return json.loads(text) + except json.JSONDecodeError: + return None + + +def response_created_payload(response_id: str, model: str) -> dict[str, Any]: + return {"id": response_id, "object": "response", "status": "in_progress", "model": model, "output": []} + + +def output_item_added_payload(state: ResponsesStreamState) -> dict[str, Any]: + return { + "response_id": state.response_id, + "output_index": 0, + "item": {"id": state.output_item_id, "type": "message", "role": "assistant", "content": []}, + } + + +def output_text_delta_payload(state: ResponsesStreamState, delta: str) -> dict[str, Any]: + return { + "response_id": state.response_id, + "item_id": state.output_item_id, + "output_index": 0, + "content_index": 0, + "delta": delta, + } + + +def output_item_done_payload(state: ResponsesStreamState) -> dict[str, Any]: + return { + "response_id": state.response_id, + "output_index": 0, + "item": { + "id": state.output_item_id, + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": state.output_text}], + }, + } + + +def response_completed_payload(state: ResponsesStreamState, usage: Any = None) -> dict[str, Any]: + payload = { + "id": state.response_id, + "object": "response", + "status": "completed", + "model": state.model, + "output": [output_item_done_payload(state)["item"]], + } + if usage is not None: + payload["usage"] = usage + return payload + + +def response_failed_payload(response_id: str, model: str, error: Any) -> dict[str, Any]: + return {"id": response_id, "object": "response", "status": "failed", "model": model, "error": serialize_value(error)} diff --git a/tests/test_responses_routes.py b/tests/test_responses_routes.py index 8333d0d62..4079fde5f 100644 --- a/tests/test_responses_routes.py +++ b/tests/test_responses_routes.py @@ -8,6 +8,13 @@ class FakeClient: async def acompletion(self, **kwargs): + if kwargs.get("stream"): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"route"}}]}\n\n' + yield 'data: {"choices":[{"delta":{"content":" stream"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() return { "id": "chat_route_1", "model": kwargs["model"], @@ -63,10 +70,11 @@ def test_get_delete_and_input_items_routes() -> None: assert missing.status_code == 404 -def test_post_responses_stream_checkpoint_returns_501_until_streaming_slice() -> None: +def test_post_responses_stream_returns_sse_events() -> None: client = _client() response = client.post("/v1/responses", json={"model": "gpt-test", "input": "hello", "stream": True}) - assert response.status_code == 501 - assert response.json()["detail"]["error"]["type"] == "not_implemented_error" + assert response.status_code == 200 + assert "event: response.created" in response.text + assert "event: response.completed" in response.text diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py new file mode 100644 index 000000000..2979bcfac --- /dev/null +++ b/tests/test_responses_streaming.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import pytest + +from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesWebSocketFormatter + + +class FakeStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"role":"assistant"}}]}\n\n' + yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"content":"Hel"}}]}\n\n' + yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"content":"lo"}}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class FailingStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"before"}}]}\n\n' + raise RuntimeError("stream exploded") + + return chunks() + + +@pytest.mark.asyncio +async def test_stream_response_emits_responses_sse_events_and_stores_final_response() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient())] + + event_text = "".join(events) + assert "event: response.created" in event_text + assert "event: response.output_item.added" in event_text + assert "event: response.output_text.delta" in event_text + assert "event: response.output_item.done" in event_text + assert "event: response.completed" in event_text + assert event_text.endswith("data: [DONE]\n\n") + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.output_items[0]["content"][0]["text"] == "Hello" + assert stored.usage == {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3} + + +@pytest.mark.asyncio +async def test_stream_response_store_false_does_not_persist() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True, "store": False}, FakeStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + assert await store.get(response_id) is None + + +@pytest.mark.asyncio +async def test_stream_response_errors_emit_failed_event() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FailingStreamingClient())] + + event_text = "".join(events) + assert "event: response.failed" in event_text + assert "stream exploded" in event_text + assert event_text.endswith("data: [DONE]\n\n") + + +def test_transport_formatters_expose_sse_and_websocket_seam() -> None: + assert ResponsesSSEFormatter().transport == "sse" + websocket = ResponsesWebSocketFormatter() + assert websocket.transport == "websocket" + assert websocket.future_supported is True + with pytest.raises(NotImplementedError): + websocket.format_event("response.created", {}) From 05108ff7f7e8aefc16751e38a78291c55688cf43 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:39:56 +0200 Subject: [PATCH 026/182] fix(responses): close Phase 4 review gaps Wires Responses routes into the transform trace logger when request logging is enabled, adds coverage for unsupported Responses fields preserved in bridge metadata, and strengthens streaming tests to assert SSE event order. Tests: python -m pytest tests/test_responses_store.py tests/test_responses_bridge.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_streaming.py Tests: python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py tests/test_adapter_registry.py tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_provider_protocol_declarations.py tests/test_session_tracking.py tests/test_selection_engine.py --- src/proxy_app/main.py | 6 ++++-- tests/test_responses_bridge.py | 10 ++++++++++ tests/test_responses_streaming.py | 17 ++++++++++++----- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 60c5cabbe..f2aa4dd1f 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -141,6 +141,7 @@ from proxy_app.batch_manager import EmbeddingBatcher from proxy_app.detailed_logger import RawIOLogger from rotator_library.responses import ResponsesService, ResponsesServiceError + from rotator_library.transaction_logger import TransactionLogger print(" → Discovering provider plugins...") # Provider lazy loading happens during import, so time it here @@ -1052,14 +1053,15 @@ async def responses_create( raise HTTPException(status_code=400, detail="Invalid JSON in request body.") if logger: logger.log_request(headers=dict(request.headers), body=request_data) + transaction_logger = TransactionLogger("responses", request_data.get("model", "unknown")) if ENABLE_REQUEST_LOGGING else None try: if request_data.get("stream"): return StreamingResponse( - service.stream_response(request_data, client, request=request), + service.stream_response(request_data, client, request=request, transaction_logger=transaction_logger), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) - result = await service.create_response(request_data, client, request=request) + result = await service.create_response(request_data, client, request=request, transaction_logger=transaction_logger) if logger: logger.log_final_response(status_code=200, headers=None, body=result) return JSONResponse(content=result) diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py index 9c2fdc836..ecb50b0ad 100644 --- a/tests/test_responses_bridge.py +++ b/tests/test_responses_bridge.py @@ -61,6 +61,16 @@ def test_bridge_preserves_tool_definitions() -> None: ] +def test_bridge_preserves_unsupported_fields_for_trace_metadata() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hi", "custom_unsupported": 42}) + + kwargs = bridge.to_chat_kwargs(unified) + + assert kwargs["_responses_bridge"]["extra"]["custom_unsupported"] == 42 + + def test_bridge_converts_chat_response_to_responses_payload() -> None: protocol = ResponsesProtocol() bridge = ResponsesBridge(protocol) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 2979bcfac..4fc5fed80 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -5,6 +5,10 @@ from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesWebSocketFormatter +def _event_names(events: list[str]) -> list[str]: + return [line.removeprefix("event: ") for event in events for line in event.splitlines() if line.startswith("event: ")] + + class FakeStreamingClient: async def acompletion(self, **kwargs): async def chunks(): @@ -33,11 +37,14 @@ async def test_stream_response_emits_responses_sse_events_and_stores_final_respo events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient())] event_text = "".join(events) - assert "event: response.created" in event_text - assert "event: response.output_item.added" in event_text - assert "event: response.output_text.delta" in event_text - assert "event: response.output_item.done" in event_text - assert "event: response.completed" in event_text + assert _event_names(events) == [ + "response.created", + "response.output_item.added", + "response.output_text.delta", + "response.output_text.delta", + "response.output_item.done", + "response.completed", + ] assert event_text.endswith("data: [DONE]\n\n") response_id = events[0].split('"id": "')[1].split('"')[0] From 456a6bb7823b5c09c9670b554902c6e6261112bb Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:48:19 +0200 Subject: [PATCH 027/182] docs(experimental): plan provider protocol overhaul Adds the Phase 5 plan for native provider execution, provider declarations, HTTP and streaming seams, priority provider order, Antigravity restoration constraints, Gemini CLI parity review, fallback policy, transform tracing, field-cache rules, tests, and review checkpoints. Reports remain uncommitted for user-facing review only. --- .../phase-5-provider-protocol-overhaul.md | 263 ++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 docs/experimental/phase-5-provider-protocol-overhaul.md diff --git a/docs/experimental/phase-5-provider-protocol-overhaul.md b/docs/experimental/phase-5-provider-protocol-overhaul.md new file mode 100644 index 000000000..c05d94ad7 --- /dev/null +++ b/docs/experimental/phase-5-provider-protocol-overhaul.md @@ -0,0 +1,263 @@ +# Phase 5 Plan: Native Provider Protocol Overhaul + +## Goal + +Make provider implementations opt into the native protocol, adapter, field-cache, and Responses foundations added in Phases 1-4, then add or restore the priority providers in the requested order: Claude Code, Codex, Copilot, Antigravity, and Gemini CLI parity review. This phase should create a reusable provider-native execution seam first, then implement providers incrementally behind explicit declarations so LiteLLM remains fallback-only for uncovered cases. + +## Non-Goals + +- Do not replace `UsageManager`, `SessionTracker`, `SelectionEngine`, or retry-after parsing. +- Do not implement routing/fallback groups yet; Phase 6 owns ordered fallback chains. +- Do not implement full multi-user/security isolation. +- Do not add SQLite or any database. +- Do not rewrite the entire executor in one pass. +- Do not migrate every existing provider at once. +- Do not copy obsolete retired-provider behavior blindly. +- Do not remove LiteLLM fallback; make fallback explicit and observable. + +## Current Code Context + +- Phase 1 provides reusable protocol adapters: OpenAI Chat, Anthropic Messages, Gemini, Responses, and LiteLLM fallback. +- Phase 2 provides transaction transform tracing and provider trace entries. +- Phase 3 provides provider declarations, adapter chains, field-cache rules, and scoped cache engine. +- Phase 4 provides `/v1/responses`, response storage, HTTP SSE, and WebSocket seam. +- `ProviderInterface` already has optional `get_protocol_name()`, `get_adapter_names()`, `get_adapter_config()`, and `get_field_cache_rules()`. +- Existing custom providers with `has_custom_logic()` include Gemini CLI, Deepseek, and retired Antigravity/IFlow/Qwen-style providers. +- Gemini CLI is large and custom; Phase 5 should review and target improvements rather than destabilize it. +- Antigravity exists only under `src/rotator_library/providers/_retired/` and must be compared against current behavior before restoration. +- No Claude Code, Codex, or Copilot provider files exist in active providers today. + +## Files To Add + +- `src/rotator_library/native_provider/__init__.py` +- `src/rotator_library/native_provider/context.py` +- `src/rotator_library/native_provider/executor.py` +- `src/rotator_library/native_provider/http.py` +- `src/rotator_library/native_provider/streaming.py` +- `src/rotator_library/providers/claude_code_provider.py` +- `src/rotator_library/providers/codex_provider.py` +- `src/rotator_library/providers/copilot_provider.py` +- Possibly `src/rotator_library/providers/antigravity_provider.py` +- `tests/test_native_provider_executor.py` +- `tests/test_native_provider_streaming.py` +- `tests/test_claude_code_provider.py` +- `tests/test_codex_provider.py` +- `tests/test_copilot_provider.py` +- `tests/test_antigravity_provider_restore.py` if restored in this phase +- `tests/test_gemini_cli_protocol_declarations.py` + +## Files Likely To Touch + +- `src/rotator_library/providers/provider_interface.py` only for small optional native execution methods if needed. +- `src/rotator_library/client/executor.py` only if a minimal seam is needed to route declared native providers without breaking current providers. +- `src/rotator_library/client/models.py` only if model/provider resolution needs explicit provider declarations. +- `src/rotator_library/client/request_builder.py` only if provider declarations need context fields before execution. +- `src/rotator_library/responses/service.py` only if native Responses-capable providers can bypass the chat bridge cleanly. +- Existing provider files only for declaration additions or targeted Gemini CLI fixes. + +## Native Provider Execution Seam + +- Add `NativeProviderContext` with provider, model, credential identifier, headers, request context, protocol name, adapter names, field-cache rules, transport, transaction logger, session fields, scope key, classifier, and provider metadata. +- Add `NativeProviderExecutor` that can: + - Resolve protocol via `provider.get_protocol_name(model)`. + - Build `AdapterContext` and run `before_adapter_chain` / `after_adapter` / `after_adapter_chain`. + - Run `FieldCacheEngine.inject()` before provider request. + - Build native HTTP request payload from the selected protocol. + - Send request through a small HTTP transport wrapper. + - Parse provider response through selected protocol. + - Run `FieldCacheEngine.extract()` after response or stream events. + - Format back to the requested client protocol. + - Emit transform trace passes for every step. +- The executor must be provider-opt-in. If `get_protocol_name()` returns `None` or `"litellm_fallback"`, current behavior stays unchanged. +- It must be independently testable with mocked HTTP clients and fake providers. + +## Provider Interface Additions + +Prefer using existing declarations first. Add optional methods only if needed: + +- `get_native_endpoint(model, operation)` +- `get_native_headers(credential_identifier, model, operation)` +- `build_native_request_options(model, operation)` +- `supports_native_operation(operation, model)` +- `should_use_native_protocol(operation, model)` + +Defaults must preserve current behavior. Docstrings must explain provider overrides are encouraged when a base protocol is close but not exact. + +## HTTP Transport + +- Use injected `httpx.AsyncClient`. +- Support JSON POST first. +- Support SSE response iteration for streaming. +- Preserve raw request/response bodies for transform trace. +- Do not add WebSocket runtime transport yet; keep the seam compatible with Phase 4 WebSocket formatter. +- Convert provider HTTP errors into existing error classification paths where possible. + +## Streaming + +- Native stream parser should call protocol stream parsing where available. +- Emit transform passes: + - `native_provider_request` + - `raw_native_provider_stream_chunk` + - `parsed_native_stream_event` + - `after_field_cache_stream_extraction` + - `formatted_client_stream_event` +- Do not fallback after visible output. Phase 6 owns fallback policy; Phase 5 only makes stream behavior observable and safe. + +## Responses Integration + +- For providers with native Responses support, allow `ResponsesService` to bypass `ResponsesBridge`. +- Add a service-level native executor hook only if it can be done without coupling service to provider internals. +- If not clean in Phase 5, leave `/v1/responses` bridge as default and expose provider-native Responses in provider tests/foundation for Phase 6/8 wiring. +- Preserve `/v1/responses` route behavior from Phase 4. + +## Priority Provider Plan + +### Claude Code + +- Add provider first. +- Determine whether it is Anthropic Messages-compatible, OpenAI Chat-compatible, or a dedicated endpoint. +- Start with declarations and mocked native execution tests. +- Preserve Claude reasoning/thinking fields via field-cache rules if present. +- Suppress/transform unsupported roles through adapters instead of bespoke monolithic code. +- Add tests for auth headers, model list behavior, request transform, response transform, streaming text, and field-cache rule declarations. +- If live API details are uncertain, implement an integration path with explicit env names and mocked endpoint behavior rather than guessing secrets or undocumented flows. + +### Codex + +- Add provider second. +- Treat as likely OpenAI/Responses-compatible until provider-specific evidence says otherwise. +- Prefer native Responses protocol if supported; otherwise OpenAI Chat protocol. +- Add tests for protocol selection, Responses bypass or bridge compatibility, auth headers, model naming, and explicit no-LiteLLM path when native mode is declared. + +### Copilot + +- Add provider third. +- Use native protocol declarations and OAuth/header helpers. +- Keep credential acquisition/refresh minimal and mocked unless existing credential machinery already supports it. +- Add field-cache rules for conversation/session IDs if required by provider behavior. +- Add tests for model list, auth header, protocol selection, request conversion, and streaming. + +### Antigravity + +- Compare `src/rotator_library/providers/_retired/antigravity_provider.py`, auth base, quota tracker, and device profile utilities before restoring. +- Restore only valid/current behavior. +- Avoid fragile or obsolete device-profile behavior unless current service behavior requires it. +- Extract reusable pieces: + - model mapping + - auth headers + - schema cleanup + - thinking/reasoning preservation + - quota parsing + - SSE handling +- Do not resurrect the whole monolith unchanged. +- Add a restored provider only after tests describe the stable subset. +- If the current service cannot be validated safely, write an explicit integration-path provider skeleton and defer live-specific behavior. + +### Gemini CLI Parity Review + +- Review active `gemini_cli_provider.py` against the new protocol/adapter/cache foundation. +- Add provider declarations where safe: + - likely protocol `gemini` + - adapters for model override / field rename / role suppression only if needed + - field-cache rules for thought signatures and provider session fields if they match current behavior. +- Do not rewrite Gemini CLI in this phase. +- Add targeted tests proving declarations do not change current behavior. +- Fix only clear parity gaps found during review. + +## LiteLLM Fallback Policy + +- Providers with native declarations should not silently fall back to LiteLLM on the primary path. +- If fallback is allowed, it must be explicit: + - protocol name `"litellm_fallback"` + - trace pass `native_provider_litellm_fallback` + - metadata reason. +- Tests must assert no accidental LiteLLM fallback for providers declared native. + +## Transform Trace Requirements + +Provider-native calls should log: + +- `native_protocol_selected` +- `before_adapter_chain` +- `after_adapter` +- `after_field_cache_injection` +- `native_provider_request` +- `raw_native_provider_response` +- `parsed_native_provider_response` +- `after_field_cache_extraction` +- `final_client_response` + +Provider-native stream calls should log: + +- `native_protocol_selected` +- `native_provider_stream_request` +- `raw_native_provider_stream_chunk` +- `parsed_native_stream_event` +- `after_field_cache_stream_extraction` +- `formatted_client_stream_event` + +Errors should use `log_transform_error()` with provider/protocol/pass context. + +## Field-Cache Requirements + +- Use Phase 3 rules for: + - reasoning content + - thought signatures + - provider session IDs + - prompt cache keys + - previous/response IDs where provider-specific +- Scope must include at least provider + model + classifier + session for conversation-affecting values. +- Credential scope must be added for values tied to an account/token. +- Tests must prove no cross-provider/model/session/credential leakage for provider rules. + +## Testing Plan + +Native executor tests: + +- protocol selection +- adapter chain order +- field-cache injection/extraction +- non-stream request/response trace passes +- stream trace passes +- explicit fallback behavior +- provider opt-in leaves current fallback path untouched for undeclared providers + +Provider tests: + +- registration/discovery for Claude Code, Codex, Copilot, and restored Antigravity if added. +- env var naming and auth header construction via mocks. +- model list parsing via mocked HTTP. +- request payload build and response parse via mocked HTTP. +- SSE stream conversion via mocked chunks. +- field-cache declarations. +- no live credentials required. + +Regression tests: + +- Phase 1 protocol tests. +- Phase 2 logging tests. +- Phase 3 adapter/field-cache tests. +- Phase 4 Responses tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add native provider context/executor/HTTP transport with tests. +2. Add native streaming support with tests. +3. Add Claude Code provider or integration skeleton with tests. +4. Add Codex provider or integration skeleton with tests. +5. Add Copilot provider or integration skeleton with tests. +6. Compare and restore the safe Antigravity subset, or write a documented deferral with tests. +7. Add Gemini CLI declaration/parity fixes with tests. +8. Run focused and regression tests. +9. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 5 report. + +## Risks And Mitigations + +- Provider APIs may be undocumented or volatile. Use mocked behavior and explicit integration seams rather than guessing hidden flows. +- A native executor could destabilize current traffic. Keep native execution opt-in and leave undeclared providers unchanged. +- Restoring Antigravity could reintroduce brittle device/profile logic. Restore only tested stable subsets. +- Streaming fidelity for tool/reasoning deltas may vary by provider. Preserve raw chunks in trace and keep provider-specific parsers override-friendly. +- Field-cache leakage would be serious. Add scope tests per provider rule. +- LiteLLM fallback could hide native failures. Make fallback explicit, traced, and tested. From d3028ad97ecbea669190c403d45841716baf9e54 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:51:46 +0200 Subject: [PATCH 028/182] feat(native-provider): add opt-in executor foundation Adds the Phase 5 native provider foundation with execution context, HTTP transport wrapper, and non-streaming executor that runs protocol selection, adapter chains, field-cache injection/extraction, provider HTTP calls, and transform tracing. The foundation is not wired into live request execution yet, preserving current provider behavior while giving priority provider work a testable native path. Tests: python -m pytest tests/test_native_provider_executor.py tests/test_responses_store.py tests/test_responses_bridge.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_streaming.py --- .../native_provider/__init__.py | 10 ++ .../native_provider/context.py | 83 ++++++++++++++++ .../native_provider/executor.py | 97 ++++++++++++++++++ src/rotator_library/native_provider/http.py | 29 ++++++ tests/test_native_provider_executor.py | 98 +++++++++++++++++++ 5 files changed, 317 insertions(+) create mode 100644 src/rotator_library/native_provider/__init__.py create mode 100644 src/rotator_library/native_provider/context.py create mode 100644 src/rotator_library/native_provider/executor.py create mode 100644 src/rotator_library/native_provider/http.py create mode 100644 tests/test_native_provider_executor.py diff --git a/src/rotator_library/native_provider/__init__.py b/src/rotator_library/native_provider/__init__.py new file mode 100644 index 000000000..cd63fed30 --- /dev/null +++ b/src/rotator_library/native_provider/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Opt-in native provider execution helpers.""" + +from .context import NativeProviderContext +from .executor import NativeProviderExecutor +from .http import NativeHTTPTransport + +__all__ = ["NativeHTTPTransport", "NativeProviderContext", "NativeProviderExecutor"] diff --git a/src/rotator_library/native_provider/context.py b/src/rotator_library/native_provider/context.py new file mode 100644 index 000000000..52cc9130e --- /dev/null +++ b/src/rotator_library/native_provider/context.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Context objects for native provider execution.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from ..adapters import AdapterContext +from ..field_cache import FieldCacheContext, FieldCacheRule +from ..protocols import ProtocolContext + + +@dataclass +class NativeProviderContext: + """Metadata needed to execute a provider through native protocol helpers. + + This context intentionally mirrors the trace, adapter, protocol, and + field-cache contexts so provider-native execution can remain opt-in and + testable without changing the existing LiteLLM-backed path. + """ + + provider: str + model: str + protocol_name: str + endpoint: str + headers: dict[str, str] = field(default_factory=dict) + credential_id: Optional[str] = None + session_id: Optional[str] = None + scope_key: Optional[str] = None + classifier: Optional[str] = None + transport: str = "http" + adapter_names: tuple[str, ...] = () + adapter_config: dict[str, dict[str, Any]] = field(default_factory=dict) + field_cache_rules: tuple[FieldCacheRule, ...] = () + transaction_logger: Optional[Any] = None + metadata: dict[str, Any] = field(default_factory=dict) + + def protocol_context(self, *, target_protocol: Optional[str] = None) -> ProtocolContext: + """Build a protocol context for parse/build/format passes.""" + + return ProtocolContext( + provider=self.provider, + model=self.model, + source_protocol=self.protocol_name, + target_protocol=target_protocol or self.protocol_name, + session_id=self.session_id, + credential_stable_id=self.credential_id, + transport=self.transport, + metadata=dict(self.metadata), + ) + + def adapter_context(self) -> AdapterContext: + """Build an adapter context for provider payload adapters.""" + + return AdapterContext( + provider=self.provider, + model=self.model, + protocol=self.protocol_name, + credential_id=self.credential_id, + session_id=self.session_id, + scope_key=self.scope_key, + classifier=self.classifier, + transport=self.transport, + metadata=dict(self.metadata), + adapter_config=dict(self.adapter_config), + transaction_logger=self.transaction_logger, + ) + + def field_cache_context(self) -> FieldCacheContext: + """Build a field-cache context with provider isolation metadata.""" + + return FieldCacheContext( + provider=self.provider, + model=self.model, + credential_id=self.credential_id, + session_id=self.session_id, + conversation_id=self.scope_key, + classifier=self.classifier, + metadata=dict(self.metadata), + ) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py new file mode 100644 index 000000000..ff48e44a2 --- /dev/null +++ b/src/rotator_library/native_provider/executor.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Opt-in executor for provider-native protocol calls.""" + +from __future__ import annotations + +from typing import Any + +from ..adapters import get_adapter, run_adapter_chain +from ..field_cache import FieldCacheEngine +from ..protocols import get_protocol +from .context import NativeProviderContext +from .http import NativeHTTPTransport + + +class NativeProviderExecutor: + """Run one native provider request through protocol/adapter/cache passes. + + This executor is intentionally not wired into the live client path yet. Phase + 5 providers can test native behavior here first, then later phases can route + declared providers into it without disturbing undeclared providers. + """ + + def __init__(self, *, field_cache_store: Any = None) -> None: + self.field_cache_store = field_cache_store + + async def execute(self, raw_request: dict[str, Any], context: NativeProviderContext, transport: NativeHTTPTransport) -> dict[str, Any]: + """Execute a non-streaming native provider request.""" + + logger = context.transaction_logger + protocol = get_protocol(context.protocol_name) + self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") + try: + protocol_context = context.protocol_context() + unified_request = protocol.parse_request(raw_request, protocol_context) + provider_request = protocol.build_request(unified_request, protocol_context) + adapters = [get_adapter(name) for name in context.adapter_names] + provider_request = await run_adapter_chain(adapters, provider_request, context.adapter_context(), stage="request") + cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) + provider_request, _ = await cache_engine.inject( + "request", + provider_request, + context.field_cache_context(), + transaction_logger=logger, + ) + self._trace(context, "native_provider_request", provider_request, direction="request", stage="provider") + raw_response = await transport.post_json(context.endpoint, headers=context.headers, payload=provider_request) + self._trace(context, "raw_native_provider_response", raw_response, direction="response", stage="provider") + unified_response = protocol.parse_response(raw_response, protocol_context) + provider_response = protocol.format_response(unified_response, protocol_context) + self._trace(context, "parsed_native_provider_response", provider_response, direction="response", stage="protocol") + provider_response = await run_adapter_chain(adapters, provider_response, context.adapter_context(), stage="response") + await cache_engine.extract("response", provider_response, context.field_cache_context(), transaction_logger=logger) + self._trace(context, "final_client_response", provider_response, direction="response", stage="final") + return provider_response + except Exception as exc: + if logger: + logger.log_transform_error( + "native_provider_execute", + exc, + payload=raw_request, + stage="provider", + protocol=context.protocol_name, + metadata={"provider": context.provider, "model": context.model}, + ) + raise + + @staticmethod + def _trace( + context: NativeProviderContext, + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + metadata: dict[str, Any] | None = None, + ) -> None: + if not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + pass_name, + data, + direction=direction, + stage=stage, + protocol=context.protocol_name, + credential_id=context.credential_id, + transport=context.transport, + metadata={ + "provider": context.provider, + "model": context.model, + "session_id": context.session_id, + "scope_key": context.scope_key, + "classifier": context.classifier, + **(metadata or {}), + }, + ) diff --git a/src/rotator_library/native_provider/http.py b/src/rotator_library/native_provider/http.py new file mode 100644 index 000000000..9678b569c --- /dev/null +++ b/src/rotator_library/native_provider/http.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Small HTTP transport wrapper for native provider calls.""" + +from __future__ import annotations + +from typing import Any + + +class NativeHTTPTransport: + """Execute provider-native JSON HTTP requests through an injected client.""" + + def __init__(self, client: Any) -> None: + self.client = client + + async def post_json(self, endpoint: str, *, headers: dict[str, str], payload: dict[str, Any]) -> Any: + """POST JSON and return a decoded response body. + + The wrapper keeps HTTP behavior easy to mock. It does not own retries or + credential rotation; those remain in the existing executor/usage layer. + """ + + response = await self.client.post(endpoint, headers=headers, json=payload) + if hasattr(response, "raise_for_status"): + response.raise_for_status() + if hasattr(response, "json"): + return response.json() + return response diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py new file mode 100644 index 000000000..d8dbfaec8 --- /dev/null +++ b/tests/test_native_provider_executor.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule +from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from rotator_library.transaction_logger import TransactionLogger + + +class FakeHTTPResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeHTTPClient: + def __init__(self, response): + self.response = response + self.calls = [] + + async def post(self, endpoint, *, headers, json): + self.calls.append({"endpoint": endpoint, "headers": headers, "json": json}) + return FakeHTTPResponse(self.response) + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="reasoning", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.reasoning_content"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + headers={"authorization": "Bearer test"}, + adapter_names=("model_override",), + adapter_config={"model_override": {"model": "provider/gpt-test"}}, + field_cache_rules=(rule,), + transaction_logger=logger, + ) + response = { + "id": "chat_1", + "model": "provider/gpt-test", + "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "hidden"}}], + } + client = FakeHTTPClient(response) + + result = await NativeProviderExecutor().execute( + {"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]}, + context, + NativeHTTPTransport(client), + ) + + assert result["id"] == "chat_1" + assert client.calls[0]["json"]["model"] == "provider/gpt-test" + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "native_protocol_selected" in pass_names + assert "native_provider_request" in pass_names + assert "raw_native_provider_response" in pass_names + assert "parsed_native_provider_response" in pass_names + assert "after_field_cache_extraction" in pass_names + assert "final_client_response" in pass_names + + +@pytest.mark.asyncio +async def test_native_provider_executor_logs_transform_errors(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + adapter_names=("missing_adapter",), + transaction_logger=logger, + ) + + with pytest.raises(KeyError): + await NativeProviderExecutor().execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeHTTPClient({}))) + + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "transform_log_error" in pass_names From 6eef5fa32b0b181e6c1781a6fdd00a56a48ee0ef Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:53:27 +0200 Subject: [PATCH 029/182] feat(native-provider): add streaming foundation Adds opt-in native provider streaming support with streaming-capable transport seam, raw chunk tracing, protocol stream parsing, field-cache stream extraction, formatted client stream events, and transform error logging. The streaming foundation remains isolated from live provider routing so existing providers keep current behavior while Phase 5 provider implementations gain a mocked native stream path. Tests: python -m pytest tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_responses_store.py tests/test_responses_bridge.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_streaming.py tests/test_protocol_responses.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py --- .../native_provider/__init__.py | 3 +- .../native_provider/executor.py | 49 ++++++++++++- src/rotator_library/native_provider/http.py | 15 ++++ .../native_provider/streaming.py | 16 +++++ tests/test_native_provider_streaming.py | 70 +++++++++++++++++++ 5 files changed, 151 insertions(+), 2 deletions(-) create mode 100644 src/rotator_library/native_provider/streaming.py create mode 100644 tests/test_native_provider_streaming.py diff --git a/src/rotator_library/native_provider/__init__.py b/src/rotator_library/native_provider/__init__.py index cd63fed30..dc63572ad 100644 --- a/src/rotator_library/native_provider/__init__.py +++ b/src/rotator_library/native_provider/__init__.py @@ -6,5 +6,6 @@ from .context import NativeProviderContext from .executor import NativeProviderExecutor from .http import NativeHTTPTransport +from .streaming import stream_event_payload -__all__ = ["NativeHTTPTransport", "NativeProviderContext", "NativeProviderExecutor"] +__all__ = ["NativeHTTPTransport", "NativeProviderContext", "NativeProviderExecutor", "stream_event_payload"] diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index ff48e44a2..cb57df946 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -5,13 +5,14 @@ from __future__ import annotations -from typing import Any +from typing import Any, AsyncGenerator from ..adapters import get_adapter, run_adapter_chain from ..field_cache import FieldCacheEngine from ..protocols import get_protocol from .context import NativeProviderContext from .http import NativeHTTPTransport +from .streaming import stream_event_payload class NativeProviderExecutor: @@ -66,6 +67,50 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont ) raise + async def stream(self, raw_request: dict[str, Any], context: NativeProviderContext, transport: NativeHTTPTransport) -> AsyncGenerator[Any, None]: + """Execute a streaming native provider request and yield client events.""" + + logger = context.transaction_logger + protocol = get_protocol(context.protocol_name) + self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") + try: + protocol_context = context.protocol_context() + request_payload = dict(raw_request) + request_payload["stream"] = True + unified_request = protocol.parse_request(request_payload, protocol_context) + provider_request = protocol.build_request(unified_request, protocol_context) + adapters = [get_adapter(name) for name in context.adapter_names] + provider_request = await run_adapter_chain(adapters, provider_request, context.adapter_context(), stage="request") + cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) + provider_request, _ = await cache_engine.inject( + "request", + provider_request, + context.field_cache_context(), + transaction_logger=logger, + ) + self._trace(context, "native_provider_stream_request", provider_request, direction="request", stage="provider") + async for raw_chunk in transport.stream_json_lines(context.endpoint, headers=context.headers, payload=provider_request): + self._trace(context, "raw_native_provider_stream_chunk", raw_chunk, direction="stream", stage="provider") + event = protocol.parse_stream_event(raw_chunk, protocol_context) + event_payload = stream_event_payload(event) + self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") + await cache_engine.extract("stream_event", event_payload, context.field_cache_context(), transaction_logger=logger) + formatted = protocol.format_stream_event(event, protocol_context) + self._trace(context, "formatted_client_stream_event", formatted, direction="stream", stage="final", snapshot=False) + yield formatted + except Exception as exc: + if logger: + logger.log_transform_error( + "native_provider_stream", + exc, + payload=raw_request, + stage="provider", + protocol=context.protocol_name, + transport=context.transport, + metadata={"provider": context.provider, "model": context.model}, + ) + raise + @staticmethod def _trace( context: NativeProviderContext, @@ -75,6 +120,7 @@ def _trace( direction: str, stage: str, metadata: dict[str, Any] | None = None, + snapshot: bool = True, ) -> None: if not context.transaction_logger: return @@ -94,4 +140,5 @@ def _trace( "classifier": context.classifier, **(metadata or {}), }, + snapshot=snapshot, ) diff --git a/src/rotator_library/native_provider/http.py b/src/rotator_library/native_provider/http.py index 9678b569c..b322875ea 100644 --- a/src/rotator_library/native_provider/http.py +++ b/src/rotator_library/native_provider/http.py @@ -27,3 +27,18 @@ async def post_json(self, endpoint: str, *, headers: dict[str, str], payload: di if hasattr(response, "json"): return response.json() return response + + async def stream_json_lines(self, endpoint: str, *, headers: dict[str, str], payload: dict[str, Any]): + """Yield provider stream chunks from an injected streaming-capable client. + + Tests and provider-specific clients can expose `stream_json_lines()` to + avoid binding this foundation to one HTTP client's streaming API. A later + streaming phase can add richer `httpx.stream()` support without changing + native provider executor semantics. + """ + + if hasattr(self.client, "stream_json_lines"): + async for chunk in self.client.stream_json_lines(endpoint, headers=headers, json=payload): + yield chunk + return + raise NotImplementedError("Injected native HTTP client does not expose stream_json_lines") diff --git a/src/rotator_library/native_provider/streaming.py b/src/rotator_library/native_provider/streaming.py new file mode 100644 index 000000000..592b34704 --- /dev/null +++ b/src/rotator_library/native_provider/streaming.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Streaming helpers for provider-native execution.""" + +from __future__ import annotations + +from typing import Any + + +def stream_event_payload(event: Any) -> Any: + """Return a JSON-safe payload for stream field-cache and trace passes.""" + + if hasattr(event, "to_dict"): + return event.to_dict() + return event diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py new file mode 100644 index 000000000..1b08f9312 --- /dev/null +++ b/tests/test_native_provider_streaming.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.field_cache import FieldCacheRule +from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from rotator_library.transaction_logger import TransactionLogger + + +class FakeStreamingClient: + def __init__(self, chunks): + self.chunks = chunks + self.calls = [] + + async def stream_json_lines(self, endpoint, *, headers, json): + self.calls.append({"endpoint": endpoint, "headers": headers, "json": json}) + for chunk in self.chunks: + yield chunk + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_native_provider_stream_traces_and_yields_formatted_events(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(FieldCacheRule(name="stream_reasoning", source="stream_event", path="raw.choices.0.delta.reasoning_content", allow_missing_session=True),), + transaction_logger=logger, + ) + chunks = [ + {"choices": [{"delta": {"content": "hi", "reasoning_content": "hidden"}}]}, + "[DONE]", + ] + client = FakeStreamingClient(chunks) + + events = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(client))] + + assert events == chunks + assert client.calls[0]["json"]["stream"] is True + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "native_provider_stream_request" in pass_names + assert pass_names.count("raw_native_provider_stream_chunk") == 2 + assert pass_names.count("parsed_native_stream_event") == 2 + assert "after_field_cache_extraction" in pass_names + assert pass_names.count("formatted_client_stream_event") == 2 + + +@pytest.mark.asyncio +async def test_native_provider_stream_logs_errors(tmp_path) -> None: + class BrokenClient: + async def stream_json_lines(self, endpoint, *, headers, json): + raise RuntimeError("broken stream") + yield None + + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext(provider="native", model="gpt-test", protocol_name="openai_chat", endpoint="https://example.test/chat", transaction_logger=logger) + + with pytest.raises(RuntimeError): + [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(BrokenClient()))] + + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "transform_log_error" in pass_names From 34483efaf020be7391a150ccaf284c58a3447090 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:55:59 +0200 Subject: [PATCH 030/182] feat(providers): add Claude Code native skeleton Adds the first priority Phase 5 provider as an explicit native integration skeleton with Anthropic Messages protocol declaration, adapter config, thinking-signature field-cache rule, native header/endpoint helpers, and mock-friendly model discovery. This does not assume undocumented live behavior or wire the provider into the runtime native executor yet; it establishes a tested provider declaration path for later native routing. Tests: python -m pytest tests/test_claude_code_provider.py tests/test_provider_protocol_declarations.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_protocol_anthropic_messages.py tests/test_adapter_registry.py tests/test_field_cache_engine.py --- .../providers/claude_code_provider.py | 93 +++++++++++++++++++ tests/test_claude_code_provider.py | 80 ++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 src/rotator_library/providers/claude_code_provider.py create mode 100644 tests/test_claude_code_provider.py diff --git a/src/rotator_library/providers/claude_code_provider.py b/src/rotator_library/providers/claude_code_provider.py new file mode 100644 index 000000000..10a21d3df --- /dev/null +++ b/src/rotator_library/providers/claude_code_provider.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Claude Code provider integration skeleton for native protocol execution.""" + +from __future__ import annotations + +import os +from typing import Any, List + +import httpx + +from ..field_cache import FieldCacheInjection, FieldCacheRule +from .provider_interface import ProviderInterface + +DEFAULT_API_BASE = "https://api.anthropic.com" +FALLBACK_MODELS = ["claude_code/claude-sonnet-4-5", "claude_code/claude-opus-4-5"] + + +class ClaudeCodeProvider(ProviderInterface): + """Provider declaration for Claude Code style native requests. + + The provider starts as an explicit integration path rather than a guessed live + implementation. It declares protocol/adapters/cache rules and exposes mocked + auth/model helpers so later native wiring can use it without reintroducing a + monolithic provider transform. + """ + + provider_env_name = "claude_code" + protocol_name = "anthropic_messages" + adapter_names = ("suppress_developer_role",) + field_cache_rules = ( + FieldCacheRule( + name="claude_code_thinking_signature", + source="response", + path="content.*.signature", + mode="all", + scope=("provider", "model", "credential", "session"), + inject=FieldCacheInjection(target="request", path="metadata.thinking_signatures", as_list=True), + metadata={"purpose": "preserve Claude thinking signatures for follow-up requests"}, + ), + ) + default_rotation_mode = "sequential" + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """Fetch provider models with a conservative fallback list. + + Model discovery is intentionally mock-friendly. If the configured service + does not expose a standard `/v1/models` response, provider work can later + override this without changing protocol declarations. + """ + + try: + response = await client.get(f"{self.get_api_base().rstrip('/')}/v1/models", headers=self.get_native_headers(api_key), timeout=30) + response.raise_for_status() + models = [item.get("id") for item in response.json().get("data", []) if isinstance(item, dict) and item.get("id")] + if models: + return [self._with_prefix(model) for model in models] + except Exception: + return list(FALLBACK_MODELS) + return list(FALLBACK_MODELS) + + def get_api_base(self) -> str: + """Return the configured Claude Code API base URL.""" + + return os.getenv("CLAUDE_CODE_API_BASE", DEFAULT_API_BASE) + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "messages") -> dict[str, str]: + """Return headers for native mocked HTTP requests.""" + + return { + "Authorization": f"Bearer {credential_identifier}", + "anthropic-version": os.getenv("CLAUDE_CODE_ANTHROPIC_VERSION", "2023-06-01"), + "content-type": "application/json", + } + + def get_native_endpoint(self, model: str = "", operation: str = "messages") -> str: + """Return the provider endpoint for a native operation.""" + + if operation == "models": + return f"{self.get_api_base().rstrip('/')}/v1/models" + return f"{self.get_api_base().rstrip('/')}/v1/messages" + + def get_adapter_config(self, model: str = "") -> dict[str, dict[str, Any]]: + """Configure adapters without hardcoding provider transforms.""" + + return {"suppress_developer_role": {"replacement_role": "user"}} + + @staticmethod + def _with_prefix(model: str) -> str: + if model.startswith("claude_code/"): + return model + return f"claude_code/{model}" diff --git a/tests/test_claude_code_provider.py b/tests/test_claude_code_provider.py new file mode 100644 index 000000000..ffea0fb18 --- /dev/null +++ b/tests/test_claude_code_provider.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import pytest + +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.claude_code_provider import ClaudeCodeProvider + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeClient: + def __init__(self, payload): + self.payload = payload + self.calls = [] + + async def get(self, url, *, headers, timeout): + self.calls.append({"url": url, "headers": headers, "timeout": timeout}) + return FakeResponse(self.payload) + + +def test_claude_code_provider_is_discovered() -> None: + assert "claude_code" in PROVIDER_PLUGINS + + +def test_claude_code_provider_declares_native_protocol_adapters_and_cache_rules() -> None: + provider = ClaudeCodeProvider() + + assert provider.get_protocol_name("claude-sonnet-4-5") == "anthropic_messages" + assert provider.get_adapter_names("claude-sonnet-4-5") == ("suppress_developer_role",) + assert provider.get_adapter_config("claude-sonnet-4-5") == {"suppress_developer_role": {"replacement_role": "user"}} + rules = provider.get_field_cache_rules("claude-sonnet-4-5") + assert rules[0].name == "claude_code_thinking_signature" + assert rules[0].scope == ("provider", "model", "credential", "session") + + +def test_claude_code_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: + monkeypatch.setenv("CLAUDE_CODE_API_BASE", "https://claude-code.test") + monkeypatch.setenv("CLAUDE_CODE_ANTHROPIC_VERSION", "2099-01-01") + provider = ClaudeCodeProvider() + + assert provider.get_native_endpoint(operation="messages") == "https://claude-code.test/v1/messages" + assert provider.get_native_headers("token") == { + "Authorization": "Bearer token", + "anthropic-version": "2099-01-01", + "content-type": "application/json", + } + + +@pytest.mark.asyncio +async def test_claude_code_provider_get_models_uses_mocked_models_endpoint(monkeypatch) -> None: + monkeypatch.setenv("CLAUDE_CODE_API_BASE", "https://claude-code.test") + client = FakeClient({"data": [{"id": "claude-sonnet-test"}]}) + provider = ClaudeCodeProvider() + + models = await provider.get_models("token", client) + + assert models == ["claude_code/claude-sonnet-test"] + assert client.calls[0]["url"] == "https://claude-code.test/v1/models" + assert client.calls[0]["headers"]["Authorization"] == "Bearer token" + + +@pytest.mark.asyncio +async def test_claude_code_provider_get_models_falls_back_on_errors() -> None: + class BrokenClient: + async def get(self, *args, **kwargs): + raise RuntimeError("offline") + + assert await ClaudeCodeProvider().get_models("token", BrokenClient()) == [ + "claude_code/claude-sonnet-4-5", + "claude_code/claude-opus-4-5", + ] From d0abac0cd8b2230bbc8cf0e6eced3961f2447c66 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:57:05 +0200 Subject: [PATCH 031/182] feat(providers): add Codex native skeleton Adds the second priority Phase 5 provider as a Responses-protocol integration skeleton with native header/endpoint helpers, mocked model discovery, and a response-id field-cache rule for continuation. The provider remains declarative and test-driven so native routing can opt into it later without silent LiteLLM fallback. Tests: python -m pytest tests/test_codex_provider.py tests/test_claude_code_provider.py tests/test_protocol_responses.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py --- .../providers/codex_provider.py | 72 +++++++++++++++++++ tests/test_codex_provider.py | 71 ++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 src/rotator_library/providers/codex_provider.py create mode 100644 tests/test_codex_provider.py diff --git a/src/rotator_library/providers/codex_provider.py b/src/rotator_library/providers/codex_provider.py new file mode 100644 index 000000000..68848902b --- /dev/null +++ b/src/rotator_library/providers/codex_provider.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Codex provider integration skeleton for native Responses execution.""" + +from __future__ import annotations + +import os +from typing import List + +import httpx + +from ..field_cache import FieldCacheInjection, FieldCacheRule +from .provider_interface import ProviderInterface + +DEFAULT_API_BASE = "https://api.openai.com" +FALLBACK_MODELS = ["codex/codex-mini-latest", "codex/gpt-5.1-codex"] + + +class CodexProvider(ProviderInterface): + """Provider declaration for Codex-style native Responses requests.""" + + provider_env_name = "codex" + protocol_name = "responses" + adapter_names: tuple[str, ...] = () + field_cache_rules = ( + FieldCacheRule( + name="codex_previous_response_id", + source="response", + path="id", + scope=("provider", "model", "credential", "session"), + inject=FieldCacheInjection(target="request", path="previous_response_id", when_missing_only=True), + metadata={"purpose": "preserve Responses continuation IDs for Codex sessions"}, + ), + ) + default_rotation_mode = "sequential" + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """Fetch Codex-visible models with fallback names for offline tests.""" + + try: + response = await client.get(self.get_native_endpoint(operation="models"), headers=self.get_native_headers(api_key), timeout=30) + response.raise_for_status() + models = [item.get("id") for item in response.json().get("data", []) if isinstance(item, dict) and item.get("id")] + codex_models = [model for model in models if "codex" in model.lower()] + if codex_models: + return [self._with_prefix(model) for model in codex_models] + except Exception: + return list(FALLBACK_MODELS) + return list(FALLBACK_MODELS) + + def get_api_base(self) -> str: + """Return the configured Codex API base URL.""" + + return os.getenv("CODEX_API_BASE", DEFAULT_API_BASE) + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "responses") -> dict[str, str]: + """Return headers for Codex native HTTP calls.""" + + return {"Authorization": f"Bearer {credential_identifier}", "content-type": "application/json"} + + def get_native_endpoint(self, model: str = "", operation: str = "responses") -> str: + """Return the native Codex endpoint for an operation.""" + + suffix = "/v1/models" if operation == "models" else "/v1/responses" + return f"{self.get_api_base().rstrip('/')}{suffix}" + + @staticmethod + def _with_prefix(model: str) -> str: + if model.startswith("codex/"): + return model + return f"codex/{model}" diff --git a/tests/test_codex_provider.py b/tests/test_codex_provider.py new file mode 100644 index 000000000..779e4b4ba --- /dev/null +++ b/tests/test_codex_provider.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import pytest + +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.codex_provider import CodexProvider + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeClient: + def __init__(self, payload): + self.payload = payload + self.calls = [] + + async def get(self, url, *, headers, timeout): + self.calls.append({"url": url, "headers": headers, "timeout": timeout}) + return FakeResponse(self.payload) + + +def test_codex_provider_is_discovered() -> None: + assert "codex" in PROVIDER_PLUGINS + + +def test_codex_provider_declares_responses_protocol_and_cache_rule() -> None: + provider = CodexProvider() + + assert provider.get_protocol_name("codex-mini-latest") == "responses" + assert provider.get_adapter_names("codex-mini-latest") == () + rules = provider.get_field_cache_rules("codex-mini-latest") + assert rules[0].name == "codex_previous_response_id" + assert rules[0].inject.path == "previous_response_id" + + +def test_codex_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: + monkeypatch.setenv("CODEX_API_BASE", "https://codex.test") + provider = CodexProvider() + + assert provider.get_native_endpoint(operation="responses") == "https://codex.test/v1/responses" + assert provider.get_native_endpoint(operation="models") == "https://codex.test/v1/models" + assert provider.get_native_headers("token") == {"Authorization": "Bearer token", "content-type": "application/json"} + + +@pytest.mark.asyncio +async def test_codex_provider_get_models_filters_codex_models(monkeypatch) -> None: + monkeypatch.setenv("CODEX_API_BASE", "https://codex.test") + client = FakeClient({"data": [{"id": "gpt-5.1-codex"}, {"id": "gpt-5.1"}]}) + provider = CodexProvider() + + models = await provider.get_models("token", client) + + assert models == ["codex/gpt-5.1-codex"] + assert client.calls[0]["url"] == "https://codex.test/v1/models" + + +@pytest.mark.asyncio +async def test_codex_provider_get_models_falls_back_on_errors() -> None: + class BrokenClient: + async def get(self, *args, **kwargs): + raise RuntimeError("offline") + + assert await CodexProvider().get_models("token", BrokenClient()) == ["codex/codex-mini-latest", "codex/gpt-5.1-codex"] From eec7544aeee0ab3d4476c914f09c620ed35e7afb Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 00:58:10 +0200 Subject: [PATCH 032/182] feat(providers): add Copilot native skeleton Adds the third priority Phase 5 provider as an OpenAI Chat-compatible native integration skeleton with auth/endpoint helpers, mocked model discovery, developer-role adapter declaration, and no guessed field-cache state. This keeps Copilot native support explicit and testable without assuming hidden session behavior or wiring it into live routing yet. Tests: python -m pytest tests/test_copilot_provider.py tests/test_codex_provider.py tests/test_claude_code_provider.py tests/test_protocol_openai_chat.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py --- .../providers/copilot_provider.py | 75 +++++++++++++++++++ tests/test_copilot_provider.py | 75 +++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 src/rotator_library/providers/copilot_provider.py create mode 100644 tests/test_copilot_provider.py diff --git a/src/rotator_library/providers/copilot_provider.py b/src/rotator_library/providers/copilot_provider.py new file mode 100644 index 000000000..bfe47a2c8 --- /dev/null +++ b/src/rotator_library/providers/copilot_provider.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Copilot provider integration skeleton for native OpenAI Chat execution.""" + +from __future__ import annotations + +import os +from typing import List + +import httpx + +from .provider_interface import ProviderInterface + +DEFAULT_API_BASE = "https://api.githubcopilot.com" +FALLBACK_MODELS = ["copilot/gpt-4.1", "copilot/claude-sonnet-4-5"] + + +class CopilotProvider(ProviderInterface): + """Provider declaration for Copilot-style OpenAI-compatible chat calls. + + The skeleton intentionally avoids inventing field-cache rules until a stable + provider session/conversation field is identified. This keeps Copilot native + support explicit and testable without guessing hidden behavior. + """ + + provider_env_name = "copilot" + protocol_name = "openai_chat" + adapter_names = ("suppress_developer_role",) + field_cache_rules: tuple = () + default_rotation_mode = "sequential" + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """Fetch Copilot-visible models with a safe fallback list.""" + + try: + response = await client.get(self.get_native_endpoint(operation="models"), headers=self.get_native_headers(api_key), timeout=30) + response.raise_for_status() + models = [item.get("id") for item in response.json().get("data", []) if isinstance(item, dict) and item.get("id")] + if models: + return [self._with_prefix(model) for model in models] + except Exception: + return list(FALLBACK_MODELS) + return list(FALLBACK_MODELS) + + def get_api_base(self) -> str: + """Return the configured Copilot API base URL.""" + + return os.getenv("COPILOT_API_BASE", DEFAULT_API_BASE) + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "chat") -> dict[str, str]: + """Return headers for Copilot native HTTP calls.""" + + return { + "Authorization": f"Bearer {credential_identifier}", + "content-type": "application/json", + "Copilot-Integration-Id": os.getenv("COPILOT_INTEGRATION_ID", "llm-api-key-proxy"), + } + + def get_native_endpoint(self, model: str = "", operation: str = "chat") -> str: + """Return the Copilot endpoint for a native operation.""" + + suffix = "/models" if operation == "models" else "/chat/completions" + return f"{self.get_api_base().rstrip('/')}{suffix}" + + def get_adapter_config(self, model: str = "") -> dict[str, dict[str, str]]: + """Configure role suppression declaratively for OpenAI-compatible chat.""" + + return {"suppress_developer_role": {"replacement_role": "system"}} + + @staticmethod + def _with_prefix(model: str) -> str: + if model.startswith("copilot/"): + return model + return f"copilot/{model}" diff --git a/tests/test_copilot_provider.py b/tests/test_copilot_provider.py new file mode 100644 index 000000000..a60bc9cc4 --- /dev/null +++ b/tests/test_copilot_provider.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import pytest + +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.copilot_provider import CopilotProvider + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeClient: + def __init__(self, payload): + self.payload = payload + self.calls = [] + + async def get(self, url, *, headers, timeout): + self.calls.append({"url": url, "headers": headers, "timeout": timeout}) + return FakeResponse(self.payload) + + +def test_copilot_provider_is_discovered() -> None: + assert "copilot" in PROVIDER_PLUGINS + + +def test_copilot_provider_declares_openai_chat_protocol_and_adapter() -> None: + provider = CopilotProvider() + + assert provider.get_protocol_name("gpt-4.1") == "openai_chat" + assert provider.get_adapter_names("gpt-4.1") == ("suppress_developer_role",) + assert provider.get_adapter_config("gpt-4.1") == {"suppress_developer_role": {"replacement_role": "system"}} + assert provider.get_field_cache_rules("gpt-4.1") == () + + +def test_copilot_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: + monkeypatch.setenv("COPILOT_API_BASE", "https://copilot.test") + monkeypatch.setenv("COPILOT_INTEGRATION_ID", "proxy-test") + provider = CopilotProvider() + + assert provider.get_native_endpoint(operation="chat") == "https://copilot.test/chat/completions" + assert provider.get_native_endpoint(operation="models") == "https://copilot.test/models" + assert provider.get_native_headers("token") == { + "Authorization": "Bearer token", + "content-type": "application/json", + "Copilot-Integration-Id": "proxy-test", + } + + +@pytest.mark.asyncio +async def test_copilot_provider_get_models_uses_mocked_endpoint(monkeypatch) -> None: + monkeypatch.setenv("COPILOT_API_BASE", "https://copilot.test") + client = FakeClient({"data": [{"id": "gpt-4.1"}]}) + provider = CopilotProvider() + + models = await provider.get_models("token", client) + + assert models == ["copilot/gpt-4.1"] + assert client.calls[0]["url"] == "https://copilot.test/models" + + +@pytest.mark.asyncio +async def test_copilot_provider_get_models_falls_back_on_errors() -> None: + class BrokenClient: + async def get(self, *args, **kwargs): + raise RuntimeError("offline") + + assert await CopilotProvider().get_models("token", BrokenClient()) == ["copilot/gpt-4.1", "copilot/claude-sonnet-4-5"] From ae7dbabb3f886a3d093c63c98a321cbd96cc5ce6 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:02:47 +0200 Subject: [PATCH 033/182] feat(providers): restore Antigravity native skeleton Restores a tested stable subset from the retired Antigravity provider: model aliases, available-model filtering, static required headers, internal endpoint shapes, quota groups, and thought-signature field-cache declaration. Brittle retired device-profile/fingerprint logic and monolithic request transforms are intentionally not restored; future live behavior must be added behind tests. Tests: python -m pytest tests/test_antigravity_provider_restore.py tests/test_copilot_provider.py tests/test_codex_provider.py tests/test_claude_code_provider.py tests/test_protocol_gemini.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py --- .../providers/antigravity_provider.py | 161 ++++++++++++++++++ tests/test_antigravity_provider_restore.py | 87 ++++++++++ 2 files changed, 248 insertions(+) create mode 100644 src/rotator_library/providers/antigravity_provider.py create mode 100644 tests/test_antigravity_provider_restore.py diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py new file mode 100644 index 000000000..c3c6c5571 --- /dev/null +++ b/src/rotator_library/providers/antigravity_provider.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Antigravity provider integration skeleton restored from safe retired pieces.""" + +from __future__ import annotations + +import os +from typing import Any, List, Optional + +import httpx + +from ..field_cache import FieldCacheInjection, FieldCacheRule +from .provider_interface import ProviderInterface + +BASE_URLS = [ + "https://daily-cloudcode-pa.sandbox.googleapis.com/v1internal", + "https://daily-cloudcode-pa.googleapis.com/v1internal", + "https://cloudcode-pa.googleapis.com/v1internal", +] +ANTIGRAVITY_HEADERS = { + "User-Agent": "antigravity/1.15.8 windows/amd64", + "X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1", + "Client-Metadata": '{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}', +} +AVAILABLE_MODELS = [ + "gemini-2.5-flash", + "gemini-2.5-flash-lite", + "gemini-3-pro-preview", + "gemini-3-flash", + "claude-sonnet-4.5", + "claude-opus-4.5", + "claude-opus-4.6", +] +MODEL_ALIAS_MAP = { + "rev19-uic3-1p": "gemini-2.5-computer-use-preview-10-2025", + "gemini-3-pro-image": "gemini-3-pro-image-preview", + "gemini-3-pro-low": "gemini-3-pro-preview", + "gemini-3-pro-high": "gemini-3-pro-preview", + "claude-sonnet-4-5": "claude-sonnet-4.5", + "claude-opus-4-5": "claude-opus-4.5", + "claude-opus-4-6": "claude-opus-4.6", +} +MODEL_ALIAS_REVERSE = {public: internal for internal, public in MODEL_ALIAS_MAP.items()} +EXCLUDED_MODELS = {"chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-2.5-pro"} + + +class AntigravityProvider(ProviderInterface): + """Safe restored Antigravity integration path. + + The retired provider contained valuable model/header/endpoint knowledge mixed + with fragile device-profile and monolithic transform logic. This active + skeleton restores only stable declarations and helpers so native provider + work can proceed behind tests before any live routing is enabled. + """ + + provider_env_name = "antigravity" + protocol_name = "gemini" + adapter_names = ("field_rename",) + field_cache_rules = ( + FieldCacheRule( + name="antigravity_thought_signature", + source="response", + path="candidates.*.content.parts.*.thoughtSignature", + mode="all", + scope=("provider", "model", "credential", "session"), + inject=FieldCacheInjection(target="request", path="metadata.thoughtSignatures", as_list=True), + metadata={"purpose": "preserve Gemini thought signatures across Antigravity turns"}, + ), + ) + model_quota_groups = { + "gemini": ["gemini-2.5-flash", "gemini-2.5-flash-lite", "gemini-3-pro-preview", "gemini-3-flash"], + "claude": ["claude-sonnet-4.5", "claude-opus-4.5", "claude-opus-4.6"], + } + default_rotation_mode = "sequential" + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """Fetch available Antigravity models or return the restored safe list.""" + + try: + response = await client.post(self.get_native_endpoint(operation="models"), headers=self.get_native_headers(api_key), json={}, timeout=30) + response.raise_for_status() + models = self._models_from_response(response.json()) + if models: + return [self._with_prefix(model) for model in models] + except Exception: + return [self._with_prefix(model) for model in AVAILABLE_MODELS] + return [self._with_prefix(model) for model in AVAILABLE_MODELS] + + def get_api_base(self) -> str: + """Return the first configured Antigravity base URL.""" + + configured = os.getenv("ANTIGRAVITY_API_BASE") + return configured.rstrip("/") if configured else BASE_URLS[0] + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "generate") -> dict[str, str]: + """Return static Antigravity headers plus bearer auth. + + Retired per-device fingerprint headers are intentionally not restored; + they are brittle and should only return if a current service requirement + is verified with tests. + """ + + return { + "Authorization": f"Bearer {credential_identifier}", + "Content-Type": "application/json", + "Accept": "text/event-stream", + **ANTIGRAVITY_HEADERS, + } + + def get_native_endpoint(self, model: str = "", operation: str = "generate") -> str: + """Return Antigravity internal operation endpoints.""" + + if operation == "models": + return f"{self.get_api_base()}:fetchAvailableModels" + return f"{self.get_api_base()}:streamGenerateContent?alt=sse" + + def get_adapter_config(self, model: str = "") -> dict[str, dict[str, Any]]: + """Configure minimal payload path copies needed by native tests.""" + + return {"field_rename": {"rules": []}} + + def get_model_tier_requirement(self, model: str) -> Optional[int]: + """Antigravity exposes no restored model-tier restriction.""" + + return None + + def normalize_model_for_tracking(self, model: str) -> str: + """Normalize internal names to public aliases while preserving prefix.""" + + if "/" in model: + provider, clean_model = model.split("/", 1) + return f"{provider}/{self._api_to_user_model(clean_model)}" + return self._api_to_user_model(model) + + @staticmethod + def _with_prefix(model: str) -> str: + if model.startswith("antigravity/"): + return model + return f"antigravity/{model}" + + @staticmethod + def _alias_to_internal(alias: str) -> str: + return MODEL_ALIAS_REVERSE.get(alias, alias) + + @staticmethod + def _api_to_user_model(internal: str) -> str: + return MODEL_ALIAS_MAP.get(internal, internal) + + def _models_from_response(self, payload: dict[str, Any]) -> list[str]: + raw_models: list[str] = [] + if isinstance(payload.get("models"), dict): + raw_models.extend(payload["models"].keys()) + if isinstance(payload.get("data"), list): + raw_models.extend(item.get("id") for item in payload["data"] if isinstance(item, dict) and item.get("id")) + result = [] + for model in raw_models: + public = self._api_to_user_model(str(model)) + if public not in EXCLUDED_MODELS and public in AVAILABLE_MODELS and public not in result: + result.append(public) + return result diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py new file mode 100644 index 000000000..c6fd5d349 --- /dev/null +++ b/tests/test_antigravity_provider_restore.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import pytest + +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.antigravity_provider import AntigravityProvider + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeClient: + def __init__(self, payload): + self.payload = payload + self.calls = [] + + async def post(self, url, *, headers, json, timeout): + self.calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout}) + return FakeResponse(self.payload) + + +def test_antigravity_provider_is_discovered() -> None: + assert "antigravity" in PROVIDER_PLUGINS + + +def test_antigravity_provider_restores_safe_declarations() -> None: + provider = AntigravityProvider() + + assert provider.get_protocol_name("gemini-3-flash") == "gemini" + assert provider.get_adapter_names("gemini-3-flash") == ("field_rename",) + assert provider.get_model_tier_requirement("antigravity/gemini-3-flash") is None + rules = provider.get_field_cache_rules("gemini-3-flash") + assert rules[0].name == "antigravity_thought_signature" + assert rules[0].scope == ("provider", "model", "credential", "session") + + +def test_antigravity_provider_builds_static_headers_without_device_profile(monkeypatch) -> None: + monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") + provider = AntigravityProvider() + + headers = provider.get_native_headers("token") + + assert provider.get_native_endpoint(operation="generate") == "https://antigravity.test/v1internal:streamGenerateContent?alt=sse" + assert provider.get_native_endpoint(operation="models") == "https://antigravity.test/v1internal:fetchAvailableModels" + assert headers["Authorization"] == "Bearer token" + assert headers["X-Goog-Api-Client"] == "google-cloud-sdk vscode_cloudshelleditor/0.1" + assert "X-Client-Device-Id" not in headers + + +def test_antigravity_model_aliases_and_tracking_normalization() -> None: + provider = AntigravityProvider() + + assert provider._alias_to_internal("claude-sonnet-4.5") == "claude-sonnet-4-5" + assert provider.normalize_model_for_tracking("antigravity/claude-sonnet-4-5") == "antigravity/claude-sonnet-4.5" + + +@pytest.mark.asyncio +async def test_antigravity_get_models_filters_and_aliases_mocked_response(monkeypatch) -> None: + monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") + client = FakeClient({"models": {"gemini-3-pro-low": {}, "chat_20706": {}, "claude-opus-4-6": {}}}) + provider = AntigravityProvider() + + models = await provider.get_models("token", client) + + assert models == ["antigravity/gemini-3-pro-preview", "antigravity/claude-opus-4.6"] + assert client.calls[0]["url"] == "https://antigravity.test/v1internal:fetchAvailableModels" + assert client.calls[0]["headers"]["Authorization"] == "Bearer token" + + +@pytest.mark.asyncio +async def test_antigravity_get_models_falls_back_on_errors() -> None: + class BrokenClient: + async def post(self, *args, **kwargs): + raise RuntimeError("offline") + + models = await AntigravityProvider().get_models("token", BrokenClient()) + + assert "antigravity/gemini-3-flash" in models + assert "antigravity/claude-opus-4.6" in models From 4c2b34b8eeb947e9a3a6e7dffb442c22bbc220b8 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:04:16 +0200 Subject: [PATCH 034/182] feat(providers): declare Gemini CLI native protocol metadata Adds safe Phase 5 protocol metadata to the existing Gemini CLI provider: Gemini protocol declaration and a thoughtSignature field-cache rule mirroring current custom behavior. The provider still reports custom logic and keeps its existing execution path unchanged; the metadata is for future opt-in native routing. Tests: python -m pytest tests/test_gemini_cli_protocol_declarations.py tests/test_protocol_gemini.py tests/test_field_cache_engine.py tests/test_provider_protocol_declarations.py tests/test_antigravity_provider_restore.py --- .../providers/gemini_cli_provider.py | 19 +++++++++++++ .../test_gemini_cli_protocol_declarations.py | 27 +++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 tests/test_gemini_cli_protocol_declarations.py diff --git a/src/rotator_library/providers/gemini_cli_provider.py b/src/rotator_library/providers/gemini_cli_provider.py index fe0080ca3..a861a095c 100644 --- a/src/rotator_library/providers/gemini_cli_provider.py +++ b/src/rotator_library/providers/gemini_cli_provider.py @@ -13,6 +13,7 @@ from .provider_interface import ProviderInterface, QuotaGroupMap, UsageResetConfigDef from .gemini_auth_base import GeminiAuthBase from .provider_cache import ProviderCache +from ..field_cache import FieldCacheInjection, FieldCacheRule from .utilities.gemini_cli_quota_tracker import GeminiCliQuotaTracker from .utilities.gemini_shared_utils import ( env_bool, @@ -156,6 +157,24 @@ class GeminiCliProvider( # Provider name for env var lookups (QUOTA_GROUPS_GEMINI_CLI_*) provider_env_name: str = "gemini_cli" + # Native protocol declarations for the Phase 5 opt-in execution seam. The + # existing custom Gemini CLI execution path remains active; these metadata + # hooks let future native routing reuse the Gemini protocol and field-cache + # engine without rewriting this provider first. + protocol_name: str = "gemini" + adapter_names: tuple[str, ...] = () + field_cache_rules = ( + FieldCacheRule( + name="gemini_cli_thought_signature", + source="response", + path="candidates.*.content.parts.*.thoughtSignature", + mode="all", + scope=("provider", "model", "credential", "session"), + inject=FieldCacheInjection(target="request", path="metadata.thoughtSignatures", as_list=True), + metadata={"purpose": "mirror existing Gemini CLI thoughtSignature preservation in the native cache seam"}, + ), + ) + # Tier name -> priority mapping (from centralized tier utilities) # Lower numbers = higher priority (ULTRA=1 > PRO=2 > FREE=3) tier_priorities = TIER_PRIORITIES diff --git a/tests/test_gemini_cli_protocol_declarations.py b/tests/test_gemini_cli_protocol_declarations.py new file mode 100644 index 000000000..2a99adfe1 --- /dev/null +++ b/tests/test_gemini_cli_protocol_declarations.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import pytest + +from rotator_library.providers.gemini_cli_provider import GeminiCliProvider + + +@pytest.mark.asyncio +async def test_gemini_cli_declares_gemini_protocol_without_changing_custom_logic() -> None: + provider = GeminiCliProvider() + + assert provider.has_custom_logic() is True + assert provider.get_protocol_name("gemini-3-flash-preview") == "gemini" + assert provider.get_adapter_names("gemini-3-flash-preview") == () + + +@pytest.mark.asyncio +async def test_gemini_cli_declares_thought_signature_cache_rule() -> None: + provider = GeminiCliProvider() + + rules = provider.get_field_cache_rules("gemini-3-flash-preview") + + assert len(rules) == 1 + assert rules[0].name == "gemini_cli_thought_signature" + assert rules[0].path == "candidates.*.content.parts.*.thoughtSignature" + assert rules[0].inject.path == "metadata.thoughtSignatures" + assert rules[0].scope == ("provider", "model", "credential", "session") From 42168cefaf232db814815ea2d9b244dfbe9e3f67 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:13:55 +0200 Subject: [PATCH 035/182] fix(native-provider): align provider adapters and trace passes Fixes Phase 5 review findings by changing Claude Code and Copilot suppress-developer-role config to the adapter's mode key, adding effective adapter behavior tests, removing Antigravity's no-op field_rename declaration, and adding explicit native field-cache injection/stream-extraction trace passes. Tests: python -m pytest tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_adapter_registry.py --- src/rotator_library/native_provider/executor.py | 10 ++++++++++ .../providers/antigravity_provider.py | 4 ++-- .../providers/claude_code_provider.py | 2 +- .../providers/copilot_provider.py | 2 +- tests/test_antigravity_provider_restore.py | 3 ++- tests/test_claude_code_provider.py | 17 ++++++++++++++++- tests/test_copilot_provider.py | 17 ++++++++++++++++- tests/test_native_provider_executor.py | 1 + tests/test_native_provider_streaming.py | 1 + 9 files changed, 50 insertions(+), 7 deletions(-) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index cb57df946..cddc8d60f 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -45,6 +45,7 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont context.field_cache_context(), transaction_logger=logger, ) + self._trace(context, "after_field_cache_injection", provider_request, direction="request", stage="adapter") self._trace(context, "native_provider_request", provider_request, direction="request", stage="provider") raw_response = await transport.post_json(context.endpoint, headers=context.headers, payload=provider_request) self._trace(context, "raw_native_provider_response", raw_response, direction="response", stage="provider") @@ -88,6 +89,7 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte context.field_cache_context(), transaction_logger=logger, ) + self._trace(context, "after_field_cache_injection", provider_request, direction="request", stage="adapter") self._trace(context, "native_provider_stream_request", provider_request, direction="request", stage="provider") async for raw_chunk in transport.stream_json_lines(context.endpoint, headers=context.headers, payload=provider_request): self._trace(context, "raw_native_provider_stream_chunk", raw_chunk, direction="stream", stage="provider") @@ -95,6 +97,14 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte event_payload = stream_event_payload(event) self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") await cache_engine.extract("stream_event", event_payload, context.field_cache_context(), transaction_logger=logger) + self._trace( + context, + "after_field_cache_stream_extraction", + {"source": "stream_event"}, + direction="stream", + stage="adapter", + snapshot=False, + ) formatted = protocol.format_stream_event(event, protocol_context) self._trace(context, "formatted_client_stream_event", formatted, direction="stream", stage="final", snapshot=False) yield formatted diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index c3c6c5571..eac7e0562 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -56,7 +56,7 @@ class AntigravityProvider(ProviderInterface): provider_env_name = "antigravity" protocol_name = "gemini" - adapter_names = ("field_rename",) + adapter_names: tuple[str, ...] = () field_cache_rules = ( FieldCacheRule( name="antigravity_thought_signature", @@ -118,7 +118,7 @@ def get_native_endpoint(self, model: str = "", operation: str = "generate") -> s def get_adapter_config(self, model: str = "") -> dict[str, dict[str, Any]]: """Configure minimal payload path copies needed by native tests.""" - return {"field_rename": {"rules": []}} + return {} def get_model_tier_requirement(self, model: str) -> Optional[int]: """Antigravity exposes no restored model-tier restriction.""" diff --git a/src/rotator_library/providers/claude_code_provider.py b/src/rotator_library/providers/claude_code_provider.py index 10a21d3df..12a03fb64 100644 --- a/src/rotator_library/providers/claude_code_provider.py +++ b/src/rotator_library/providers/claude_code_provider.py @@ -84,7 +84,7 @@ def get_native_endpoint(self, model: str = "", operation: str = "messages") -> s def get_adapter_config(self, model: str = "") -> dict[str, dict[str, Any]]: """Configure adapters without hardcoding provider transforms.""" - return {"suppress_developer_role": {"replacement_role": "user"}} + return {"suppress_developer_role": {"mode": "user"}} @staticmethod def _with_prefix(model: str) -> str: diff --git a/src/rotator_library/providers/copilot_provider.py b/src/rotator_library/providers/copilot_provider.py index bfe47a2c8..073bb2efa 100644 --- a/src/rotator_library/providers/copilot_provider.py +++ b/src/rotator_library/providers/copilot_provider.py @@ -66,7 +66,7 @@ def get_native_endpoint(self, model: str = "", operation: str = "chat") -> str: def get_adapter_config(self, model: str = "") -> dict[str, dict[str, str]]: """Configure role suppression declaratively for OpenAI-compatible chat.""" - return {"suppress_developer_role": {"replacement_role": "system"}} + return {"suppress_developer_role": {"mode": "system"}} @staticmethod def _with_prefix(model: str) -> str: diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py index c6fd5d349..d2a79d23d 100644 --- a/tests/test_antigravity_provider_restore.py +++ b/tests/test_antigravity_provider_restore.py @@ -35,7 +35,8 @@ def test_antigravity_provider_restores_safe_declarations() -> None: provider = AntigravityProvider() assert provider.get_protocol_name("gemini-3-flash") == "gemini" - assert provider.get_adapter_names("gemini-3-flash") == ("field_rename",) + assert provider.get_adapter_names("gemini-3-flash") == () + assert provider.get_adapter_config("gemini-3-flash") == {} assert provider.get_model_tier_requirement("antigravity/gemini-3-flash") is None rules = provider.get_field_cache_rules("gemini-3-flash") assert rules[0].name == "antigravity_thought_signature" diff --git a/tests/test_claude_code_provider.py b/tests/test_claude_code_provider.py index ffea0fb18..ef910067e 100644 --- a/tests/test_claude_code_provider.py +++ b/tests/test_claude_code_provider.py @@ -2,6 +2,7 @@ import pytest +from rotator_library.adapters import AdapterContext, get_adapter, run_adapter_chain from rotator_library.providers import PROVIDER_PLUGINS from rotator_library.providers.claude_code_provider import ClaudeCodeProvider @@ -36,12 +37,26 @@ def test_claude_code_provider_declares_native_protocol_adapters_and_cache_rules( assert provider.get_protocol_name("claude-sonnet-4-5") == "anthropic_messages" assert provider.get_adapter_names("claude-sonnet-4-5") == ("suppress_developer_role",) - assert provider.get_adapter_config("claude-sonnet-4-5") == {"suppress_developer_role": {"replacement_role": "user"}} + assert provider.get_adapter_config("claude-sonnet-4-5") == {"suppress_developer_role": {"mode": "user"}} rules = provider.get_field_cache_rules("claude-sonnet-4-5") assert rules[0].name == "claude_code_thinking_signature" assert rules[0].scope == ("provider", "model", "credential", "session") +@pytest.mark.asyncio +async def test_claude_code_adapter_config_converts_developer_role_to_user() -> None: + provider = ClaudeCodeProvider() + + result = await run_adapter_chain( + [get_adapter("suppress_developer_role")], + {"messages": [{"role": "developer", "content": "rules"}]}, + AdapterContext(adapter_config=provider.get_adapter_config("claude-sonnet-4-5")), + stage="request", + ) + + assert result["messages"][0]["role"] == "user" + + def test_claude_code_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: monkeypatch.setenv("CLAUDE_CODE_API_BASE", "https://claude-code.test") monkeypatch.setenv("CLAUDE_CODE_ANTHROPIC_VERSION", "2099-01-01") diff --git a/tests/test_copilot_provider.py b/tests/test_copilot_provider.py index a60bc9cc4..bc0f3c929 100644 --- a/tests/test_copilot_provider.py +++ b/tests/test_copilot_provider.py @@ -2,6 +2,7 @@ import pytest +from rotator_library.adapters import AdapterContext, get_adapter, run_adapter_chain from rotator_library.providers import PROVIDER_PLUGINS from rotator_library.providers.copilot_provider import CopilotProvider @@ -36,10 +37,24 @@ def test_copilot_provider_declares_openai_chat_protocol_and_adapter() -> None: assert provider.get_protocol_name("gpt-4.1") == "openai_chat" assert provider.get_adapter_names("gpt-4.1") == ("suppress_developer_role",) - assert provider.get_adapter_config("gpt-4.1") == {"suppress_developer_role": {"replacement_role": "system"}} + assert provider.get_adapter_config("gpt-4.1") == {"suppress_developer_role": {"mode": "system"}} assert provider.get_field_cache_rules("gpt-4.1") == () +@pytest.mark.asyncio +async def test_copilot_adapter_config_converts_developer_role_to_system() -> None: + provider = CopilotProvider() + + result = await run_adapter_chain( + [get_adapter("suppress_developer_role")], + {"messages": [{"role": "developer", "content": "rules"}]}, + AdapterContext(adapter_config=provider.get_adapter_config("gpt-4.1")), + stage="request", + ) + + assert result["messages"][0]["role"] == "system" + + def test_copilot_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: monkeypatch.setenv("COPILOT_API_BASE", "https://copilot.test") monkeypatch.setenv("COPILOT_INTEGRATION_ID", "proxy-test") diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index d8dbfaec8..14ca38850 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -72,6 +72,7 @@ async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tm assert client.calls[0]["json"]["model"] == "provider/gpt-test" pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] assert "native_protocol_selected" in pass_names + assert "after_field_cache_injection" in pass_names assert "native_provider_request" in pass_names assert "raw_native_provider_response" in pass_names assert "parsed_native_provider_response" in pass_names diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index 1b08f9312..62f94473a 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -50,6 +50,7 @@ async def test_native_provider_stream_traces_and_yields_formatted_events(tmp_pat assert pass_names.count("raw_native_provider_stream_chunk") == 2 assert pass_names.count("parsed_native_stream_event") == 2 assert "after_field_cache_extraction" in pass_names + assert "after_field_cache_stream_extraction" in pass_names assert pass_names.count("formatted_client_stream_event") == 2 From b38054ca54a17c6ac3c9cffdc043d794118f8247 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:21:49 +0200 Subject: [PATCH 036/182] docs(experimental): plan routing fallback groups Adds the Phase 6 plan for ordered fallback groups, target resolution, fallback policy, native/custom/LiteLLM execution selection, RequestExecutor integration, streaming fallback rules, trace requirements, tests, and review checkpoints. Reports remain uncommitted for user-facing review only. --- .../phase-6-routing-fallback-groups.md | 294 ++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 docs/experimental/phase-6-routing-fallback-groups.md diff --git a/docs/experimental/phase-6-routing-fallback-groups.md b/docs/experimental/phase-6-routing-fallback-groups.md new file mode 100644 index 000000000..d546e8efc --- /dev/null +++ b/docs/experimental/phase-6-routing-fallback-groups.md @@ -0,0 +1,294 @@ +# Phase 6 Plan: Routing And Fallback Groups + +## Goal + +Add explicit ordered routing/fallback groups and a provider-native routing seam that can choose between the current LiteLLM execution path, provider custom logic, and the Phase 5 `NativeProviderExecutor`. Fallback groups are the priority deliverable: a model can resolve to an ordered chain of concrete targets, and the executor can try the next target on eligible failures without replacing the existing credential rotation, usage tracking, session tracking, or retry-after parsing. + +## Non-Goals + +- Do not replace `UsageManager`, `SelectionEngine`, `SessionTracker`, or retry-after parsing. +- Do not add full target-group selector language before fallback groups work. +- Do not implement multi-user/security isolation. +- Do not implement a complete config-file system; Phase 10 owns config polish. +- Do not route every provider natively by default. +- Do not fallback after visible streamed output has been emitted. +- Do not silently fallback to LiteLLM for native providers unless the fallback chain explicitly includes a LiteLLM target. +- Do not make Claude Code/Codex/Copilot/Antigravity live-production claims beyond mocked/tested routing behavior. + +## Current Code Context + +- `RequestExecutor` is the live retry/rotation path. +- `RequestExecutor` already gets provider plugin instances, usage managers, credentials, request context, transaction logger, and LiteLLM provider params. +- Phase 5 `NativeProviderExecutor` is isolated and opt-in. +- Providers can declare `protocol_name`, adapter names, field-cache rules, native endpoints, and native headers. +- Model resolution currently maps a single requested model to one provider/model pair. +- Existing retry loops rotate credentials inside one provider; Phase 6 must add model/provider target fallback around that without breaking per-provider credential rotation. +- Existing transaction trace can log routing decisions. +- Existing streaming handler can rotate credentials before output, but stream fallback after output must remain disallowed. + +## Files To Add + +- `src/rotator_library/routing/__init__.py` +- `src/rotator_library/routing/types.py` +- `src/rotator_library/routing/config.py` +- `src/rotator_library/routing/resolver.py` +- `src/rotator_library/routing/policy.py` +- `src/rotator_library/routing/executor.py` or `attempts.py` +- `tests/test_routing_config.py` +- `tests/test_fallback_resolver.py` +- `tests/test_fallback_policy.py` +- `tests/test_request_executor_fallback_groups.py` +- `tests/test_request_executor_native_routing.py` +- `tests/test_streaming_fallback_policy.py` + +## Files Likely To Touch + +- `src/rotator_library/core/types.py` to add optional routing/fallback fields to `RequestContext`. +- `src/rotator_library/client/request_builder.py` to resolve fallback groups before execution. +- `src/rotator_library/client/models.py` if model resolution needs a group-aware helper. +- `src/rotator_library/client/executor.py` to wrap existing per-provider retry loops in target attempts and call native executor for eligible targets. +- `src/rotator_library/client/rotating_client.py` only if passing route configuration into components requires it. +- `src/rotator_library/providers/provider_interface.py` only if a small native-support method is needed. +- `src/rotator_library/transaction_logger.py` only if adding route correlation helpers is cleaner. + +## Routing Types + +`RouteTarget`: + +- `name`: stable target identifier. +- `provider`: provider key. +- `model`: concrete model name, with provider prefix optional but normalized. +- `protocol`: optional override; defaults to provider declaration. +- `execution`: `auto`, `native`, `custom`, or `litellm_fallback`. +- `priority`: optional numeric order. +- `weight`: future target-group selector support, ignored for ordered fallback. +- `conditions`: optional metadata for later selectors. +- `metadata`. + +`FallbackGroup`: + +- `name`. +- `targets`: ordered `RouteTarget` list. +- `failover_on`: error categories that permit next-target fallback. +- `stop_on`: error categories that must stop. +- `streaming_policy`: e.g. `pre_output_only`. +- `max_targets`: optional guard. +- `metadata`. + +`RoutingDecision`: + +- `requested_model`. +- `group_name` optional. +- `targets`. +- `selected_target_index`. +- `reason`. + +`RouteAttemptResult`: + +- target, success/failure, error classification, emitted_output flag, usage summary. + +## Configuration Plan + +- Phase 6 supports env-first fallback definitions with a small parser. +- Optional JSON config can be parsed if simple, but deeper config merging belongs to Phase 10. +- Env examples: + - `FALLBACK_GROUPS=sonnet_chain,codex_chain` + - `FALLBACK_GROUP_SONNET_CHAIN=claude_code/claude-sonnet-4-5,copilot/claude-sonnet-4-5,anthropic/claude-3-5-sonnet-latest` + - `FALLBACK_GROUP_CODEX_CHAIN=codex/gpt-5.1-codex,openai/gpt-5.1` + - `MODEL_ROUTE_CLAUDE_SONNET=group:sonnet_chain` + - `MODEL_ROUTE_CODEX=group:codex_chain` +- Also allow programmatic construction in tests. +- Config parser must validate: + - group names are unique. + - targets have provider/model. + - empty chains are invalid. + - cycles through group aliases are rejected. + - provider names are not silently guessed when ambiguous. +- No secrets in route config. + +## Target Resolution + +- If requested model maps to a fallback group, return all group targets in order. +- If requested model is already `provider/model`, return a single target. +- If requested model maps through existing model definitions, preserve existing behavior. +- Target model should be normalized with provider prefix before entering `RequestExecutor`. +- Record transform trace pass `routing_decision` with requested model, group, target count, selected target names, execution modes, and no credentials. + +## Fallback Policy + +Try next target only on eligible failures: + +- rate limit/quota/capacity. +- provider unavailable/server error. +- transient connection errors. +- native provider unsupported operation when explicit fallback target exists. + +Do not fallback on: + +- auth errors for all credentials unless the group target explicitly marks auth fallback safe. +- validation/permanent request errors. +- pre-request callback failures. +- client cancellation. +- streaming errors after visible output. + +Respect existing per-provider credential rotation before moving to the next target: + +- one target attempt should let current `RequestExecutor` exhaust eligible credentials according to existing retry logic. +- after a target is exhausted, the fallback layer can advance to the next target. + +Keep error accumulator context across targets so final failure includes all target errors. + +## Native, Custom, And LiteLLM Execution Choice + +`execution=auto`: + +- if plugin has `has_custom_logic()`, use current plugin `acompletion()`. +- else if provider has native protocol declaration and native endpoint/header methods, use `NativeProviderExecutor`. +- else use current LiteLLM call path. + +`execution=native`: require native support or fail this target with a classified unsupported-operation error. + +`execution=custom`: require `has_custom_logic()`. + +`execution=litellm_fallback`: force current LiteLLM path and emit `native_provider_litellm_fallback` / `routing_litellm_fallback` trace metadata. + +Tests must prove native-declared providers do not silently use LiteLLM unless the target says `litellm_fallback`. + +## RequestExecutor Integration + +- Keep the existing `_execute_non_streaming()` and `_execute_streaming()` credential retry loops as the per-target implementation. +- Add a wrapper method that accepts a `RequestContext` with `routing_targets`. +- For each target: + - clone/update context provider/model/usage manager key/credentials for that target. + - record `routing_target_attempt_started`. + - run the existing target execution path. + - on success record `routing_target_attempt_succeeded` and return. + - on failure classify and ask fallback policy if next target is allowed. + - record `routing_target_attempt_failed`. +- Avoid mutating the original request context in-place. +- If no group exists, execute exactly as today. + +## Native Provider Execution Integration + +- Add a small execution branch inside the target attempt where the selected target resolves to native execution. +- Construct `NativeProviderContext` from: + - provider plugin declarations. + - selected model. + - credential identifier/token. + - request context session/scope/classifier. + - transaction logger. + - endpoint/headers from provider methods. +- Use Phase 5 `NativeProviderExecutor` for mocked/tested providers. +- Preserve usage recording through the existing credential context. If native responses do not provide full usage yet, record normalized available usage only and leave cost normalization for Phase 9. +- For custom providers with `has_custom_logic()`, keep existing plugin `acompletion()` branch. +- For undeclared providers, keep LiteLLM path. + +## Streaming Integration + +- Fallback to next target is allowed only before any visible chunk is yielded. +- Track `emitted_output` in the stream wrapper. +- If a stream errors before output, fallback policy may advance to next target. +- If a stream errors after output, propagate the error; do not switch models mid-stream. +- Record: + - `routing_stream_target_attempt_started` + - `routing_stream_target_attempt_failed` + - `routing_stream_target_attempt_succeeded` + - `routing_stream_fallback_blocked_after_output`. +- Keep current stream retry policy for same-target/same-provider errors before group fallback. + +## Transaction Trace Requirements + +- `routing_decision` +- `routing_target_attempt_started` +- `routing_target_attempt_failed` +- `routing_target_attempt_succeeded` +- `routing_fallback_selected` +- `routing_fallback_exhausted` +- `routing_litellm_fallback` +- `routing_native_execution_selected` +- stream equivalents where applicable. +- Trace metadata includes group, target index, provider, model, execution mode, classification, and reason. It must not include raw credentials. + +## Target Groups As Optional Richer Layer + +- Phase 6 focuses on ordered fallback groups. +- Add type seams for future target groups: + - `TargetSelector` + - `TargetGroup` + - `TargetSelectionPolicy` +- Do not implement weighted/latency/cost selector behavior unless it is trivial and isolated. +- Document that Phase 6 fallback groups are deterministic ordered chains. + +## Tests + +Config tests: + +- parse env fallback group. +- invalid empty group. +- duplicate group names. +- provider/model parsing. +- explicit `litellm_fallback` execution target. + +Resolver tests: + +- requested alias maps to group. +- provider/model remains single target. +- target order is preserved. +- existing model prefix behavior stays intact. + +Policy tests: + +- fallback on rate limit/server/transient. +- stop on auth/permanent/pre-request/cancel. +- streaming fallback allowed before output. +- streaming fallback blocked after output. + +Request executor tests: + +- first target success uses no fallback. +- first target rate limits all credentials then second target succeeds. +- final error includes both target failures. +- no fallback group preserves old behavior. +- transaction trace includes routing passes. + +Native routing tests: + +- native-declared provider target uses `NativeProviderExecutor`. +- native-declared provider does not use LiteLLM unless target execution is `litellm_fallback`. +- custom provider still uses plugin `acompletion()`. +- undeclared provider still uses LiteLLM. + +Streaming tests: + +- pre-output target failure falls through to next target. +- post-output failure does not fall through. +- trace records blocked fallback after output. + +Regression: + +- Phase 1 protocol tests. +- Phase 2 logging tests. +- Phase 3 adapter/cache tests. +- Phase 4 Responses tests. +- Phase 5 provider tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add routing types, config parser, resolver, and policy with tests. +2. Add request context routing fields and target cloning helpers with tests. +3. Integrate non-streaming fallback group wrapper around existing executor path with tests. +4. Integrate native/custom/LiteLLM execution mode selection with tests. +5. Integrate streaming pre-output fallback policy with tests. +6. Add trace pass coverage and run regressions. +7. Review with `explore` and `explore-heavy`, fix findings, and write the uncommitted Phase 6 report. + +## Risks And Mitigations + +- Fallback could double-spend or corrupt usage. Mitigation: each target uses existing credential context and usage manager; group fallback starts only after target failure. +- Fallback could mask permanent request bugs. Mitigation: policy stops on validation/permanent errors. +- Streaming fallback could corrupt client output. Mitigation: disallow fallback after visible output. +- Native routing could accidentally bypass LiteLLM behavior for existing providers. Mitigation: default no group/no native route is unchanged; native path is explicit/declared. +- Config could become too broad. Mitigation: small env parser now; richer JSON config later in Phase 10. +- Session affinity could leak across provider pools. Mitigation: keep `SessionTracker` namespace behavior and clone target context with provider/model-specific namespace where needed. From ce9d085e6d5f405cf9bb35b7f64433cfb33e5095 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:23:44 +0200 Subject: [PATCH 037/182] feat(routing): add fallback group primitives Adds Phase 6 routing dataclasses, env-first fallback group parser, deterministic resolver, and fallback policy for ordered provider/model chains. This checkpoint is isolated from live request execution and only establishes the tested foundation for later RequestExecutor integration. Tests: python -m pytest tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_fallback_policy.py --- src/rotator_library/routing/__init__.py | 23 +++++ src/rotator_library/routing/config.py | 69 +++++++++++++++ src/rotator_library/routing/policy.py | 31 +++++++ src/rotator_library/routing/resolver.py | 32 +++++++ src/rotator_library/routing/types.py | 113 ++++++++++++++++++++++++ tests/test_fallback_policy.py | 40 +++++++++ tests/test_fallback_resolver.py | 44 +++++++++ tests/test_routing_config.py | 48 ++++++++++ 8 files changed, 400 insertions(+) create mode 100644 src/rotator_library/routing/__init__.py create mode 100644 src/rotator_library/routing/config.py create mode 100644 src/rotator_library/routing/policy.py create mode 100644 src/rotator_library/routing/resolver.py create mode 100644 src/rotator_library/routing/types.py create mode 100644 tests/test_fallback_policy.py create mode 100644 tests/test_fallback_resolver.py create mode 100644 tests/test_routing_config.py diff --git a/src/rotator_library/routing/__init__.py b/src/rotator_library/routing/__init__.py new file mode 100644 index 000000000..8807b05f4 --- /dev/null +++ b/src/rotator_library/routing/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Routing and fallback group primitives.""" + +from .config import RoutingConfigError, load_routing_config_from_env, parse_route_target +from .policy import FallbackPolicy +from .resolver import FallbackResolver +from .types import FallbackGroup, RouteTarget, RoutingConfig, RoutingDecision, TargetGroup, TargetSelector + +__all__ = [ + "FallbackGroup", + "FallbackPolicy", + "FallbackResolver", + "RouteTarget", + "RoutingConfig", + "RoutingConfigError", + "RoutingDecision", + "TargetGroup", + "TargetSelector", + "load_routing_config_from_env", + "parse_route_target", +] diff --git a/src/rotator_library/routing/config.py b/src/rotator_library/routing/config.py new file mode 100644 index 000000000..71a149bb2 --- /dev/null +++ b/src/rotator_library/routing/config.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Environment parser for Phase 6 fallback routing configuration.""" + +from __future__ import annotations + +import os +from collections.abc import Mapping + +from .types import FallbackGroup, RouteTarget, RoutingConfig + + +class RoutingConfigError(ValueError): + """Raised when fallback routing configuration is invalid.""" + + +def parse_route_target(spec: str) -> RouteTarget: + """Parse `provider/model` with optional `@execution` suffix. + + The suffix is intentionally small because Phase 6 prioritizes ordered + fallback groups. Rich selector syntax belongs to the config polish phase. + """ + + text = spec.strip() + if not text: + raise RoutingConfigError("route target cannot be empty") + target_text, _, execution = text.partition("@") + if "/" not in target_text: + raise RoutingConfigError(f"route target requires provider/model: {spec}") + provider, model = target_text.split("/", 1) + if not provider or not model: + raise RoutingConfigError(f"route target requires provider/model: {spec}") + return RouteTarget(provider=provider.strip(), model=model.strip(), execution=(execution.strip() or "auto")) + + +def load_routing_config_from_env(env: Mapping[str, str] | None = None) -> RoutingConfig: + """Load fallback groups and model-route aliases from environment variables.""" + + source = env if env is not None else os.environ + group_names = _csv(source.get("FALLBACK_GROUPS", "")) + if len(group_names) != len(set(group_names)): + raise RoutingConfigError("fallback group names must be unique") + groups: dict[str, FallbackGroup] = {} + for name in group_names: + key = f"FALLBACK_GROUP_{_env_key(name)}" + target_specs = _csv(source.get(key, "")) + if not target_specs: + raise RoutingConfigError(f"fallback group {name} has no targets") + groups[name] = FallbackGroup(name=name, targets=tuple(parse_route_target(spec) for spec in target_specs)) + + model_routes: dict[str, str] = {} + for key, value in source.items(): + if not key.startswith("MODEL_ROUTE_"): + continue + model_alias = key[len("MODEL_ROUTE_") :].lower() + route = value.strip() + if route.startswith("group:") and route[len("group:") :] not in groups: + raise RoutingConfigError(f"model route {key} references unknown fallback group {route}") + model_routes[model_alias] = route + return RoutingConfig(fallback_groups=groups, model_routes=model_routes) + + +def _csv(value: str) -> list[str]: + return [part.strip() for part in value.split(",") if part.strip()] + + +def _env_key(value: str) -> str: + return value.upper().replace("-", "_") diff --git a/src/rotator_library/routing/policy.py b/src/rotator_library/routing/policy.py new file mode 100644 index 000000000..92ea5241d --- /dev/null +++ b/src/rotator_library/routing/policy.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Fallback eligibility policy for ordered route chains.""" + +from __future__ import annotations + +from .types import DEFAULT_FAILOVER_ON, DEFAULT_STOP_ON, FallbackGroup + + +class FallbackPolicy: + """Decide whether a failed target may advance to the next target.""" + + def should_fallback( + self, + error_type: str, + *, + group: FallbackGroup | None = None, + emitted_output: bool = False, + stream: bool = False, + ) -> bool: + """Return whether fallback is allowed for a classified failure.""" + + normalized = error_type.lower() + active_stop = group.stop_on if group else DEFAULT_STOP_ON + active_failover = group.failover_on if group else DEFAULT_FAILOVER_ON + if stream and emitted_output: + return False + if normalized in active_stop: + return False + return normalized in active_failover diff --git a/src/rotator_library/routing/resolver.py b/src/rotator_library/routing/resolver.py new file mode 100644 index 000000000..745a6070a --- /dev/null +++ b/src/rotator_library/routing/resolver.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Resolve requested models to direct targets or fallback groups.""" + +from __future__ import annotations + +from .config import RoutingConfigError, parse_route_target +from .types import RoutingConfig, RoutingDecision + + +class FallbackResolver: + """Resolve a model name using deterministic fallback group rules.""" + + def __init__(self, config: RoutingConfig) -> None: + self.config = config + + def resolve(self, requested_model: str) -> RoutingDecision: + """Return the ordered targets for a requested model.""" + + route = self.config.model_routes.get(requested_model.lower()) + if route and route.startswith("group:"): + group_name = route[len("group:") :] + group = self.config.fallback_groups.get(group_name) + if not group: + raise RoutingConfigError(f"unknown fallback group {group_name}") + return RoutingDecision(requested_model=requested_model, group_name=group.name, targets=group.targets, reason="model_route_group") + if route: + return RoutingDecision(requested_model=requested_model, targets=(parse_route_target(route),), reason="model_route_target") + if "/" in requested_model: + return RoutingDecision(requested_model=requested_model, targets=(parse_route_target(requested_model),), reason="direct_provider_model") + raise RoutingConfigError(f"model {requested_model!r} is not provider-prefixed and has no route") diff --git a/src/rotator_library/routing/types.py b/src/rotator_library/routing/types.py new file mode 100644 index 000000000..bd6ca04a9 --- /dev/null +++ b/src/rotator_library/routing/types.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Typed route targets and fallback groups.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal, Protocol + +ExecutionMode = Literal["auto", "native", "custom", "litellm_fallback"] +StreamingFallbackPolicy = Literal["pre_output_only", "never"] + +DEFAULT_FAILOVER_ON = frozenset({"rate_limit", "quota", "capacity", "server_error", "api_connection", "transient"}) +DEFAULT_STOP_ON = frozenset({"auth", "authentication", "validation", "permanent", "pre_request_callback", "cancelled"}) + + +@dataclass(frozen=True) +class RouteTarget: + """One concrete provider/model execution target in a fallback chain.""" + + provider: str + model: str + name: str = "" + protocol: str | None = None + execution: ExecutionMode = "auto" + priority: int | None = None + weight: float | None = None + conditions: dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if not self.provider or not self.model: + raise ValueError("route targets require provider and model") + if self.execution not in {"auto", "native", "custom", "litellm_fallback"}: + raise ValueError(f"unsupported execution mode: {self.execution}") + if not self.name: + object.__setattr__(self, "name", f"{self.provider}/{self.model}") + + @property + def prefixed_model(self) -> str: + """Return `provider/model` without double-prefixing an already-prefixed model.""" + + return self.model if self.model.startswith(f"{self.provider}/") else f"{self.provider}/{self.model}" + + +@dataclass(frozen=True) +class FallbackGroup: + """Deterministic ordered chain of route targets.""" + + name: str + targets: tuple[RouteTarget, ...] + failover_on: frozenset[str] = DEFAULT_FAILOVER_ON + stop_on: frozenset[str] = DEFAULT_STOP_ON + streaming_policy: StreamingFallbackPolicy = "pre_output_only" + max_targets: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if not self.name: + raise ValueError("fallback group name is required") + if not self.targets: + raise ValueError("fallback groups require at least one target") + if self.max_targets is not None and self.max_targets <= 0: + raise ValueError("max_targets must be positive") + if self.max_targets is not None and len(self.targets) > self.max_targets: + raise ValueError("fallback group target count exceeds max_targets") + + +@dataclass(frozen=True) +class RoutingConfig: + """Routing configuration loaded from env or tests.""" + + fallback_groups: dict[str, FallbackGroup] = field(default_factory=dict) + model_routes: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RoutingDecision: + """Resolved routing plan for a requested model.""" + + requested_model: str + targets: tuple[RouteTarget, ...] + group_name: str | None = None + selected_target_index: int = 0 + reason: str = "direct" + + +@dataclass(frozen=True) +class RouteAttemptResult: + """Result summary for one attempted route target.""" + + target: RouteTarget + success: bool + error_type: str | None = None + emitted_output: bool = False + usage: dict[str, Any] = field(default_factory=dict) + + +class TargetSelector(Protocol): + """Future target-group selector seam; Phase 6 keeps ordered fallback only.""" + + def select(self, targets: tuple[RouteTarget, ...]) -> RouteTarget: + """Return one target from a richer target group.""" + + +@dataclass(frozen=True) +class TargetGroup: + """Future richer target group; not used for Phase 6 ordered fallback.""" + + name: str + targets: tuple[RouteTarget, ...] + selector: str = "ordered" diff --git a/tests/test_fallback_policy.py b/tests/test_fallback_policy.py new file mode 100644 index 000000000..2d36892cd --- /dev/null +++ b/tests/test_fallback_policy.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from rotator_library.routing import FallbackPolicy, parse_route_target +from rotator_library.routing.types import FallbackGroup + + +def test_policy_falls_back_on_retryable_categories() -> None: + policy = FallbackPolicy() + + assert policy.should_fallback("rate_limit") is True + assert policy.should_fallback("server_error") is True + assert policy.should_fallback("api_connection") is True + + +def test_policy_stops_on_permanent_categories() -> None: + policy = FallbackPolicy() + + assert policy.should_fallback("auth") is False + assert policy.should_fallback("validation") is False + assert policy.should_fallback("pre_request_callback") is False + + +def test_policy_blocks_stream_fallback_after_visible_output() -> None: + assert FallbackPolicy().should_fallback("rate_limit", stream=True, emitted_output=True) is False + + +def test_policy_allows_stream_fallback_before_visible_output() -> None: + assert FallbackPolicy().should_fallback("rate_limit", stream=True, emitted_output=False) is True + + +def test_policy_respects_group_overrides() -> None: + group = FallbackGroup( + name="auth_safe", + targets=(parse_route_target("a/model"), parse_route_target("b/model")), + failover_on=frozenset({"auth"}), + stop_on=frozenset({"validation"}), + ) + + assert FallbackPolicy().should_fallback("auth", group=group) is True + assert FallbackPolicy().should_fallback("validation", group=group) is False diff --git a/tests/test_fallback_resolver.py b/tests/test_fallback_resolver.py new file mode 100644 index 000000000..674a5d559 --- /dev/null +++ b/tests/test_fallback_resolver.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest + +from rotator_library.routing import FallbackResolver, RoutingConfigError, load_routing_config_from_env + + +def test_resolver_maps_alias_to_fallback_group_in_order() -> None: + config = load_routing_config_from_env( + { + "FALLBACK_GROUPS": "code_chain", + "FALLBACK_GROUP_CODE_CHAIN": "codex/gpt-5.1-codex,openai/gpt-5.1", + "MODEL_ROUTE_CODEX": "group:code_chain", + } + ) + + decision = FallbackResolver(config).resolve("codex") + + assert decision.group_name == "code_chain" + assert [target.prefixed_model for target in decision.targets] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + assert decision.reason == "model_route_group" + + +def test_resolver_keeps_provider_prefixed_model_as_direct_target() -> None: + decision = FallbackResolver(load_routing_config_from_env({})).resolve("openai/gpt-5.1") + + assert decision.group_name is None + assert decision.targets[0].provider == "openai" + assert decision.targets[0].model == "gpt-5.1" + assert decision.reason == "direct_provider_model" + + +def test_resolver_maps_alias_to_single_target() -> None: + config = load_routing_config_from_env({"MODEL_ROUTE_FAST": "openai/gpt-5.1"}) + + decision = FallbackResolver(config).resolve("fast") + + assert decision.targets[0].prefixed_model == "openai/gpt-5.1" + assert decision.reason == "model_route_target" + + +def test_resolver_rejects_unprefixed_model_without_route() -> None: + with pytest.raises(RoutingConfigError): + FallbackResolver(load_routing_config_from_env({})).resolve("gpt-5.1") diff --git a/tests/test_routing_config.py b/tests/test_routing_config.py new file mode 100644 index 000000000..3d856d61b --- /dev/null +++ b/tests/test_routing_config.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import pytest + +from rotator_library.routing import RoutingConfigError, load_routing_config_from_env, parse_route_target + + +def test_parse_route_target_supports_execution_suffix() -> None: + target = parse_route_target("codex/gpt-5.1-codex@native") + + assert target.provider == "codex" + assert target.model == "gpt-5.1-codex" + assert target.execution == "native" + assert target.prefixed_model == "codex/gpt-5.1-codex" + + +def test_parse_route_target_rejects_missing_provider() -> None: + with pytest.raises(RoutingConfigError): + parse_route_target("gpt-5.1") + + +def test_load_routing_config_from_env_parses_group_and_model_route() -> None: + config = load_routing_config_from_env( + { + "FALLBACK_GROUPS": "sonnet_chain", + "FALLBACK_GROUP_SONNET_CHAIN": "claude_code/claude-sonnet-4-5,copilot/claude-sonnet-4-5@litellm_fallback", + "MODEL_ROUTE_CLAUDE_SONNET": "group:sonnet_chain", + } + ) + + assert tuple(config.fallback_groups) == ("sonnet_chain",) + assert config.fallback_groups["sonnet_chain"].targets[1].execution == "litellm_fallback" + assert config.model_routes["claude_sonnet"] == "group:sonnet_chain" + + +def test_load_routing_config_rejects_empty_group() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env({"FALLBACK_GROUPS": "empty", "FALLBACK_GROUP_EMPTY": ""}) + + +def test_load_routing_config_rejects_duplicate_group_names() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env({"FALLBACK_GROUPS": "a,a", "FALLBACK_GROUP_A": "openai/gpt"}) + + +def test_load_routing_config_rejects_unknown_group_route() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env({"MODEL_ROUTE_CODEX": "group:missing"}) From 036eecf113fade9862574900261dcab718e14704 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:24:53 +0200 Subject: [PATCH 038/182] feat(routing): add target context cloning Adds optional routing metadata to RequestContext and a clone_context_for_target helper that creates immutable per-target request contexts for fallback attempts without mutating the original request. Tests: python -m pytest tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_fallback_policy.py tests/test_routing_attempts.py tests/test_session_tracking.py tests/test_selection_engine.py --- src/rotator_library/core/types.py | 4 ++ src/rotator_library/routing/__init__.py | 2 + src/rotator_library/routing/attempts.py | 40 ++++++++++++++++++++ tests/test_routing_attempts.py | 50 +++++++++++++++++++++++++ 4 files changed, 96 insertions(+) create mode 100644 src/rotator_library/routing/attempts.py create mode 100644 tests/test_routing_attempts.py diff --git a/src/rotator_library/core/types.py b/src/rotator_library/core/types.py index c502592ec..1096932c9 100644 --- a/src/rotator_library/core/types.py +++ b/src/rotator_library/core/types.py @@ -89,6 +89,10 @@ class RequestContext: provider_config: Optional[Dict[str, Any]] = None credential_secrets: Dict[str, str] = field(default_factory=dict) classifier: Optional[str] = None + routing_targets: Optional[Any] = None + routing_group_name: Optional[str] = None + routing_target_index: int = 0 + routing_attempt_history: List[Dict[str, Any]] = field(default_factory=list) @dataclass diff --git a/src/rotator_library/routing/__init__.py b/src/rotator_library/routing/__init__.py index 8807b05f4..110ef3217 100644 --- a/src/rotator_library/routing/__init__.py +++ b/src/rotator_library/routing/__init__.py @@ -4,6 +4,7 @@ """Routing and fallback group primitives.""" from .config import RoutingConfigError, load_routing_config_from_env, parse_route_target +from .attempts import clone_context_for_target from .policy import FallbackPolicy from .resolver import FallbackResolver from .types import FallbackGroup, RouteTarget, RoutingConfig, RoutingDecision, TargetGroup, TargetSelector @@ -18,6 +19,7 @@ "RoutingDecision", "TargetGroup", "TargetSelector", + "clone_context_for_target", "load_routing_config_from_env", "parse_route_target", ] diff --git a/src/rotator_library/routing/attempts.py b/src/rotator_library/routing/attempts.py new file mode 100644 index 000000000..827279dc9 --- /dev/null +++ b/src/rotator_library/routing/attempts.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Helpers for creating per-target request contexts.""" + +from __future__ import annotations + +from dataclasses import replace +from typing import Any, Sequence + +from ..core.types import RequestContext +from .types import RouteTarget + + +def clone_context_for_target( + context: RequestContext, + target: RouteTarget, + *, + credentials: Sequence[str] | None = None, + usage_manager_key: str | None = None, + target_index: int = 0, +) -> RequestContext: + """Return a target-specific request context without mutating the original. + + Fallback routing must preserve the original request context for traceability + and for safe retry decisions. This helper copies the request kwargs, updates + the selected model/provider, and carries routing metadata for executor traces. + """ + + kwargs: dict[str, Any] = dict(context.kwargs) + kwargs["model"] = target.prefixed_model + return replace( + context, + model=target.prefixed_model, + provider=target.provider, + kwargs=kwargs, + credentials=list(credentials) if credentials is not None else list(context.credentials), + usage_manager_key=usage_manager_key if usage_manager_key is not None else target.provider, + routing_target_index=target_index, + ) diff --git a/tests/test_routing_attempts.py b/tests/test_routing_attempts.py new file mode 100644 index 000000000..1a8a978b9 --- /dev/null +++ b/tests/test_routing_attempts.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from rotator_library.core.types import RequestContext +from rotator_library.routing import clone_context_for_target, parse_route_target + + +def test_clone_context_for_target_updates_model_provider_without_mutating_original() -> None: + original = RequestContext( + model="requested", + provider="original", + kwargs={"model": "requested", "messages": []}, + streaming=False, + credentials=["cred-a"], + deadline=123.0, + usage_manager_key="original", + ) + target = parse_route_target("codex/gpt-5.1-codex@native") + + cloned = clone_context_for_target(original, target, credentials=["cred-b"], target_index=1) + + assert cloned.model == "codex/gpt-5.1-codex" + assert cloned.provider == "codex" + assert cloned.kwargs["model"] == "codex/gpt-5.1-codex" + assert cloned.credentials == ["cred-b"] + assert cloned.usage_manager_key == "codex" + assert cloned.routing_target_index == 1 + assert original.model == "requested" + assert original.kwargs["model"] == "requested" + assert original.credentials == ["cred-a"] + + +def test_clone_context_for_target_preserves_request_metadata() -> None: + original = RequestContext( + model="requested", + provider="original", + kwargs={"model": "requested"}, + streaming=True, + credentials=["cred-a"], + deadline=123.0, + session_id="session-1", + classifier="global", + routing_group_name="chain", + ) + + cloned = clone_context_for_target(original, parse_route_target("openai/gpt-5.1")) + + assert cloned.streaming is True + assert cloned.session_id == "session-1" + assert cloned.classifier == "global" + assert cloned.routing_group_name == "chain" From 589603ca33f7b5c6d9f071a7307d79d6335dc804 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:26:39 +0200 Subject: [PATCH 039/182] feat(routing): add fallback attempt runner Adds an isolated ordered fallback runner that uses the Phase 6 fallback policy with injected per-target attempts, including retryable fallback, permanent-stop, and stream-after-output blocking behavior. This prepares RequestExecutor integration without changing live execution behavior yet. Tests: python -m pytest tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_fallback_policy.py tests/test_routing_attempts.py tests/test_fallback_attempt_runner.py --- src/rotator_library/routing/__init__.py | 3 + src/rotator_library/routing/executor.py | 57 +++++++++++++++++++ tests/test_fallback_attempt_runner.py | 75 +++++++++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 src/rotator_library/routing/executor.py create mode 100644 tests/test_fallback_attempt_runner.py diff --git a/src/rotator_library/routing/__init__.py b/src/rotator_library/routing/__init__.py index 110ef3217..9892616bc 100644 --- a/src/rotator_library/routing/__init__.py +++ b/src/rotator_library/routing/__init__.py @@ -5,12 +5,15 @@ from .config import RoutingConfigError, load_routing_config_from_env, parse_route_target from .attempts import clone_context_for_target +from .executor import FallbackAttemptRunner, FallbackExhaustedError from .policy import FallbackPolicy from .resolver import FallbackResolver from .types import FallbackGroup, RouteTarget, RoutingConfig, RoutingDecision, TargetGroup, TargetSelector __all__ = [ "FallbackGroup", + "FallbackAttemptRunner", + "FallbackExhaustedError", "FallbackPolicy", "FallbackResolver", "RouteTarget", diff --git a/src/rotator_library/routing/executor.py b/src/rotator_library/routing/executor.py new file mode 100644 index 000000000..868d50267 --- /dev/null +++ b/src/rotator_library/routing/executor.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Ordered fallback attempt runner.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from .policy import FallbackPolicy +from .types import RouteAttemptResult, RouteTarget, RoutingDecision + +AttemptCallback = Callable[[RouteTarget, int], Awaitable[Any]] + + +class FallbackExhaustedError(Exception): + """Raised when every eligible fallback target fails.""" + + def __init__(self, decision: RoutingDecision, attempts: tuple[RouteAttemptResult, ...]) -> None: + self.decision = decision + self.attempts = attempts + summary = ", ".join(f"{attempt.target.name}:{attempt.error_type}" for attempt in attempts) + super().__init__(f"fallback group exhausted for {decision.requested_model}: {summary}") + + +class FallbackAttemptRunner: + """Run ordered route targets using an injected per-target attempt callback. + + The runner is independent from `RequestExecutor`. Phase 6 can unit-test + fallback control flow here, then wire the callback to existing per-target + credential retry logic without duplicating policy decisions. + """ + + def __init__(self, policy: FallbackPolicy | None = None) -> None: + self.policy = policy or FallbackPolicy() + + async def run(self, decision: RoutingDecision, attempt: AttemptCallback, *, stream: bool = False) -> Any: + """Try targets in order until one succeeds or fallback is exhausted.""" + + attempts: list[RouteAttemptResult] = [] + group = None + for index, target in enumerate(decision.targets): + try: + return await attempt(target, index) + except Exception as exc: + error_type = _error_type(exc) + emitted_output = bool(getattr(exc, "emitted_output", False)) + attempts.append(RouteAttemptResult(target=target, success=False, error_type=error_type, emitted_output=emitted_output)) + has_next = index < len(decision.targets) - 1 + if not has_next or not self.policy.should_fallback(error_type, group=group, emitted_output=emitted_output, stream=stream): + raise FallbackExhaustedError(decision, tuple(attempts)) from exc + raise FallbackExhaustedError(decision, tuple(attempts)) + + +def _error_type(error: BaseException) -> str: + return str(getattr(error, "error_type", error.__class__.__name__)).lower() diff --git a/tests/test_fallback_attempt_runner.py b/tests/test_fallback_attempt_runner.py new file mode 100644 index 000000000..0ba95adc9 --- /dev/null +++ b/tests/test_fallback_attempt_runner.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import pytest + +from rotator_library.routing import FallbackAttemptRunner, FallbackExhaustedError, RoutingDecision, parse_route_target + + +class ClassifiedFailure(Exception): + def __init__(self, error_type: str, *, emitted_output: bool = False) -> None: + super().__init__(error_type) + self.error_type = error_type + self.emitted_output = emitted_output + + +def _decision() -> RoutingDecision: + return RoutingDecision( + requested_model="code", + group_name="code_chain", + targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")), + reason="model_route_group", + ) + + +@pytest.mark.asyncio +async def test_attempt_runner_returns_first_success_without_fallback() -> None: + calls = [] + + async def attempt(target, index): + calls.append(target.prefixed_model) + return {"target": target.prefixed_model} + + result = await FallbackAttemptRunner().run(_decision(), attempt) + + assert result == {"target": "codex/gpt-5.1-codex"} + assert calls == ["codex/gpt-5.1-codex"] + + +@pytest.mark.asyncio +async def test_attempt_runner_falls_back_on_retryable_error() -> None: + calls = [] + + async def attempt(target, index): + calls.append(target.prefixed_model) + if index == 0: + raise ClassifiedFailure("rate_limit") + return {"target": target.prefixed_model} + + result = await FallbackAttemptRunner().run(_decision(), attempt) + + assert result == {"target": "openai/gpt-5.1"} + assert calls == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + + +@pytest.mark.asyncio +async def test_attempt_runner_stops_on_permanent_error() -> None: + async def attempt(target, index): + raise ClassifiedFailure("validation") + + with pytest.raises(FallbackExhaustedError) as exc: + await FallbackAttemptRunner().run(_decision(), attempt) + + assert len(exc.value.attempts) == 1 + assert exc.value.attempts[0].error_type == "validation" + + +@pytest.mark.asyncio +async def test_attempt_runner_blocks_stream_fallback_after_output() -> None: + async def attempt(target, index): + raise ClassifiedFailure("rate_limit", emitted_output=True) + + with pytest.raises(FallbackExhaustedError) as exc: + await FallbackAttemptRunner().run(_decision(), attempt, stream=True) + + assert len(exc.value.attempts) == 1 + assert exc.value.attempts[0].emitted_output is True From 9293035af5e133c1d835454e157ddee7b1e914b4 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:37:13 +0200 Subject: [PATCH 040/182] feat(routing): integrate non-streaming fallback attempts Adds the first RequestExecutor integration seam for Phase 6: non-streaming ordered routing targets activate only when RequestContext.routing_targets is present, clone per-target contexts without mutating the original request, and emit routing trace passes for decision/start/failure/fallback/success/exhaustion. The no-routing execution path remains unchanged. Structured proxy all-credentials-exhausted responses are interpreted for retryable fallback when they summarize rate-limit or server-style failures. Tests: python -m pytest tests/test_request_executor_fallback_groups.py tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_fallback_policy.py tests/test_routing_attempts.py tests/test_fallback_attempt_runner.py --- src/rotator_library/client/executor.py | 135 ++++++++++++++++++ .../test_request_executor_fallback_groups.py | 111 ++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 tests/test_request_executor_fallback_groups.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 589a6c33e..e5565d2bb 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -62,6 +62,8 @@ from ..request_sanitizer import sanitize_request_payload from ..transaction_logger import TransactionLogger from ..failure_logger import log_failure +from ..routing import FallbackPolicy, clone_context_for_target +from ..routing.types import RouteTarget from .types import RetryState, AvailabilityStats from .filters import CredentialFilter @@ -495,9 +497,102 @@ async def execute( """ if context.streaming: return self._execute_streaming(context) + elif context.routing_targets: + return await self._execute_non_streaming_with_fallback(context) else: return await self._execute_non_streaming(context) + async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> Any: + """Execute an ordered non-streaming fallback target chain. + + The normal single-target path remains `_execute_non_streaming()`. This + wrapper only runs when request building has populated `routing_targets`, + preserving existing behavior for all current requests. + """ + + targets = tuple(context.routing_targets or ()) + if not targets: + return await self._execute_non_streaming(context) + policy = FallbackPolicy() + last_failure: Any = None + self._log_routing_trace( + context, + "routing_decision", + {"requested_model": context.model, "target_count": len(targets)}, + metadata={"group": context.routing_group_name, "targets": [_target_trace(target) for target in targets]}, + ) + for index, target in enumerate(targets): + target_context = clone_context_for_target( + context, + target, + target_index=index, + usage_manager_key=target.provider, + ) + self._log_routing_trace( + context, + "routing_target_attempt_started", + _target_trace(target), + metadata={"target_index": index, "group": context.routing_group_name}, + ) + try: + result = await self._execute_non_streaming(target_context) + except Exception as exc: + last_failure = exc + error_type = _route_error_type(exc) + self._log_routing_trace( + context, + "routing_target_attempt_failed", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type, "exception": exc.__class__.__name__}, + ) + if index >= len(targets) - 1 or not policy.should_fallback(error_type): + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type}) + raise + self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) + continue + + error_type = _route_error_type_from_response(result) + if error_type: + last_failure = result + self._log_routing_trace( + context, + "routing_target_attempt_failed", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type}, + ) + if index < len(targets) - 1 and policy.should_fallback(error_type): + self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) + continue + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type}) + return result + + self._log_routing_trace( + context, + "routing_target_attempt_succeeded", + _target_trace(target), + metadata={"target_index": index}, + ) + return result + + if isinstance(last_failure, Exception): + raise last_failure + return last_failure + + @staticmethod + def _log_routing_trace(context: RequestContext, pass_name: str, data: Any, *, metadata: Optional[Dict[str, Any]] = None) -> None: + """Record routing trace entries without affecting request execution.""" + + if not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + pass_name, + data, + direction="metadata", + stage="routing", + metadata=metadata or {}, + snapshot=False, + ) + async def _prepare_execution( self, context: RequestContext, @@ -1558,3 +1653,43 @@ async def _transaction_logging_stream_wrapper( lib_logger.debug( f"Failed to assemble/log final streaming response: {e}" ) + + +def _target_trace(target: RouteTarget) -> Dict[str, Any]: + """Return non-secret route target metadata for transaction traces.""" + + return { + "name": target.name, + "provider": target.provider, + "model": target.prefixed_model, + "execution": target.execution, + "protocol": target.protocol, + } + + +def _route_error_type(error: BaseException) -> str: + """Map an exception to a fallback-policy error type.""" + + explicit = getattr(error, "error_type", None) + if explicit: + return str(explicit).lower() + classified = classify_error(error) + return classified.error_type + + +def _route_error_type_from_response(response: Any) -> Optional[str]: + """Infer retryability from the proxy's structured error response.""" + + if not isinstance(response, dict) or not isinstance(response.get("error"), dict): + return None + error = response["error"] + error_type = str(error.get("type", "")).lower() + details = error.get("details") if isinstance(error.get("details"), dict) else {} + normal_summary = str(details.get("normal_error_summary", "")).lower() + if any(token in normal_summary for token in ("rate_limit", "quota", "capacity")): + return "rate_limit" + if any(token in normal_summary for token in ("server_error", "api_connection", "transient")): + return "server_error" + if error_type in {"proxy_timeout", "proxy_all_credentials_exhausted"}: + return "rate_limit" + return None diff --git a/tests/test_request_executor_fallback_groups.py b/tests/test_request_executor_fallback_groups.py new file mode 100644 index 000000000..1188a5347 --- /dev/null +++ b/tests/test_request_executor_fallback_groups.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import json +from types import MethodType + +import pytest + +from rotator_library.client.executor import RequestExecutor +from rotator_library.core.types import RequestContext +from rotator_library.routing import parse_route_target +from rotator_library.transaction_logger import TransactionLogger + + +class ClassifiedFailure(Exception): + def __init__(self, error_type: str) -> None: + super().__init__(error_type) + self.error_type = error_type + + +def _context(*, routing_targets=None, logger=None) -> RequestContext: + return RequestContext( + model="code", + provider="requested", + kwargs={"model": "code", "messages": []}, + streaming=False, + credentials=["cred-a"], + deadline=9999999999.0, + transaction_logger=logger, + routing_targets=routing_targets, + routing_group_name="code_chain" if routing_targets else None, + ) + + +def _executor_with_attempts(attempts): + executor = RequestExecutor.__new__(RequestExecutor) + + async def fake_execute(self, context): + attempts.append(context) + if len(attempts) == 1: + raise ClassifiedFailure("rate_limit") + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + return executor + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_tries_next_target_on_retryable_error() -> None: + attempts = [] + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) + + result = await _executor_with_attempts(attempts)._execute_non_streaming_with_fallback(_context(routing_targets=targets)) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + assert [attempt.provider for attempt in attempts] == ["codex", "openai"] + assert [attempt.kwargs["model"] for attempt in attempts] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_stops_on_permanent_error() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context) + raise ClassifiedFailure("validation") + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + with pytest.raises(ClassifiedFailure): + await executor._execute_non_streaming_with_fallback( + _context(routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1"))) + ) + + assert len(attempts) == 1 + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_handles_structured_error_response() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context) + if len(attempts) == 1: + return {"error": {"type": "proxy_all_credentials_exhausted", "details": {"normal_error_summary": "2 rate_limit"}}} + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + result = await executor._execute_non_streaming_with_fallback( + _context(routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1"))) + ) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_emits_routing_trace(tmp_path) -> None: + attempts = [] + logger = TransactionLogger("routing", "code", parent_dir=tmp_path) + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) + + await _executor_with_attempts(attempts)._execute_non_streaming_with_fallback(_context(routing_targets=targets, logger=logger)) + + pass_names = [json.loads(line)["pass_name"] for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + assert "routing_decision" in pass_names + assert pass_names.count("routing_target_attempt_started") == 2 + assert "routing_target_attempt_failed" in pass_names + assert "routing_fallback_selected" in pass_names + assert "routing_target_attempt_succeeded" in pass_names From 744cf599461e85c62c01f6903ab52f8b45de3f55 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:40:18 +0200 Subject: [PATCH 041/182] feat(routing): select native custom and fallback execution Adds routed execution-mode selection inside RequestExecutor: explicit LiteLLM fallback tracing, custom provider execution, native provider execution through NativeProviderExecutor, and unsupported native/custom errors that can participate in fallback policy. Native execution remains opt-in through routing targets and provider declarations; undeclared providers still use the existing LiteLLM path. Tests: python -m pytest tests/test_request_executor_native_routing.py tests/test_request_executor_fallback_groups.py tests/test_native_provider_executor.py tests/test_routing_config.py tests/test_fallback_policy.py tests/test_session_tracking.py tests/test_selection_engine.py --- src/rotator_library/client/executor.py | 145 ++++++++++++++++-- src/rotator_library/routing/types.py | 2 +- tests/test_request_executor_native_routing.py | 143 +++++++++++++++++ 3 files changed, 276 insertions(+), 14 deletions(-) create mode 100644 tests/test_request_executor_native_routing.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index e5565d2bb..6e989c526 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -64,6 +64,7 @@ from ..failure_logger import log_failure from ..routing import FallbackPolicy, clone_context_for_target from ..routing.types import RouteTarget +from ..native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor from .types import RetryState, AvailabilityStats from .filters import CredentialFilter @@ -77,6 +78,14 @@ lib_logger = logging.getLogger("rotator_library") +class RoutingExecutionError(RuntimeError): + """Internal error used when a routed target cannot use its requested mode.""" + + def __init__(self, message: str, error_type: str = "unsupported_operation") -> None: + super().__init__(message) + self.error_type = error_type + + class RequestExecutor: """ Unified retry/rotation logic for all request types. @@ -480,6 +489,99 @@ async def _run_pre_request_callback( raise PreRequestCallbackError(str(e)) from e lib_logger.warning(f"Pre-request callback failed: {e}") + async def _execute_provider_request( + self, + provider: str, + model: str, + plugin: Any, + credential_secret: str, + credential_id: str, + kwargs: Dict[str, Any], + context: RequestContext, + ) -> Any: + """Execute one provider request using routed execution-mode rules.""" + + target = _current_route_target(context) + execution = target.execution if target else "auto" + if execution == "litellm_fallback": + self._log_routing_trace( + context, + "routing_litellm_fallback", + _target_trace(target) if target else {"provider": provider, "model": model}, + ) + return await self._execute_litellm_request(kwargs, credential_secret) + + if execution == "custom" or (execution == "auto" and plugin and plugin.has_custom_logic()): + if not plugin or not plugin.has_custom_logic(): + raise RoutingExecutionError(f"Provider {provider} does not support custom execution") + kwargs["credential_identifier"] = credential_secret + return await plugin.acompletion(self._http_client, **kwargs) + + if execution == "native" or (execution == "auto" and _provider_native_protocol(plugin, model, target)): + native_context = self._build_native_provider_context( + provider, + model, + plugin, + credential_secret, + credential_id, + context, + target, + ) + self._log_routing_trace( + context, + "routing_native_execution_selected", + _target_trace(target) if target else {"provider": provider, "model": model}, + metadata={"protocol": native_context.protocol_name}, + ) + return await NativeProviderExecutor().execute(dict(kwargs), native_context, NativeHTTPTransport(self._http_client)) + + return await self._execute_litellm_request(kwargs, credential_secret) + + async def _execute_litellm_request(self, kwargs: Dict[str, Any], credential_secret: str) -> Any: + """Execute the existing LiteLLM request path.""" + + kwargs["api_key"] = credential_secret + self._apply_litellm_logger(kwargs) + kwargs.pop("transaction_context", None) + return await litellm.acompletion(**kwargs) + + def _build_native_provider_context( + self, + provider: str, + model: str, + plugin: Any, + credential_secret: str, + credential_id: str, + context: RequestContext, + target: Optional[RouteTarget], + ) -> NativeProviderContext: + """Build native provider context from provider declarations.""" + + if not plugin: + raise RoutingExecutionError(f"Provider {provider} has no plugin for native execution") + protocol_name = _provider_native_protocol(plugin, model, target) + if not protocol_name: + raise RoutingExecutionError(f"Provider {provider} has no native protocol declaration") + if not hasattr(plugin, "get_native_endpoint") or not hasattr(plugin, "get_native_headers"): + raise RoutingExecutionError(f"Provider {provider} has no native endpoint/header helpers") + endpoint = plugin.get_native_endpoint(model=model, operation="chat") + headers = plugin.get_native_headers(credential_secret, model=model, operation="chat") + return NativeProviderContext( + provider=provider, + model=model, + protocol_name=protocol_name, + endpoint=endpoint, + headers=headers, + credential_id=credential_id, + session_id=context.session_id, + scope_key=context.usage_manager_key, + classifier=context.classifier, + adapter_names=tuple(plugin.get_adapter_names(model) if hasattr(plugin, "get_adapter_names") else ()), + adapter_config=dict(plugin.get_adapter_config(model) if hasattr(plugin, "get_adapter_config") else {}), + field_cache_rules=tuple(plugin.get_field_cache_rules(model) if hasattr(plugin, "get_field_cache_rules") else ()), + transaction_logger=context.transaction_logger, + ) + async def execute( self, context: RequestContext, @@ -726,19 +828,15 @@ async def _execute_non_streaming( # Pre-request callback await self._run_pre_request_callback(context, kwargs) - # Make the API call - if plugin and plugin.has_custom_logic(): - kwargs["credential_identifier"] = credential_secret - response = await plugin.acompletion( - self._http_client, **kwargs - ) - else: - # Standard LiteLLM call - kwargs["api_key"] = credential_secret - self._apply_litellm_logger(kwargs) - # Remove internal context before litellm call - kwargs.pop("transaction_context", None) - response = await litellm.acompletion(**kwargs) + response = await self._execute_provider_request( + provider, + model, + plugin, + credential_secret, + cred_context.stable_id, + kwargs, + context, + ) # Success! Extract token usage if available ( @@ -1667,6 +1765,27 @@ def _target_trace(target: RouteTarget) -> Dict[str, Any]: } +def _current_route_target(context: RequestContext) -> Optional[RouteTarget]: + """Return the currently selected route target from context metadata.""" + + targets = tuple(context.routing_targets or ()) + if not targets: + return None + if context.routing_target_index < 0 or context.routing_target_index >= len(targets): + return None + return targets[context.routing_target_index] + + +def _provider_native_protocol(plugin: Any, model: str, target: Optional[RouteTarget]) -> Optional[str]: + """Resolve native protocol from target override or provider declaration.""" + + if target and target.protocol: + return target.protocol + if plugin and hasattr(plugin, "get_protocol_name"): + return plugin.get_protocol_name(model) + return None + + def _route_error_type(error: BaseException) -> str: """Map an exception to a fallback-policy error type.""" diff --git a/src/rotator_library/routing/types.py b/src/rotator_library/routing/types.py index bd6ca04a9..36bd4e056 100644 --- a/src/rotator_library/routing/types.py +++ b/src/rotator_library/routing/types.py @@ -11,7 +11,7 @@ ExecutionMode = Literal["auto", "native", "custom", "litellm_fallback"] StreamingFallbackPolicy = Literal["pre_output_only", "never"] -DEFAULT_FAILOVER_ON = frozenset({"rate_limit", "quota", "capacity", "server_error", "api_connection", "transient"}) +DEFAULT_FAILOVER_ON = frozenset({"rate_limit", "quota", "capacity", "server_error", "api_connection", "transient", "unsupported_operation"}) DEFAULT_STOP_ON = frozenset({"auth", "authentication", "validation", "permanent", "pre_request_callback", "cancelled"}) diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py new file mode 100644 index 000000000..dcb9a0d97 --- /dev/null +++ b/tests/test_request_executor_native_routing.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import pytest + +from rotator_library.client import executor as executor_module +from rotator_library.client.executor import RequestExecutor, RoutingExecutionError +from rotator_library.core.types import RequestContext +from rotator_library.routing import parse_route_target + + +class FakeNativeResponse: + def __init__(self, payload): + self.payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self.payload + + +class FakeHTTPClient: + def __init__(self): + self.calls = [] + + async def post(self, endpoint, *, headers, json): + self.calls.append({"endpoint": endpoint, "headers": headers, "json": json}) + return FakeNativeResponse({"id": "chat_native", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}) + + +class NativePlugin: + def has_custom_logic(self): + return False + + def get_protocol_name(self, model=""): + return "openai_chat" + + def get_native_endpoint(self, model="", operation="chat"): + return "https://native.test/chat" + + def get_native_headers(self, credential_identifier, model="", operation="chat"): + return {"Authorization": f"Bearer {credential_identifier}"} + + def get_adapter_names(self, model=""): + return () + + def get_adapter_config(self, model=""): + return {} + + def get_field_cache_rules(self, model=""): + return () + + +class CustomPlugin: + def __init__(self): + self.calls = [] + + def has_custom_logic(self): + return True + + async def acompletion(self, client, **kwargs): + self.calls.append(kwargs) + return {"id": "custom"} + + +def _context(target=None) -> RequestContext: + return RequestContext( + model="provider/gpt-test", + provider="provider", + kwargs={"model": "provider/gpt-test", "messages": [{"role": "user", "content": "hi"}]}, + streaming=False, + credentials=["cred"], + deadline=9999999999.0, + routing_targets=(target,) if target else None, + ) + + +def _executor(http_client=None) -> RequestExecutor: + executor = RequestExecutor.__new__(RequestExecutor) + executor._http_client = http_client or FakeHTTPClient() + executor._apply_litellm_logger = lambda kwargs: None + return executor + + +@pytest.mark.asyncio +async def test_native_declared_provider_uses_native_executor_in_auto_mode() -> None: + http_client = FakeHTTPClient() + target = parse_route_target("provider/gpt-test") + context = _context(target) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request( + "provider", "provider/gpt-test", NativePlugin(), "secret", "stable", dict(context.kwargs), context + ) + + assert response["id"] == "chat_native" + assert http_client.calls[0]["endpoint"] == "https://native.test/chat" + assert http_client.calls[0]["headers"]["Authorization"] == "Bearer secret" + + +@pytest.mark.asyncio +async def test_litellm_fallback_execution_is_explicit(monkeypatch) -> None: + calls = [] + + async def fake_acompletion(**kwargs): + calls.append(kwargs) + return {"id": "litellm"} + + monkeypatch.setattr(executor_module.litellm, "acompletion", fake_acompletion) + target = parse_route_target("provider/gpt-test@litellm_fallback") + context = _context(target) + context.routing_target_index = 0 + + response = await _executor()._execute_provider_request("provider", "provider/gpt-test", NativePlugin(), "secret", "stable", dict(context.kwargs), context) + + assert response == {"id": "litellm"} + assert calls[0]["api_key"] == "secret" + + +@pytest.mark.asyncio +async def test_custom_execution_mode_requires_custom_plugin() -> None: + target = parse_route_target("provider/gpt-test@custom") + context = _context(target) + context.routing_target_index = 0 + + plugin = CustomPlugin() + + response = await _executor()._execute_provider_request("provider", "provider/gpt-test", plugin, "secret", "stable", dict(context.kwargs), context) + + assert response == {"id": "custom"} + assert plugin.calls[0]["credential_identifier"] == "secret" + + +@pytest.mark.asyncio +async def test_native_execution_mode_fails_when_provider_has_no_native_declaration() -> None: + target = parse_route_target("provider/gpt-test@native") + context = _context(target) + context.routing_target_index = 0 + + with pytest.raises(RoutingExecutionError) as exc: + await _executor()._execute_provider_request("provider", "provider/gpt-test", None, "secret", "stable", dict(context.kwargs), context) + + assert exc.value.error_type == "unsupported_operation" From 9fdaf31fad35e7ee55222e33658551f3b25bc61f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:41:56 +0200 Subject: [PATCH 042/182] feat(routing): add streaming fallback policy Adds streaming routing fallback around the existing stream executor. Fallback to the next target is allowed only before visible output, while failures after output are traced and propagated to avoid changing models mid-stream. Tests cover pre-output fallback, blocked post-output fallback, routing trace entries, and treating error SSE chunks as non-visible output. Tests: python -m pytest tests/test_streaming_fallback_policy.py tests/test_request_executor_fallback_groups.py tests/test_request_executor_native_routing.py tests/test_fallback_policy.py tests/test_native_provider_streaming.py tests/test_responses_streaming.py tests/test_session_tracking.py tests/test_selection_engine.py --- src/rotator_library/client/executor.py | 89 ++++++++++++++++++++ tests/test_streaming_fallback_policy.py | 105 ++++++++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 tests/test_streaming_fallback_policy.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 6e989c526..542b65327 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -597,6 +597,8 @@ async def execute( Returns: Response object or async generator for streaming """ + if context.streaming and context.routing_targets: + return self._execute_streaming_with_fallback(context) if context.streaming: return self._execute_streaming(context) elif context.routing_targets: @@ -680,6 +682,74 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> raise last_failure return last_failure + async def _execute_streaming_with_fallback(self, context: RequestContext) -> AsyncGenerator[str, None]: + """Execute streaming fallback targets with pre-output-only failover.""" + + targets = tuple(context.routing_targets or ()) + if not targets: + async for chunk in self._execute_streaming(context): + yield chunk + return + policy = FallbackPolicy() + self._log_routing_trace( + context, + "routing_decision", + {"requested_model": context.model, "target_count": len(targets), "stream": True}, + metadata={"group": context.routing_group_name, "targets": [_target_trace(target) for target in targets]}, + ) + for index, target in enumerate(targets): + emitted_output = False + target_context = clone_context_for_target( + context, + target, + target_index=index, + usage_manager_key=target.provider, + ) + self._log_routing_trace( + context, + "routing_stream_target_attempt_started", + _target_trace(target), + metadata={"target_index": index, "group": context.routing_group_name}, + ) + try: + async for chunk in self._execute_streaming(target_context): + if _stream_chunk_is_visible_output(chunk): + emitted_output = True + yield chunk + self._log_routing_trace( + context, + "routing_stream_target_attempt_succeeded", + _target_trace(target), + metadata={"target_index": index, "emitted_output": emitted_output}, + ) + return + except Exception as exc: + error_type = _route_error_type(exc) + self._log_routing_trace( + context, + "routing_stream_target_attempt_failed", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type, "emitted_output": emitted_output}, + ) + if emitted_output: + self._log_routing_trace( + context, + "routing_stream_fallback_blocked_after_output", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type}, + ) + raise + if index < len(targets) - 1 and policy.should_fallback(error_type, stream=True, emitted_output=False): + self._log_routing_trace( + context, + "routing_fallback_selected", + _target_trace(targets[index + 1]), + metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type, "stream": True}, + ) + continue + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "stream": True}) + raise + @staticmethod def _log_routing_trace(context: RequestContext, pass_name: str, data: Any, *, metadata: Optional[Dict[str, Any]] = None) -> None: """Record routing trace entries without affecting request execution.""" @@ -1812,3 +1882,22 @@ def _route_error_type_from_response(response: Any) -> Optional[str]: if error_type in {"proxy_timeout", "proxy_all_credentials_exhausted"}: return "rate_limit" return None + + +def _stream_chunk_is_visible_output(chunk: str) -> bool: + """Return whether a stream chunk should block cross-target fallback.""" + + text = chunk.strip() + if not text or text == "data: [DONE]": + return False + if text.startswith("data:"): + payload = text[len("data:") :].strip() + if not payload or payload == "[DONE]": + return False + try: + data = json.loads(payload) + except json.JSONDecodeError: + return True + if isinstance(data, dict) and "error" in data: + return False + return True diff --git a/tests/test_streaming_fallback_policy.py b/tests/test_streaming_fallback_policy.py new file mode 100644 index 000000000..8e925a5ec --- /dev/null +++ b/tests/test_streaming_fallback_policy.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import json +from types import MethodType + +import pytest + +from rotator_library.client.executor import RequestExecutor +from rotator_library.core.types import RequestContext +from rotator_library.routing import parse_route_target +from rotator_library.transaction_logger import TransactionLogger + + +class StreamFailure(Exception): + def __init__(self, error_type: str) -> None: + super().__init__(error_type) + self.error_type = error_type + + +def _context(*, logger=None) -> RequestContext: + return RequestContext( + model="code", + provider="requested", + kwargs={"model": "code", "messages": [], "stream": True}, + streaming=True, + credentials=["cred-a"], + deadline=9999999999.0, + transaction_logger=logger, + routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")), + routing_group_name="code_chain", + ) + + +@pytest.mark.asyncio +async def test_streaming_fallback_tries_next_target_before_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_stream(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise StreamFailure("rate_limit") + yield 'data: {"choices":[{"delta":{"content":"ok"}}]}\n\n' + yield "data: [DONE]\n\n" + + executor._execute_streaming = MethodType(fake_stream, executor) + + chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + + assert attempts == ["codex", "openai"] + assert chunks[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_streaming_fallback_blocks_after_visible_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + + async def fake_stream(self, context): + yield 'data: {"choices":[{"delta":{"content":"partial"}}]}\n\n' + raise StreamFailure("rate_limit") + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + + +@pytest.mark.asyncio +async def test_streaming_fallback_trace_records_blocked_after_output(tmp_path) -> None: + executor = RequestExecutor.__new__(RequestExecutor) + logger = TransactionLogger("routing", "code", parent_dir=tmp_path) + + async def fake_stream(self, context): + yield 'data: {"choices":[{"delta":{"content":"partial"}}]}\n\n' + raise StreamFailure("rate_limit") + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context(logger=logger))] + + pass_names = [json.loads(line)["pass_name"] for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + assert "routing_stream_target_attempt_started" in pass_names + assert "routing_stream_target_attempt_failed" in pass_names + assert "routing_stream_fallback_blocked_after_output" in pass_names + + +@pytest.mark.asyncio +async def test_streaming_fallback_treats_error_chunk_as_not_visible_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_stream(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + yield 'data: {"error":{"type":"rate_limit"}}\n\n' + raise StreamFailure("rate_limit") + yield "data: [DONE]\n\n" + + executor._execute_streaming = MethodType(fake_stream, executor) + + chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + + assert attempts == ["codex", "openai"] + assert chunks[-1] == "data: [DONE]\n\n" From 94d800fddcb7b34bef3b96ef35b1c2a6eed2acb9 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 01:52:46 +0200 Subject: [PATCH 043/182] fix(routing): wire resolver and align fallback errors Fixes Phase 6 review findings by aligning fallback policy categories with classifier output, passing provider context into route error classification, wiring env fallback routes into RequestContextBuilder, carrying per-target request scopes into fallback attempts, and honoring group overrides in the isolated fallback runner. Adds tests for quota_exceeded fallback, classifier-aligned stop categories, request-builder env routing, per-target scopes, and group policy overrides. Tests: python -m pytest tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_fallback_policy.py tests/test_routing_attempts.py tests/test_fallback_attempt_runner.py tests/test_request_executor_fallback_groups.py tests/test_request_executor_native_routing.py tests/test_streaming_fallback_policy.py tests/test_request_builder_routing.py && python -m pytest tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py tests/test_adapter_registry.py tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_provider_protocol_declarations.py tests/test_responses_store.py tests/test_responses_bridge.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_streaming.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_session_tracking.py tests/test_selection_engine.py --- src/rotator_library/client/executor.py | 37 +++++++-- src/rotator_library/client/request_builder.py | 64 +++++++++++++++- src/rotator_library/routing/attempts.py | 4 + src/rotator_library/routing/executor.py | 15 +++- src/rotator_library/routing/types.py | 33 +++++++- tests/test_fallback_attempt_runner.py | 23 ++++++ tests/test_fallback_policy.py | 11 ++- tests/test_request_builder_routing.py | 76 +++++++++++++++++++ .../test_request_executor_fallback_groups.py | 21 +++++ 9 files changed, 268 insertions(+), 16 deletions(-) create mode 100644 tests/test_request_builder_routing.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 542b65327..12773bb56 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -630,7 +630,10 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> context, target, target_index=index, - usage_manager_key=target.provider, + credentials=_target_scope_value(target, "credentials", context.credentials), + usage_manager_key=_target_scope_value(target, "usage_manager_key", target.provider), + provider_config=_target_scope_value(target, "provider_config", context.provider_config), + credential_secrets=_target_scope_value(target, "credential_secrets", context.credential_secrets), ) self._log_routing_trace( context, @@ -642,7 +645,7 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> result = await self._execute_non_streaming(target_context) except Exception as exc: last_failure = exc - error_type = _route_error_type(exc) + error_type = _route_error_type(exc, target.provider) self._log_routing_trace( context, "routing_target_attempt_failed", @@ -703,7 +706,10 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy context, target, target_index=index, - usage_manager_key=target.provider, + credentials=_target_scope_value(target, "credentials", context.credentials), + usage_manager_key=_target_scope_value(target, "usage_manager_key", target.provider), + provider_config=_target_scope_value(target, "provider_config", context.provider_config), + credential_secrets=_target_scope_value(target, "credential_secrets", context.credential_secrets), ) self._log_routing_trace( context, @@ -724,7 +730,7 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy ) return except Exception as exc: - error_type = _route_error_type(exc) + error_type = _route_error_type(exc, target.provider) self._log_routing_trace( context, "routing_stream_target_attempt_failed", @@ -1856,13 +1862,24 @@ def _provider_native_protocol(plugin: Any, model: str, target: Optional[RouteTar return None -def _route_error_type(error: BaseException) -> str: +def _target_scope_value(target: RouteTarget, key: str, default: Any) -> Any: + """Read request-scope metadata attached by routing resolution.""" + + scope = target.metadata.get("request_scope") if isinstance(target.metadata, dict) else None + if isinstance(scope, dict) and key in scope: + return scope[key] + return default + + +def _route_error_type(error: BaseException, provider: Optional[str] = None) -> str: """Map an exception to a fallback-policy error type.""" + if isinstance(error, asyncio.CancelledError): + return "cancelled" explicit = getattr(error, "error_type", None) if explicit: return str(explicit).lower() - classified = classify_error(error) + classified = classify_error(error, provider) return classified.error_type @@ -1885,7 +1902,13 @@ def _route_error_type_from_response(response: Any) -> Optional[str]: def _stream_chunk_is_visible_output(chunk: str) -> bool: - """Return whether a stream chunk should block cross-target fallback.""" + """Return whether a stream chunk should block cross-target fallback. + + Only client-visible model output should lock the route. Empty frames, DONE + sentinels, and structured error frames are not considered visible output, so + a provider that fails before producing content can still fall through to the + next ordered target. + """ text = chunk.strip() if not text or text == "data: [DONE]": diff --git a/src/rotator_library/client/request_builder.py b/src/rotator_library/client/request_builder.py index 7fb6573b4..5c85794d8 100644 --- a/src/rotator_library/client/request_builder.py +++ b/src/rotator_library/client/request_builder.py @@ -8,6 +8,8 @@ from typing import Any, Awaitable, Callable, Dict, Optional from ..core.types import RequestContext +from ..routing import FallbackResolver, RoutingConfigError, load_routing_config_from_env +from ..routing.types import RouteTarget, RoutingDecision from ..transaction_logger import TransactionLogger @@ -51,6 +53,45 @@ def _provider_from_model(model: str) -> str: def _raise_no_provider(model: str) -> None: raise ValueError(f"Invalid model format or no credentials for provider: {model}") + def _resolve_routing_decision(self, model: str) -> Optional[RoutingDecision]: + """Resolve env-configured fallback routing, if any applies.""" + + config = load_routing_config_from_env() + if not config.fallback_groups and not config.model_routes: + return None + try: + decision = FallbackResolver(config).resolve(model) + if decision.reason == "direct_provider_model" and model.lower() not in config.model_routes: + return None + return decision + except RoutingConfigError: + if "/" in model: + return None + raise + + @staticmethod + def _with_request_scope(target: RouteTarget, scope: Dict[str, Any]) -> RouteTarget: + """Attach per-provider request scope to a route target without secrets in traces.""" + + metadata = dict(target.metadata) + metadata["request_scope"] = { + "credentials": list(scope["credentials"]), + "usage_manager_key": scope["usage_manager_key"], + "provider_config": scope["provider_config"], + "credential_secrets": dict(scope["credential_secrets"]), + } + return RouteTarget( + provider=target.provider, + model=target.model, + name=target.name, + protocol=target.protocol, + execution=target.execution, + priority=target.priority, + weight=target.weight, + conditions=dict(target.conditions), + metadata=metadata, + ) + async def _get_session_hints( self, provider: str, @@ -98,7 +139,9 @@ async def build_completion_context( kwargs ) model = kwargs.get("model", "") - provider = self._provider_from_model(model) + routing_decision = self._resolve_routing_decision(model) + routing_targets = routing_decision.targets if routing_decision else None + provider = routing_targets[0].provider if routing_targets else self._provider_from_model(model) if not provider: self._raise_no_provider(model) @@ -112,8 +155,23 @@ async def build_completion_context( if not scope["credentials"]: self._raise_no_provider(model) + if routing_targets: + scoped_targets = [] + for index, target in enumerate(routing_targets): + target_scope = scope if index == 0 else await self._resolve_scope_for_provider( + target.provider, + classifier, + request_api_keys, + request_providers, + private, + ) + if not target_scope["credentials"]: + self._raise_no_provider(target.prefixed_model) + scoped_targets.append(self._with_request_scope(target, target_scope)) + routing_targets = tuple(scoped_targets) + parent_log_dir = kwargs.pop("_parent_log_dir", None) - resolved_model = self._model_resolver.resolve_model_id(model, provider) + resolved_model = self._model_resolver.resolve_model_id(routing_targets[0].prefixed_model if routing_targets else model, provider) kwargs["model"] = resolved_model transaction_logger = None @@ -160,6 +218,8 @@ async def build_completion_context( provider_config=scope["provider_config"], credential_secrets=scope["credential_secrets"], classifier=scope["classifier"], + routing_targets=routing_targets, + routing_group_name=routing_decision.group_name if routing_decision else None, ) async def build_embedding_context( diff --git a/src/rotator_library/routing/attempts.py b/src/rotator_library/routing/attempts.py index 827279dc9..325b19bc8 100644 --- a/src/rotator_library/routing/attempts.py +++ b/src/rotator_library/routing/attempts.py @@ -18,6 +18,8 @@ def clone_context_for_target( *, credentials: Sequence[str] | None = None, usage_manager_key: str | None = None, + provider_config: dict[str, Any] | None = None, + credential_secrets: dict[str, str] | None = None, target_index: int = 0, ) -> RequestContext: """Return a target-specific request context without mutating the original. @@ -36,5 +38,7 @@ def clone_context_for_target( kwargs=kwargs, credentials=list(credentials) if credentials is not None else list(context.credentials), usage_manager_key=usage_manager_key if usage_manager_key is not None else target.provider, + provider_config=provider_config if provider_config is not None else context.provider_config, + credential_secrets=dict(credential_secrets) if credential_secrets is not None else dict(context.credential_secrets), routing_target_index=target_index, ) diff --git a/src/rotator_library/routing/executor.py b/src/rotator_library/routing/executor.py index 868d50267..0e0c7e9b2 100644 --- a/src/rotator_library/routing/executor.py +++ b/src/rotator_library/routing/executor.py @@ -9,7 +9,7 @@ from typing import Any from .policy import FallbackPolicy -from .types import RouteAttemptResult, RouteTarget, RoutingDecision +from .types import FallbackGroup, RouteAttemptResult, RouteTarget, RoutingDecision AttemptCallback = Callable[[RouteTarget, int], Awaitable[Any]] @@ -38,8 +38,19 @@ def __init__(self, policy: FallbackPolicy | None = None) -> None: async def run(self, decision: RoutingDecision, attempt: AttemptCallback, *, stream: bool = False) -> Any: """Try targets in order until one succeeds or fallback is exhausted.""" + return await self.run_group(decision, None, attempt, stream=stream) + + async def run_group( + self, + decision: RoutingDecision, + group: FallbackGroup | None, + attempt: AttemptCallback, + *, + stream: bool = False, + ) -> Any: + """Try targets while honoring optional group-specific policy overrides.""" + attempts: list[RouteAttemptResult] = [] - group = None for index, target in enumerate(decision.targets): try: return await attempt(target, index) diff --git a/src/rotator_library/routing/types.py b/src/rotator_library/routing/types.py index 36bd4e056..286d22b87 100644 --- a/src/rotator_library/routing/types.py +++ b/src/rotator_library/routing/types.py @@ -11,8 +11,37 @@ ExecutionMode = Literal["auto", "native", "custom", "litellm_fallback"] StreamingFallbackPolicy = Literal["pre_output_only", "never"] -DEFAULT_FAILOVER_ON = frozenset({"rate_limit", "quota", "capacity", "server_error", "api_connection", "transient", "unsupported_operation"}) -DEFAULT_STOP_ON = frozenset({"auth", "authentication", "validation", "permanent", "pre_request_callback", "cancelled"}) +DEFAULT_FAILOVER_ON = frozenset( + { + "rate_limit", + "quota_exceeded", + "server_error", + "api_connection", + "unsupported_operation", + # Human-friendly aliases for config files; classifier output uses the + # names above, but config authors should not need to know every internal + # error string. + "quota", + "capacity", + "transient", + } +) +DEFAULT_STOP_ON = frozenset( + { + "authentication", + "forbidden", + "invalid_request", + "context_window_exceeded", + "credential_reauth_needed", + "pre_request_callback_error", + "cancelled", + # Config aliases retained for readability. + "auth", + "validation", + "permanent", + "pre_request_callback", + } +) @dataclass(frozen=True) diff --git a/tests/test_fallback_attempt_runner.py b/tests/test_fallback_attempt_runner.py index 0ba95adc9..60219cd73 100644 --- a/tests/test_fallback_attempt_runner.py +++ b/tests/test_fallback_attempt_runner.py @@ -3,6 +3,7 @@ import pytest from rotator_library.routing import FallbackAttemptRunner, FallbackExhaustedError, RoutingDecision, parse_route_target +from rotator_library.routing.types import FallbackGroup class ClassifiedFailure(Exception): @@ -73,3 +74,25 @@ async def attempt(target, index): assert len(exc.value.attempts) == 1 assert exc.value.attempts[0].emitted_output is True + + +@pytest.mark.asyncio +async def test_attempt_runner_honors_group_policy_overrides() -> None: + group = FallbackGroup( + name="custom", + targets=_decision().targets, + failover_on=frozenset({"authentication"}), + stop_on=frozenset({"validation"}), + ) + calls = [] + + async def attempt(target, index): + calls.append(index) + if index == 0: + raise ClassifiedFailure("authentication") + return {"target": target.prefixed_model} + + result = await FallbackAttemptRunner().run_group(_decision(), group, attempt) + + assert result == {"target": "openai/gpt-5.1"} + assert calls == [0, 1] diff --git a/tests/test_fallback_policy.py b/tests/test_fallback_policy.py index 2d36892cd..eb7cb0e88 100644 --- a/tests/test_fallback_policy.py +++ b/tests/test_fallback_policy.py @@ -8,6 +8,7 @@ def test_policy_falls_back_on_retryable_categories() -> None: policy = FallbackPolicy() assert policy.should_fallback("rate_limit") is True + assert policy.should_fallback("quota_exceeded") is True assert policy.should_fallback("server_error") is True assert policy.should_fallback("api_connection") is True @@ -15,9 +16,13 @@ def test_policy_falls_back_on_retryable_categories() -> None: def test_policy_stops_on_permanent_categories() -> None: policy = FallbackPolicy() - assert policy.should_fallback("auth") is False - assert policy.should_fallback("validation") is False - assert policy.should_fallback("pre_request_callback") is False + assert policy.should_fallback("authentication") is False + assert policy.should_fallback("forbidden") is False + assert policy.should_fallback("invalid_request") is False + assert policy.should_fallback("context_window_exceeded") is False + assert policy.should_fallback("credential_reauth_needed") is False + assert policy.should_fallback("pre_request_callback_error") is False + assert policy.should_fallback("cancelled") is False def test_policy_blocks_stream_fallback_after_visible_output() -> None: diff --git a/tests/test_request_builder_routing.py b/tests/test_request_builder_routing.py new file mode 100644 index 000000000..07b7579d3 --- /dev/null +++ b/tests/test_request_builder_routing.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import pytest + +from rotator_library.client.request_builder import RequestContextBuilder + + +class FakeModelResolver: + def resolve_model_id(self, model, provider): + return model + + +class FakeSession: + session_id = "session" + affinity_key = "affinity" + possible_compaction = False + lineage_parent_session_id = None + tracking_namespace = "namespace" + + +class FakeSessionTracker: + def infer_session(self, *args, **kwargs): + return FakeSession() + + +async def _scope(provider, classifier, request_api_keys, request_providers, private): + return { + "credentials": [f"{provider}-cred"], + "usage_manager_key": provider, + "provider_config": {"provider": provider}, + "credential_secrets": {f"{provider}-cred": f"{provider}-secret"}, + "classifier": classifier or "global", + } + + +def _builder() -> RequestContextBuilder: + return RequestContextBuilder( + resolve_scope_for_provider=_scope, + model_resolver=FakeModelResolver(), + session_tracker=FakeSessionTracker(), + get_global_timeout=lambda: 30, + get_enable_request_logging=lambda: False, + ) + + +@pytest.mark.asyncio +async def test_request_builder_leaves_no_config_provider_model_unrouted(monkeypatch) -> None: + monkeypatch.delenv("FALLBACK_GROUPS", raising=False) + context = await _builder().build_completion_context(None, None, {"model": "openai/gpt-5.1", "messages": []}) + + assert context.routing_targets is None + assert context.provider == "openai" + assert context.credentials == ["openai-cred"] + + +@pytest.mark.asyncio +async def test_request_builder_populates_fallback_group_targets_from_env(monkeypatch) -> None: + monkeypatch.setenv("FALLBACK_GROUPS", "code_chain") + monkeypatch.setenv("FALLBACK_GROUP_CODE_CHAIN", "codex/gpt-5.1-codex,openai/gpt-5.1") + monkeypatch.setenv("MODEL_ROUTE_CODEX", "group:code_chain") + + context = await _builder().build_completion_context(None, None, {"model": "codex", "messages": []}) + + assert context.provider == "codex" + assert context.model == "codex/gpt-5.1-codex" + assert context.routing_group_name == "code_chain" + assert [target.prefixed_model for target in context.routing_targets] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + assert context.routing_targets[1].metadata["request_scope"]["credentials"] == ["openai-cred"] + + +@pytest.mark.asyncio +async def test_request_builder_rejects_unprefixed_model_without_route(monkeypatch) -> None: + monkeypatch.delenv("FALLBACK_GROUPS", raising=False) + + with pytest.raises(ValueError): + await _builder().build_completion_context(None, None, {"model": "gpt-5.1", "messages": []}) diff --git a/tests/test_request_executor_fallback_groups.py b/tests/test_request_executor_fallback_groups.py index 1188a5347..82a85070f 100644 --- a/tests/test_request_executor_fallback_groups.py +++ b/tests/test_request_executor_fallback_groups.py @@ -56,6 +56,27 @@ async def test_non_streaming_fallback_group_tries_next_target_on_retryable_error assert [attempt.kwargs["model"] for attempt in attempts] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_falls_back_on_quota_exceeded() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise ClassifiedFailure("quota_exceeded") + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + result = await executor._execute_non_streaming_with_fallback( + _context(routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1"))) + ) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + assert attempts == ["codex", "openai"] + + @pytest.mark.asyncio async def test_non_streaming_fallback_group_stops_on_permanent_error() -> None: executor = RequestExecutor.__new__(RequestExecutor) From 2fbfca99c3dc2a32ada8b92552daddb3a7eceef4 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:03:32 +0200 Subject: [PATCH 044/182] docs(experimental): plan retry cooldown failover cleanup Adds the Phase 7 plan for retry policy cleanup, provider cooldown activation, fallback group policy wiring, cross-target error summaries, and focused/regression test coverage. --- .../phase-7-retry-cooldown-failover.md | 253 ++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 docs/experimental/phase-7-retry-cooldown-failover.md diff --git a/docs/experimental/phase-7-retry-cooldown-failover.md b/docs/experimental/phase-7-retry-cooldown-failover.md new file mode 100644 index 000000000..ffcfba144 --- /dev/null +++ b/docs/experimental/phase-7-retry-cooldown-failover.md @@ -0,0 +1,253 @@ +# Phase 7 Plan: Retry/Cooldown/Failover Cleanup + +## Goal + +Harden the retry and failover layer now that ordered routing exists. Phase 7 makes provider-level cooldown functional, reduces fallback policy duplication, preserves the current retry-after parser, and improves final error visibility across target chains without replacing `UsageManager`, `SessionTracker`, credential rotation, or the existing retry classifier. + +## Non-Goals + +- Do not replace `UsageManager`, `SelectionEngine`, `SessionTracker`, or the retry-after parser. +- Do not rewrite the entire executor. +- Do not implement the Phase 8 streaming transport/library overhaul. +- Do not implement native streaming dispatch unless a small safe seam is necessary. +- Do not add SQLite or any database. +- Do not implement full JSON config merging; Phase 10 owns config polish. +- Do not make fallback happen after visible streamed output. +- Do not silently fallback to LiteLLM unless the route target explicitly selects `litellm_fallback`. + +## Current Code Context + +- `CooldownManager` exists and `_wait_for_cooldown()` is called before target attempts, but `start_cooldown()` is not called by the executor, so provider cooldown is effectively dormant. +- Existing retry-after parsing lives in `error_handler.py` and is strong; keep it. +- Existing `RequestExecutor._handle_error_with_context()` handles non-streaming per-credential retries and rotations. +- Streaming has duplicated retry/rotation handling in several exception branches. +- Phase 6 added `FallbackPolicy`, `FallbackAttemptRunner`, `clone_context_for_target()`, and executor inline fallback loops. +- Phase 6 final review found no blockers but noted two useful cleanup targets: executor inline loops do not pass `FallbackGroup` overrides to `FallbackPolicy`, and `FallbackAttemptRunner` is tested but not used by live executor. +- Phase 6 response-based fallback maps exhausted structured errors to retryable categories, but final target-chain failures do not yet summarize all target failures. +- Provider-level cooldown should apply to provider-wide or IP/global throttling, not every small per-key retry-after. + +## Files To Add + +- `src/rotator_library/retry_policy.py` +- `tests/test_retry_policy.py` +- `tests/test_cooldown_activation.py` +- `tests/test_request_executor_fallback_error_summary.py` +- Possibly `tests/test_request_executor_group_policy.py` + +## Files Likely To Touch + +- `src/rotator_library/client/executor.py` +- `src/rotator_library/cooldown_manager.py` +- `src/rotator_library/routing/types.py` +- `src/rotator_library/routing/executor.py` +- `src/rotator_library/core/types.py` +- `src/rotator_library/error_handler.py` only if docstring/type comments need alignment; do not weaken parser behavior. +- Existing Phase 6 tests if group policy wiring changes expected trace order. + +## Retry Policy Foundation + +Add a small `retry_policy.py` module to centralize decisions that are currently duplicated: + +- `classify_route_error(error, provider)` +- `should_provider_cooldown(classified_error, *, small_cooldown_threshold, provider_cooldown_threshold)` +- `provider_cooldown_duration(classified_error, default_duration)` +- `should_retry_same_credential(classified_error, small_cooldown_threshold)` +- `should_rotate_credential(classified_error)` +- `is_target_failover_eligible(error_type, group=None, stream=False, emitted_output=False)` + +This module should call existing `classify_error()`, `should_retry_same_key()`, and `should_rotate_on_error()` rather than reimplementing them. + +It should document why small retry-after values stay on the same credential while large/provider-level values can activate provider cooldown. + +It should not own sleeping or mutation; it only returns decisions. + +## Cooldown Activation + +Add provider cooldown activation in non-streaming error handling: + +- If `classified.retry_after` exists and is above `SMALL_COOLDOWN_RETRY_THRESHOLD`, start provider cooldown for that duration when the error is provider-wide enough. +- Candidate categories: `rate_limit`, `server_error`, and possibly provider-specific `quota_exceeded` only when retry-after suggests global reset rather than per-key quota. +- Be conservative: avoid cooling down an entire provider for every credential quota if the error looks per-credential. + +Add env knobs: + +- `PROVIDER_COOLDOWN_MIN_SECONDS` default perhaps same as small-cooldown threshold or a modest value. +- `PROVIDER_COOLDOWN_DEFAULT_SECONDS` for retryable provider-wide errors without retry-after. +- `PROVIDER_COOLDOWN_ON_QUOTA` default false unless evidence indicates provider-wide quota. + +Existing `_wait_for_cooldown()` stays the wait path. + +`CooldownManager.start_cooldown()` should extend only when the new expiry is later than the current expiry, not shorten an active cooldown. + +Add trace pass: + +- `provider_cooldown_started` +- metadata: provider, duration, error_type, retry_after_present, reason. + +Logging failure to start cooldown should never fail the request. + +## Fallback Runner And Live Loop Cleanup + +Either wire `FallbackAttemptRunner` into live non-streaming and streaming wrappers or keep inline wrappers but pass a resolved group object. + +Minimal preferred approach for Phase 7: + +- Add `routing_group` optional field to `RequestContext` or attach group policy data to context. +- `RequestContextBuilder` should carry the `FallbackGroup` resolved from config when routing is active. +- Executor fallback wrappers pass `group=context.routing_group` to `FallbackPolicy.should_fallback()`. +- Keep inline wrappers if that avoids contorting trace behavior. + +If refactoring to `FallbackAttemptRunner` stays small, do it; otherwise leave runner as tested support and make inline loops group-aware. + +Add tests proving group overrides are honored in live executor wrappers for non-streaming and streaming. + +## Cross-Target Error Summary + +Add a target failure accumulator for fallback groups: + +- target name +- provider +- model +- execution mode +- error type +- message summary +- emitted_output for streams + +If all fallback targets fail, final client error should include `fallback_targets` details under `error.details`, with no credentials. + +Preserve existing per-target credential error summaries from `RequestErrorAccumulator`. + +Do not expose credential secrets or raw request data. + +Add trace pass: + +- `routing_fallback_exhausted` +- include accumulated target failures. + +Tests should cover: + +- all targets fail and final response contains both target failures. +- exception failure on first target and structured proxy error on second target are both summarized. +- no credentials in summary. + +## Provider Cooldown And Fallback Interaction + +- If a target enters provider cooldown and still fails/exhausts, fallback to next target may proceed if policy allows. +- If a provider is already cooling down, `_wait_for_cooldown()` should wait only if there is request budget; otherwise target should fail fast in a way fallback can interpret. +- Consider returning/raising a classified `provider_cooldown_budget_exceeded` or mapping to `rate_limit` for target fallback. +- Keep this minimal: do not replace capacity waiting or usage limits. + +## Streaming Retry/Fallback Cleanup + +- Keep Phase 6 invariant: fallback only before visible output. +- Deduplicate obvious streaming error branches only if the change is small and tests cover it. +- Add cooldown activation in streaming error handling where retry-after is available and no visible output has been emitted. +- Do not implement native streaming execution-mode dispatch in Phase 7 unless reviewers insist; Phase 8 owns streaming transport design. +- Add provider cooldown trace for streaming provider-wide throttles. + +## Error-Classification Alignment + +Add tests tying actual `classify_error()` output to fallback policy decisions for: + +- `rate_limit` +- `quota_exceeded` +- `server_error` +- `api_connection` +- `authentication` +- `forbidden` +- `invalid_request` +- `context_window_exceeded` +- `credential_reauth_needed` +- `pre_request_callback_error` +- `cancelled` +- `unsupported_operation` + +Preserve conservative `unknown` behavior unless a test shows a strong reason to fallback. + +## Stale Fallback Test Cleanup + +There is a stale ignored test file `tests/test_fallback_groups.py` that imports a non-existent older module. + +- Do not delete it unless it is tracked and relevant to current runs. +- If it is tracked or causing collection failures in broader test runs, replace/archive it with current routing tests in a focused commit. +- If ignored/untracked, leave it alone unless it blocks CI. + +## Transform Trace Requirements + +Keep existing Phase 6 traces: + +- `routing_decision` +- `routing_target_attempt_started` +- `routing_target_attempt_failed` +- `routing_target_attempt_succeeded` +- `routing_fallback_selected` +- `routing_fallback_exhausted` +- `routing_litellm_fallback` +- `routing_native_execution_selected` +- stream equivalents. + +Add: + +- `provider_cooldown_started` +- `provider_cooldown_skipped` +- `routing_group_policy_applied` if group overrides influence a decision. + +Trace metadata must not include credentials or secrets. + +## Testing Plan + +Retry policy tests: + +- existing classifier output maps to expected retry/rotate/fallback/cooldown decisions. +- small retry-after chooses same-credential retry, not provider cooldown. +- large retry-after can start provider cooldown. +- `unknown` remains conservative for target fallback. + +Cooldown tests: + +- `CooldownManager.start_cooldown()` extends but does not shorten cooldown. +- non-streaming large retry-after starts provider cooldown. +- small retry-after does not start provider cooldown. +- cooldown trace emitted. +- `_wait_for_cooldown()` respects deadline budget. + +Fallback group policy tests: + +- live non-streaming fallback honors group-specific `failover_on`. +- live non-streaming fallback honors group-specific `stop_on`. +- live streaming fallback honors group policy before output. + +Error summary tests: + +- all target failures are summarized in final error details. +- target summary excludes credentials. + +Regression tests: + +- Phase 1 protocol tests. +- Phase 2 transform logging tests. +- Phase 3 adapter/field-cache tests. +- Phase 4 Responses tests. +- Phase 5 provider/native tests. +- Phase 6 routing tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add retry policy helper module and classifier-alignment tests. +2. Harden `CooldownManager` extension semantics and tests. +3. Wire provider cooldown activation into non-streaming executor with trace tests. +4. Add group-policy awareness to live fallback wrappers with tests. +5. Add cross-target error summaries with tests. +6. Add streaming cooldown/group-policy cleanup tests if not covered earlier. +7. Run focused and regression tests. +8. Review with `explore` and `explore-heavy`; fix findings; write uncommitted Phase 7 report. + +## Risks And Mitigations + +- Provider cooldown could over-throttle healthy credentials. Mitigation: only activate on large retry-after/provider-wide categories, keep quota cooldown disabled by default unless configured. +- Refactoring fallback loops could break trace order. Mitigation: prefer small group-aware changes unless runner integration stays simple. +- Error summaries could leak credentials. Mitigation: summarize target/provider/model/error only; never include credentials. +- Streaming fallback could corrupt output if changed carelessly. Mitigation: preserve Phase 6 visible-output gate and add tests. +- Retry policy could drift from existing parser. Mitigation: call existing parser/classifier helpers instead of rewriting them. From 6498eef69b0af13347df00c49c0ae996dbe9d7db Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:05:11 +0200 Subject: [PATCH 045/182] feat(retry): add retry policy helpers Adds Phase 7 retry policy helpers that delegate to the existing classifier, retry-after parser, same-key retry, credential rotation, fallback policy, and conservative provider-cooldown decisions. Tests cover classifier-aligned fallback decisions, explicit route errors, cancellation, same-credential retry, rotation semantics, large retry-after cooldown eligibility, and conservative quota cooldown defaults. Tests: python -m pytest tests/test_retry_policy.py tests/test_fallback_policy.py --- src/rotator_library/retry_policy.py | 119 ++++++++++++++++++++++++++++ tests/test_retry_policy.py | 88 ++++++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 src/rotator_library/retry_policy.py create mode 100644 tests/test_retry_policy.py diff --git a/src/rotator_library/retry_policy.py b/src/rotator_library/retry_policy.py new file mode 100644 index 000000000..6e133262e --- /dev/null +++ b/src/rotator_library/retry_policy.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Retry, cooldown, and target-failover policy helpers. + +This module centralizes decisions that sit above the existing error classifier. +It deliberately delegates parsing and classification to `error_handler.py` so the +proxy keeps its current retry-after parser and credential-rotation semantics. +""" + +from __future__ import annotations + +import asyncio +import os +from dataclasses import dataclass +from typing import Optional + +from .error_handler import ClassifiedError, classify_error, should_retry_same_key, should_rotate_on_error +from .routing import FallbackPolicy +from .routing.types import FallbackGroup + +DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS = 30 + + +@dataclass(frozen=True) +class ProviderCooldownDecision: + """Decision describing whether a provider-level cooldown should start.""" + + should_start: bool + duration: int = 0 + reason: str = "not_applicable" + + +def classify_route_error(error: BaseException, provider: Optional[str] = None) -> str: + """Map an exception into the vocabulary consumed by fallback policy.""" + + if isinstance(error, asyncio.CancelledError): + return "cancelled" + explicit = getattr(error, "error_type", None) + if explicit: + return str(explicit).lower() + return classify_error(error, provider).error_type + + +def should_retry_same_credential(classified_error: ClassifiedError, small_cooldown_threshold: int) -> bool: + """Return whether the current credential should be retried before rotation.""" + + return should_retry_same_key(classified_error, small_cooldown_threshold) + + +def should_rotate_credential(classified_error: ClassifiedError) -> bool: + """Return whether a classified failure should rotate to another credential.""" + + return should_rotate_on_error(classified_error) + + +def decide_provider_cooldown( + classified_error: ClassifiedError, + *, + small_cooldown_threshold: int, + provider_cooldown_min_seconds: int, + default_duration: int = DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS, + cooldown_on_quota: bool = False, +) -> ProviderCooldownDecision: + """Return whether a provider-wide cooldown should be activated. + + Small retry-after values are intentionally left to same-credential retry to + preserve cache/session locality. Larger retry-after values can indicate a + provider-wide or IP-level throttle and are therefore safe to coordinate via + the provider cooldown manager. + """ + + error_type = classified_error.error_type + if error_type == "quota_exceeded" and not cooldown_on_quota: + return ProviderCooldownDecision(False, reason="quota_cooldown_disabled") + if error_type not in {"rate_limit", "server_error", "api_connection", "quota_exceeded"}: + return ProviderCooldownDecision(False, reason="non_provider_cooldown_error") + + retry_after = classified_error.retry_after + if retry_after is not None: + if retry_after <= 0: + return ProviderCooldownDecision(False, reason="non_positive_retry_after") + if retry_after < small_cooldown_threshold: + return ProviderCooldownDecision(False, reason="small_retry_after") + if retry_after < provider_cooldown_min_seconds: + return ProviderCooldownDecision(False, reason="below_provider_cooldown_minimum") + return ProviderCooldownDecision(True, duration=int(retry_after), reason="retry_after") + + if error_type in {"server_error", "api_connection"} and default_duration >= provider_cooldown_min_seconds: + return ProviderCooldownDecision(True, duration=int(default_duration), reason="default_transient_cooldown") + return ProviderCooldownDecision(False, reason="missing_retry_after") + + +def provider_cooldown_env() -> tuple[int, int, bool]: + """Read provider-cooldown env controls with conservative defaults.""" + + min_seconds = _env_int("PROVIDER_COOLDOWN_MIN_SECONDS", 10) + default_seconds = _env_int("PROVIDER_COOLDOWN_DEFAULT_SECONDS", DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS) + cooldown_on_quota = os.environ.get("PROVIDER_COOLDOWN_ON_QUOTA", "").strip().lower() in {"1", "true", "yes", "on"} + return max(0, min_seconds), max(0, default_seconds), cooldown_on_quota + + +def is_target_failover_eligible( + error_type: str, + *, + group: FallbackGroup | None = None, + stream: bool = False, + emitted_output: bool = False, +) -> bool: + """Return whether a target failure may advance to the next target.""" + + return FallbackPolicy().should_fallback(error_type, group=group, stream=stream, emitted_output=emitted_output) + + +def _env_int(name: str, default: int) -> int: + try: + return int(os.environ.get(name, default)) + except (TypeError, ValueError): + return default diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py new file mode 100644 index 000000000..704c95eef --- /dev/null +++ b/tests/test_retry_policy.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio + +from rotator_library.error_handler import ClassifiedError, PreRequestCallbackError +from rotator_library.retry_policy import ( + classify_route_error, + decide_provider_cooldown, + is_target_failover_eligible, + should_retry_same_credential, + should_rotate_credential, +) + + +class ExplicitRouteError(Exception): + error_type = "unsupported_operation" + + +def _classified(error_type: str, **kwargs) -> ClassifiedError: + return ClassifiedError(error_type, original_exception=Exception(error_type), **kwargs) + + +def test_classifier_output_maps_to_fallback_policy_categories() -> None: + retryable = ["rate_limit", "quota_exceeded", "server_error", "api_connection", "unsupported_operation"] + stopped = [ + "authentication", + "forbidden", + "invalid_request", + "context_window_exceeded", + "credential_reauth_needed", + "pre_request_callback_error", + "cancelled", + ] + + assert all(is_target_failover_eligible(error_type) for error_type in retryable) + assert all(not is_target_failover_eligible(error_type) for error_type in stopped) + assert not is_target_failover_eligible("unknown") + + +def test_classify_route_error_preserves_explicit_and_cancelled_types() -> None: + assert classify_route_error(ExplicitRouteError()) == "unsupported_operation" + assert classify_route_error(asyncio.CancelledError()) == "cancelled" + assert classify_route_error(PreRequestCallbackError("boom")) == "pre_request_callback_error" + + +def test_retry_and_rotation_helpers_delegate_to_existing_semantics() -> None: + small_rate_limit = _classified("rate_limit", retry_after=3) + invalid_request = _classified("invalid_request") + + assert should_retry_same_credential(small_rate_limit, small_cooldown_threshold=10) is True + assert should_rotate_credential(small_rate_limit) is True + assert should_rotate_credential(invalid_request) is False + + +def test_provider_cooldown_uses_large_retry_after_not_small_retry_after() -> None: + small = decide_provider_cooldown( + _classified("rate_limit", retry_after=3), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + ) + large = decide_provider_cooldown( + _classified("rate_limit", retry_after=60), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + ) + + assert small.should_start is False + assert small.reason == "small_retry_after" + assert large.should_start is True + assert large.duration == 60 + + +def test_provider_cooldown_is_conservative_for_quota_by_default() -> None: + disabled = decide_provider_cooldown( + _classified("quota_exceeded", retry_after=3600), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + ) + enabled = decide_provider_cooldown( + _classified("quota_exceeded", retry_after=3600), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + cooldown_on_quota=True, + ) + + assert disabled.should_start is False + assert disabled.reason == "quota_cooldown_disabled" + assert enabled.should_start is True From 39980380fda23295a1c83700bdf628c78d9bc710 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:05:49 +0200 Subject: [PATCH 046/182] fix(cooldown): preserve longer provider cooldowns Updates provider cooldown start behavior so a shorter incoming cooldown cannot shorten an already-active longer provider cooldown, which is important when multiple requests observe throttles concurrently. Tests: python -m pytest tests/test_cooldown_activation.py tests/test_retry_policy.py --- src/rotator_library/cooldown_manager.py | 8 +++++++- tests/test_cooldown_activation.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 tests/test_cooldown_activation.py diff --git a/src/rotator_library/cooldown_manager.py b/src/rotator_library/cooldown_manager.py index 0d1bb63ee..184056f6f 100644 --- a/src/rotator_library/cooldown_manager.py +++ b/src/rotator_library/cooldown_manager.py @@ -28,9 +28,15 @@ async def start_cooldown(self, provider: str, duration: int): """ Initiates or extends a cooldown period for a provider. The cooldown is set to the current time plus the specified duration. + + A shorter new cooldown must not shorten an existing longer cooldown; + provider-wide throttles often arrive concurrently from several requests. """ async with self._lock: - self._cooldowns[provider] = time.time() + duration + new_expiry = time.time() + max(0, duration) + current_expiry = self._cooldowns.get(provider, 0) + if new_expiry > current_expiry: + self._cooldowns[provider] = new_expiry async def get_cooldown_remaining(self, provider: str) -> float: """ diff --git a/tests/test_cooldown_activation.py b/tests/test_cooldown_activation.py new file mode 100644 index 000000000..2ec8aba43 --- /dev/null +++ b/tests/test_cooldown_activation.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from rotator_library.cooldown_manager import CooldownManager + + +@pytest.mark.asyncio +async def test_start_cooldown_extends_but_does_not_shorten() -> None: + manager = CooldownManager() + + await manager.start_cooldown("provider", 30) + initial = await manager.get_remaining_cooldown("provider") + await asyncio.sleep(0.01) + await manager.start_cooldown("provider", 1) + after_shorter = await manager.get_remaining_cooldown("provider") + await manager.start_cooldown("provider", 60) + after_longer = await manager.get_remaining_cooldown("provider") + + assert after_shorter > 25 + assert after_shorter <= initial + assert after_longer > after_shorter From 0ac6c5046ae2801073aa161b31b5d1ab82129f91 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:07:14 +0200 Subject: [PATCH 047/182] feat(retry): activate provider cooldowns Wires conservative provider-level cooldown activation into classified executor failures using the Phase 7 retry policy. Large provider-level retry-after values can start shared cooldowns while small retry-after values remain same-credential retry candidates. Adds transform trace entries for cooldown started/skipped and tests for large retry-after activation, small retry-after skipping, and cooldown extension behavior. Tests: python -m pytest tests/test_cooldown_activation.py tests/test_retry_policy.py tests/test_request_executor_fallback_groups.py tests/test_request_executor_native_routing.py tests/test_streaming_fallback_policy.py --- src/rotator_library/client/executor.py | 83 ++++++++++++++++++++++++++ tests/test_cooldown_activation.py | 66 ++++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 12773bb56..25baed64b 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -62,6 +62,7 @@ from ..request_sanitizer import sanitize_request_payload from ..transaction_logger import TransactionLogger from ..failure_logger import log_failure +from ..retry_policy import decide_provider_cooldown, provider_cooldown_env from ..routing import FallbackPolicy, clone_context_for_target from ..routing.types import RouteTarget from ..native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor @@ -974,6 +975,7 @@ async def _execute_non_streaming( error_accumulator, retry_state, request_headers, + context, ) if action == ErrorAction.RETRY_SAME: @@ -1524,6 +1526,7 @@ async def _handle_error_with_context( error_accumulator: RequestErrorAccumulator, retry_state: RetryState, request_headers: Dict[str, Any], + context: Optional[RequestContext] = None, ) -> str: """ Handle an error and determine next action. @@ -1572,6 +1575,8 @@ async def _handle_error_with_context( cred_context.mark_failure(classified) return ErrorAction.FAIL + await self._maybe_start_provider_cooldown(provider, classified, context=context) + # Check if should retry same key (including small cooldown auto-retry) small_cooldown_threshold = int( os.environ.get( @@ -1617,6 +1622,84 @@ async def _handle_error_with_context( ) return ErrorAction.ROTATE + async def _maybe_start_provider_cooldown( + self, + provider: str, + classified: ClassifiedError, + *, + context: Optional[RequestContext], + ) -> None: + """Start provider-wide cooldown for large provider-level throttles. + + This is intentionally conservative: small retry-after values stay on the + same credential path, and quota cooldown is disabled by default because + most quota errors are per credential or account. + """ + + if not self._cooldown: + return + small_cooldown_threshold = int( + os.environ.get( + "SMALL_COOLDOWN_RETRY_THRESHOLD", DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD + ) + ) + min_seconds, default_seconds, cooldown_on_quota = provider_cooldown_env() + decision = decide_provider_cooldown( + classified, + small_cooldown_threshold=small_cooldown_threshold, + provider_cooldown_min_seconds=min_seconds, + default_duration=default_seconds, + cooldown_on_quota=cooldown_on_quota, + ) + if not decision.should_start: + self._log_provider_cooldown_trace( + context, + "provider_cooldown_skipped", + provider, + classified, + decision.duration, + decision.reason, + ) + return + try: + await self._cooldown.start_cooldown(provider, decision.duration) + self._log_provider_cooldown_trace( + context, + "provider_cooldown_started", + provider, + classified, + decision.duration, + decision.reason, + ) + except Exception as exc: + lib_logger.debug("Failed to start provider cooldown for %s: %s", provider, exc) + + @staticmethod + def _log_provider_cooldown_trace( + context: Optional[RequestContext], + pass_name: str, + provider: str, + classified: ClassifiedError, + duration: int, + reason: str, + ) -> None: + if not context or not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + pass_name, + {"provider": provider, "error_type": classified.error_type, "duration": duration}, + direction="metadata", + stage="retry", + metadata={ + "provider": provider, + "duration": duration, + "error_type": classified.error_type, + "retry_after_present": classified.retry_after is not None, + "reason": reason, + }, + snapshot=False, + ) + def _record_session_response(self, context: RequestContext, response: Any) -> None: """Let the tracker learn anchors emitted by a successful response. diff --git a/tests/test_cooldown_activation.py b/tests/test_cooldown_activation.py index 2ec8aba43..2a384f4de 100644 --- a/tests/test_cooldown_activation.py +++ b/tests/test_cooldown_activation.py @@ -4,7 +4,19 @@ import pytest +from rotator_library.client.executor import RequestExecutor from rotator_library.cooldown_manager import CooldownManager +from rotator_library.core.types import RequestContext +from rotator_library.error_handler import ClassifiedError +from rotator_library.transaction_logger import TransactionLogger + + +class FakeCooldown: + def __init__(self) -> None: + self.started = [] + + async def start_cooldown(self, provider, duration): + self.started.append((provider, duration)) @pytest.mark.asyncio @@ -22,3 +34,57 @@ async def test_start_cooldown_extends_but_does_not_shorten() -> None: assert after_shorter > 25 assert after_shorter <= initial assert after_longer > after_shorter + + +def _classified(error_type: str, retry_after=None) -> ClassifiedError: + return ClassifiedError(error_type, original_exception=Exception(error_type), retry_after=retry_after) + + +def _executor(cooldown) -> RequestExecutor: + executor = RequestExecutor.__new__(RequestExecutor) + executor._cooldown = cooldown + return executor + + +def _context(logger) -> RequestContext: + return RequestContext( + model="openai/gpt-test", + provider="openai", + kwargs={"model": "openai/gpt-test"}, + streaming=False, + credentials=["cred"], + deadline=9999999999.0, + transaction_logger=logger, + ) + + +@pytest.mark.asyncio +async def test_large_retry_after_starts_provider_cooldown_and_traces(tmp_path, monkeypatch) -> None: + monkeypatch.setenv("SMALL_COOLDOWN_RETRY_THRESHOLD", "10") + monkeypatch.setenv("PROVIDER_COOLDOWN_MIN_SECONDS", "10") + cooldown = FakeCooldown() + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + await _executor(cooldown)._maybe_start_provider_cooldown( + "openai", + _classified("rate_limit", retry_after=60), + context=_context(logger), + ) + + assert cooldown.started == [("openai", 60)] + trace = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "provider_cooldown_started" in trace + + +@pytest.mark.asyncio +async def test_small_retry_after_skips_provider_cooldown(monkeypatch) -> None: + monkeypatch.setenv("SMALL_COOLDOWN_RETRY_THRESHOLD", "10") + cooldown = FakeCooldown() + + await _executor(cooldown)._maybe_start_provider_cooldown( + "openai", + _classified("rate_limit", retry_after=3), + context=None, + ) + + assert cooldown.started == [] From 207b7a2837925dd0754878580b41013598655a7b Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:09:18 +0200 Subject: [PATCH 048/182] feat(routing): honor fallback group policies Carries the resolved fallback group on RequestContext and passes group-specific policy overrides into live non-streaming and streaming fallback decisions while preserving default behavior when no group is present. Adds tests for request-builder group propagation and live executor group override handling for non-streaming and streaming fallback. Tests: python -m pytest tests/test_request_executor_fallback_groups.py tests/test_streaming_fallback_policy.py tests/test_request_builder_routing.py tests/test_fallback_resolver.py tests/test_fallback_attempt_runner.py tests/test_routing_config.py tests/test_routing_attempts.py tests/test_request_executor_native_routing.py tests/test_retry_policy.py tests/test_cooldown_activation.py --- src/rotator_library/client/executor.py | 6 ++-- src/rotator_library/client/request_builder.py | 1 + src/rotator_library/core/types.py | 1 + src/rotator_library/routing/resolver.py | 2 +- src/rotator_library/routing/types.py | 1 + tests/test_request_builder_routing.py | 2 ++ .../test_request_executor_fallback_groups.py | 34 ++++++++++++++++++- tests/test_streaming_fallback_policy.py | 24 ++++++++++++- 8 files changed, 65 insertions(+), 6 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 25baed64b..d32eff5e4 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -653,7 +653,7 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> _target_trace(target), metadata={"target_index": index, "error_type": error_type, "exception": exc.__class__.__name__}, ) - if index >= len(targets) - 1 or not policy.should_fallback(error_type): + if index >= len(targets) - 1 or not policy.should_fallback(error_type, group=context.routing_group): self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type}) raise self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) @@ -668,7 +668,7 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> _target_trace(target), metadata={"target_index": index, "error_type": error_type}, ) - if index < len(targets) - 1 and policy.should_fallback(error_type): + if index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group): self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) continue self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type}) @@ -746,7 +746,7 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy metadata={"target_index": index, "error_type": error_type}, ) raise - if index < len(targets) - 1 and policy.should_fallback(error_type, stream=True, emitted_output=False): + if index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False): self._log_routing_trace( context, "routing_fallback_selected", diff --git a/src/rotator_library/client/request_builder.py b/src/rotator_library/client/request_builder.py index 5c85794d8..df1b400b2 100644 --- a/src/rotator_library/client/request_builder.py +++ b/src/rotator_library/client/request_builder.py @@ -220,6 +220,7 @@ async def build_completion_context( classifier=scope["classifier"], routing_targets=routing_targets, routing_group_name=routing_decision.group_name if routing_decision else None, + routing_group=routing_decision.group if routing_decision else None, ) async def build_embedding_context( diff --git a/src/rotator_library/core/types.py b/src/rotator_library/core/types.py index 1096932c9..f9f07a44b 100644 --- a/src/rotator_library/core/types.py +++ b/src/rotator_library/core/types.py @@ -91,6 +91,7 @@ class RequestContext: classifier: Optional[str] = None routing_targets: Optional[Any] = None routing_group_name: Optional[str] = None + routing_group: Optional[Any] = None routing_target_index: int = 0 routing_attempt_history: List[Dict[str, Any]] = field(default_factory=list) diff --git a/src/rotator_library/routing/resolver.py b/src/rotator_library/routing/resolver.py index 745a6070a..f384095b0 100644 --- a/src/rotator_library/routing/resolver.py +++ b/src/rotator_library/routing/resolver.py @@ -24,7 +24,7 @@ def resolve(self, requested_model: str) -> RoutingDecision: group = self.config.fallback_groups.get(group_name) if not group: raise RoutingConfigError(f"unknown fallback group {group_name}") - return RoutingDecision(requested_model=requested_model, group_name=group.name, targets=group.targets, reason="model_route_group") + return RoutingDecision(requested_model=requested_model, group_name=group.name, group=group, targets=group.targets, reason="model_route_group") if route: return RoutingDecision(requested_model=requested_model, targets=(parse_route_target(route),), reason="model_route_target") if "/" in requested_model: diff --git a/src/rotator_library/routing/types.py b/src/rotator_library/routing/types.py index 286d22b87..42d2a5fbb 100644 --- a/src/rotator_library/routing/types.py +++ b/src/rotator_library/routing/types.py @@ -111,6 +111,7 @@ class RoutingDecision: requested_model: str targets: tuple[RouteTarget, ...] group_name: str | None = None + group: FallbackGroup | None = None selected_target_index: int = 0 reason: str = "direct" diff --git a/tests/test_request_builder_routing.py b/tests/test_request_builder_routing.py index 07b7579d3..ce4dbcf25 100644 --- a/tests/test_request_builder_routing.py +++ b/tests/test_request_builder_routing.py @@ -64,6 +64,8 @@ async def test_request_builder_populates_fallback_group_targets_from_env(monkeyp assert context.provider == "codex" assert context.model == "codex/gpt-5.1-codex" assert context.routing_group_name == "code_chain" + assert context.routing_group is not None + assert context.routing_group.name == "code_chain" assert [target.prefixed_model for target in context.routing_targets] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] assert context.routing_targets[1].metadata["request_scope"]["credentials"] == ["openai-cred"] diff --git a/tests/test_request_executor_fallback_groups.py b/tests/test_request_executor_fallback_groups.py index 82a85070f..c408c2a4d 100644 --- a/tests/test_request_executor_fallback_groups.py +++ b/tests/test_request_executor_fallback_groups.py @@ -8,6 +8,7 @@ from rotator_library.client.executor import RequestExecutor from rotator_library.core.types import RequestContext from rotator_library.routing import parse_route_target +from rotator_library.routing.types import FallbackGroup from rotator_library.transaction_logger import TransactionLogger @@ -17,7 +18,7 @@ def __init__(self, error_type: str) -> None: self.error_type = error_type -def _context(*, routing_targets=None, logger=None) -> RequestContext: +def _context(*, routing_targets=None, logger=None, routing_group=None) -> RequestContext: return RequestContext( model="code", provider="requested", @@ -28,6 +29,7 @@ def _context(*, routing_targets=None, logger=None) -> RequestContext: transaction_logger=logger, routing_targets=routing_targets, routing_group_name="code_chain" if routing_targets else None, + routing_group=routing_group, ) @@ -96,6 +98,36 @@ async def fake_execute(self, context): assert len(attempts) == 1 +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_honors_group_override() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise ClassifiedFailure("authentication") + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) + result = await executor._execute_non_streaming_with_fallback( + _context( + routing_targets=targets, + routing_group=FallbackGroup( + name="code_chain", + targets=targets, + failover_on=frozenset({"authentication"}), + stop_on=frozenset({"validation"}), + ), + ) + ) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + assert attempts == ["codex", "openai"] + + @pytest.mark.asyncio async def test_non_streaming_fallback_group_handles_structured_error_response() -> None: executor = RequestExecutor.__new__(RequestExecutor) diff --git a/tests/test_streaming_fallback_policy.py b/tests/test_streaming_fallback_policy.py index 8e925a5ec..77879d7c9 100644 --- a/tests/test_streaming_fallback_policy.py +++ b/tests/test_streaming_fallback_policy.py @@ -8,6 +8,7 @@ from rotator_library.client.executor import RequestExecutor from rotator_library.core.types import RequestContext from rotator_library.routing import parse_route_target +from rotator_library.routing.types import FallbackGroup from rotator_library.transaction_logger import TransactionLogger @@ -18,6 +19,7 @@ def __init__(self, error_type: str) -> None: def _context(*, logger=None) -> RequestContext: + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) return RequestContext( model="code", provider="requested", @@ -26,11 +28,31 @@ def _context(*, logger=None) -> RequestContext: credentials=["cred-a"], deadline=9999999999.0, transaction_logger=logger, - routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")), + routing_targets=targets, routing_group_name="code_chain", + routing_group=FallbackGroup(name="code_chain", targets=targets, failover_on=frozenset({"authentication", "rate_limit"}), stop_on=frozenset({"validation"})), ) +@pytest.mark.asyncio +async def test_streaming_fallback_honors_group_override_before_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_stream(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise StreamFailure("authentication") + yield "data: [DONE]\n\n" + + executor._execute_streaming = MethodType(fake_stream, executor) + + chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + + assert attempts == ["codex", "openai"] + assert chunks == ["data: [DONE]\n\n"] + + @pytest.mark.asyncio async def test_streaming_fallback_tries_next_target_before_output() -> None: executor = RequestExecutor.__new__(RequestExecutor) From f197d2ec2ef308308788b64da5f04a9a1dae3f9f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:11:12 +0200 Subject: [PATCH 049/182] feat(routing): summarize fallback target failures Adds client-safe fallback target summaries to structured final errors when all fallback targets fail. Summaries include target/provider/model/execution/error type and intentionally omit raw provider error text and credentials. Tests: python -m pytest tests/test_request_executor_fallback_error_summary.py tests/test_request_executor_fallback_groups.py tests/test_streaming_fallback_policy.py tests/test_request_executor_native_routing.py tests/test_request_builder_routing.py tests/test_retry_policy.py tests/test_cooldown_activation.py --- src/rotator_library/client/executor.py | 46 +++++++++++++++-- ...request_executor_fallback_error_summary.py | 49 +++++++++++++++++++ 2 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 tests/test_request_executor_fallback_error_summary.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index d32eff5e4..8496d13c4 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -620,6 +620,7 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> return await self._execute_non_streaming(context) policy = FallbackPolicy() last_failure: Any = None + target_failures: List[Dict[str, Any]] = [] self._log_routing_trace( context, "routing_decision", @@ -653,8 +654,9 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> _target_trace(target), metadata={"target_index": index, "error_type": error_type, "exception": exc.__class__.__name__}, ) + target_failures.append(_target_failure_summary(target, error_type, str(exc))) if index >= len(targets) - 1 or not policy.should_fallback(error_type, group=context.routing_group): - self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type}) + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "fallback_targets": target_failures}) raise self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) continue @@ -668,11 +670,12 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> _target_trace(target), metadata={"target_index": index, "error_type": error_type}, ) + target_failures.append(_target_failure_summary(target, error_type, _route_error_message_from_response(result))) if index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group): self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) continue - self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type}) - return result + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "fallback_targets": target_failures}) + return _with_fallback_summary(result, target_failures) self._log_routing_trace( context, @@ -1984,6 +1987,43 @@ def _route_error_type_from_response(response: Any) -> Optional[str]: return None +def _route_error_message_from_response(response: Any) -> str: + """Extract a short non-secret message from a structured proxy error.""" + + if not isinstance(response, dict) or not isinstance(response.get("error"), dict): + return "" + message = str(response["error"].get("message", "")) + return message[:150] + + +def _target_failure_summary(target: RouteTarget, error_type: str, message: str = "") -> Dict[str, Any]: + """Return a client-safe fallback target failure summary.""" + + return { + "target": target.name, + "provider": target.provider, + "model": target.prefixed_model, + "execution": target.execution, + "error_type": error_type, + # Provider error text can contain raw upstream payload fragments or + # credential-like identifiers. Keep the cross-target summary structural; + # detailed per-credential errors remain in the existing sanitized error + # accumulator. + "message": "", + } + + +def _with_fallback_summary(response: Any, target_failures: List[Dict[str, Any]]) -> Any: + """Attach fallback target summaries to a structured error response.""" + + if not target_failures or not isinstance(response, dict) or not isinstance(response.get("error"), dict): + return response + details = response["error"].setdefault("details", {}) + if isinstance(details, dict): + details["fallback_targets"] = list(target_failures) + return response + + def _stream_chunk_is_visible_output(chunk: str) -> bool: """Return whether a stream chunk should block cross-target fallback. diff --git a/tests/test_request_executor_fallback_error_summary.py b/tests/test_request_executor_fallback_error_summary.py new file mode 100644 index 000000000..bab3ea55b --- /dev/null +++ b/tests/test_request_executor_fallback_error_summary.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from types import MethodType + +import pytest + +from rotator_library.client.executor import RequestExecutor +from rotator_library.core.types import RequestContext +from rotator_library.routing import parse_route_target + + +def _context() -> RequestContext: + return RequestContext( + model="code", + provider="requested", + kwargs={"model": "code", "messages": []}, + streaming=False, + credentials=["cred-a"], + deadline=9999999999.0, + routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")), + routing_group_name="code_chain", + ) + + +@pytest.mark.asyncio +async def test_fallback_summary_includes_all_structured_target_failures_without_credentials() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_execute(self, context): + attempts.append(context.provider) + return { + "error": { + "type": "proxy_all_credentials_exhausted", + "message": f"{context.provider} failed for cred-secret-value", + "details": {"normal_error_summary": "1 rate_limit"}, + } + } + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + response = await executor._execute_non_streaming_with_fallback(_context()) + + fallback_targets = response["error"]["details"]["fallback_targets"] + assert attempts == ["codex", "openai"] + assert [failure["provider"] for failure in fallback_targets] == ["codex", "openai"] + assert [failure["error_type"] for failure in fallback_targets] == ["rate_limit", "rate_limit"] + assert "credential" not in str(fallback_targets).lower() + assert "cred-secret-value" not in str(fallback_targets) From f95466809c0433f8f790ca980e6596f153a35f28 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:17:20 +0200 Subject: [PATCH 050/182] fix(retry): activate cooldowns for stream failures Starts provider cooldowns for streaming failures only when the stream has not emitted visible model output, preserving the no-fallback-after-output invariant while coordinating provider-wide throttles seen during stream setup or error-only streams. Tests: python -m pytest tests/test_cooldown_activation.py tests/test_retry_policy.py tests/test_streaming_fallback_policy.py tests/test_request_executor_fallback_groups.py tests/test_request_executor_fallback_error_summary.py tests/test_request_executor_native_routing.py --- src/rotator_library/client/executor.py | 30 ++++++++++++++++++++++++++ tests/test_cooldown_activation.py | 8 ++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 8496d13c4..0d80ad286 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1196,6 +1196,12 @@ async def _execute_streaming( last_exception = e original = getattr(e, "data", e) classified = classify_error(original, provider) + if _can_start_stream_provider_cooldown( + last_streamed_chunk + ): + await self._maybe_start_provider_cooldown( + provider, classified, context=context + ) log_failure( api_key=cred, model=model, @@ -1281,6 +1287,12 @@ async def _execute_streaming( except (RateLimitError, httpx.HTTPStatusError) as e: last_exception = e classified = classify_error(e, provider) + if _can_start_stream_provider_cooldown( + last_streamed_chunk + ): + await self._maybe_start_provider_cooldown( + provider, classified, context=context + ) log_failure( api_key=cred, model=model, @@ -1356,6 +1368,12 @@ async def _execute_streaming( ) as e: last_exception = e classified = classify_error(e, provider) + if _can_start_stream_provider_cooldown( + last_streamed_chunk + ): + await self._maybe_start_provider_cooldown( + provider, classified, context=context + ) log_failure( api_key=cred, model=model, @@ -1392,6 +1410,12 @@ async def _execute_streaming( except Exception as e: last_exception = e classified = classify_error(e, provider) + if _can_start_stream_provider_cooldown( + last_streamed_chunk + ): + await self._maybe_start_provider_cooldown( + provider, classified, context=context + ) log_failure( api_key=cred, model=model, @@ -2047,3 +2071,9 @@ def _stream_chunk_is_visible_output(chunk: str) -> bool: if isinstance(data, dict) and "error" in data: return False return True + + +def _can_start_stream_provider_cooldown(last_streamed_chunk: Optional[str]) -> bool: + """Return whether a streaming failure occurred before visible output.""" + + return last_streamed_chunk is None or not _stream_chunk_is_visible_output(last_streamed_chunk) diff --git a/tests/test_cooldown_activation.py b/tests/test_cooldown_activation.py index 2a384f4de..a24f16941 100644 --- a/tests/test_cooldown_activation.py +++ b/tests/test_cooldown_activation.py @@ -4,7 +4,7 @@ import pytest -from rotator_library.client.executor import RequestExecutor +from rotator_library.client.executor import RequestExecutor, _can_start_stream_provider_cooldown from rotator_library.cooldown_manager import CooldownManager from rotator_library.core.types import RequestContext from rotator_library.error_handler import ClassifiedError @@ -88,3 +88,9 @@ async def test_small_retry_after_skips_provider_cooldown(monkeypatch) -> None: ) assert cooldown.started == [] + + +def test_streaming_provider_cooldown_gate_allows_only_pre_output_failures() -> None: + assert _can_start_stream_provider_cooldown(None) is True + assert _can_start_stream_provider_cooldown('data: {"error":{"type":"rate_limit"}}\n\n') is True + assert _can_start_stream_provider_cooldown('data: {"choices":[{"delta":{"content":"visible"}}]}\n\n') is False From 4f5d650596798f9d69f5467bc50a30eb0fa5d4d6 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:24:51 +0200 Subject: [PATCH 051/182] docs(experimental): plan streaming library upgrade Adds the Phase 8 plan for a transport-aware streaming library layer, stream lifecycle metrics, visible-output policy, streaming error decisions, native/Responses stream seams, and focused/regression test coverage. --- .../phase-8-streaming-library-upgrade.md | 329 ++++++++++++++++++ 1 file changed, 329 insertions(+) create mode 100644 docs/experimental/phase-8-streaming-library-upgrade.md diff --git a/docs/experimental/phase-8-streaming-library-upgrade.md b/docs/experimental/phase-8-streaming-library-upgrade.md new file mode 100644 index 000000000..86c1c7fbc --- /dev/null +++ b/docs/experimental/phase-8-streaming-library-upgrade.md @@ -0,0 +1,329 @@ +# Phase 8 Plan: Streaming Library Upgrade + +## Goal + +Turn streaming into a reusable, transport-aware library layer instead of scattered executor/provider-specific logic. Phase 8 should preserve the current HTTP SSE behavior, add measured stream lifecycle events, consolidate retry/error handling where safe, expose a WebSocket-ready transport seam, and prepare native provider streaming to share the same policies. It must keep the Phase 6/7 invariants: no fallback after visible output, provider cooldown only before visible output or error-only chunks, and no replacement of `UsageManager`, `SessionTracker`, or retry-after parsing. + +## Non-Goals + +- Do not replace `UsageManager`, `SessionTracker`, `SelectionEngine`, or retry-after parsing. +- Do not implement a live WebSocket route yet unless it is strictly a disabled/placeholder seam. +- Do not migrate every provider to native streaming in one pass. +- Do not alter non-streaming request behavior. +- Do not add SQLite or any database. +- Do not implement full cost/quota overhaul; Phase 9 owns usage/quota/cost. +- Do not change client-visible SSE formats except to fix bugs or add compatible error/end events. +- Do not fallback to another target after visible model output. + +## Current Code Context + +- `RequestExecutor._execute_streaming()` owns the active streaming credential retry loop and still has four duplicated exception branches. +- `RequestExecutor._execute_streaming_with_fallback()` wraps target fallback and already tracks visible output. +- Phase 7 added `_stream_chunk_is_visible_output()` and `_can_start_stream_provider_cooldown()` safety helpers. +- `StreamingHandler.wrap_stream()` converts LiteLLM chunks to chat-completions SSE strings, records usage after completion, records response anchors, and handles client disconnect. +- `stream_retry_policy.py` owns reasoning-only retry safety for post-output stream failures. +- `ResponsesService.stream_response()` converts chat stream chunks into Responses SSE events, stores final streamed responses, and has a `ResponsesWebSocketFormatter` placeholder. +- `NativeProviderExecutor.stream()` can stream mocked native provider JSON-line chunks and trace raw/parsed/formatted events, but live executor streaming does not route through native execution modes yet. +- Transform trace already has stream pass names from Phase 2 and later phases. +- There is no central stream metrics object for TTFB/TTFT/stalls/cancellation. + +## Files To Add + +- `src/rotator_library/streaming/__init__.py` +- `src/rotator_library/streaming/events.py` +- `src/rotator_library/streaming/metrics.py` +- `src/rotator_library/streaming/transport.py` +- `src/rotator_library/streaming/policy.py` +- `src/rotator_library/streaming/errors.py` +- `tests/test_stream_events.py` +- `tests/test_stream_metrics.py` +- `tests/test_stream_transport.py` +- `tests/test_stream_policy.py` +- `tests/test_streaming_error_handler.py` +- `tests/test_request_executor_stream_metrics.py` +- `tests/test_native_streaming_transport_seam.py` + +## Files Likely To Touch + +- `src/rotator_library/client/executor.py` +- `src/rotator_library/client/streaming.py` +- `src/rotator_library/client/stream_retry_policy.py` +- `src/rotator_library/native_provider/executor.py` +- `src/rotator_library/native_provider/streaming.py` +- `src/rotator_library/responses/streaming.py` +- `src/rotator_library/responses/service.py` +- `src/proxy_app/main.py` only if route headers/cancellation handling need a small SSE-compatible fix. + +## Streaming Event Model + +Add `StreamEvent` dataclass: + +- `event_type`: `started`, `raw_chunk`, `parsed_chunk`, `delta`, `reasoning_delta`, `tool_delta`, `usage`, `error`, `completed`, `cancelled`, `heartbeat`, `metadata`. +- `protocol`: `openai_chat`, `responses`, `anthropic_messages`, `gemini`, `native`, `litellm_fallback`, or provider-specific. +- `transport`: `sse`, `websocket`, `jsonl`, or future string. +- `data`: JSON-safe event payload. +- `raw`: optional raw chunk for trace only, sanitized/serialized. +- `metadata`: JSON-safe metadata. +- `visible_output`: bool. +- `timestamp_utc`. + +Add helpers: + +- `stream_event_from_sse_chunk()` +- `stream_event_to_sse()` +- `stream_event_to_websocket_message()` placeholder/seam. + +Keep formatters small and override-friendly. + +## Transport Abstraction + +Add `StreamTransportFormatter` base/protocol: + +- `format_event(event)` +- `format_error(error_event)` +- `format_done()` +- `is_terminal_event(event)` + +Add `SSEStreamFormatter` for existing HTTP SSE output. + +Add `WebSocketStreamFormatter` placeholder that formats JSON messages but is not wired to a route yet. + +Add `JSONLineStreamFormatter` for native provider/internal stream tests if useful. + +Responses API can keep `ResponsesSSEFormatter` but should share the transport interface or adapt to it. + +## Metrics And Lifecycle + +Add `StreamMetrics`: + +- `started_at` +- `first_byte_at` +- `first_visible_output_at` +- `last_chunk_at` +- `completed_at` +- `chunk_count` +- `visible_chunk_count` +- `error_count` +- `cancelled` +- properties: `ttfb_seconds`, `ttft_seconds`, `duration_seconds`, `idle_seconds`. + +Add `StreamMonitor`: + +- records raw chunk, formatted chunk, visible output, errors, completion, cancellation. +- can detect stall if `time_since_last_chunk > stall_timeout`. + +Env knobs: + +- `STREAM_TTFB_TIMEOUT_SECONDS` optional/disabled by default. +- `STREAM_STALL_TIMEOUT_SECONDS` optional/disabled by default. +- `STREAM_HEARTBEAT_SECONDS` optional/disabled by default. + +Phase 8 should add metrics and trace events first. Enforced timeouts can be opt-in to avoid behavior surprises. + +## Stream Policy + +- Move/re-export `can_retry_stream_after_error()` into the new streaming policy package, preserving current behavior. +- Keep `stream_retry_policy.py` as a compatibility wrapper if imports exist. +- Add visible-output detection policy: + - Chat-completions content/tool/function deltas are visible. + - Reasoning-only deltas are not visible for fallback only if the existing env allows reasoning-only retry. + - Error chunks and `[DONE]` are not visible. + - Responses `response.output_text.delta` is visible. + - Responses `response.failed` is not visible by itself. +- Add tests for malformed chunks failing closed. + +## Streaming Error Handling + +Add `StreamingErrorDecision` dataclass: + +- `classified` +- `action`: `retry_same`, `rotate`, `fail`, `fallback_allowed`, `fallback_blocked_after_output` +- `start_provider_cooldown`: bool +- `provider_cooldown_duration` +- `reason` + +Add helper that consumes: + +- exception +- provider +- last_streamed_chunk +- attempt +- max_retries +- deadline +- allow_reasoning_only_retry +- retry/cooldown env settings + +It should use existing `classify_error()`, `should_retry_same_key()`, `should_rotate_on_error()`, and Phase 7 retry policy. + +It should not sleep or mutate credential state; executor still owns those side effects. + +Refactor `_execute_streaming()` gradually: + +- First introduce helper and tests. +- Then replace duplicated branch decision logic where safe. +- Preserve exact output semantics and trace ordering. + +If full deduplication is too risky, use helper for cooldown/visibility/fallback decisions but keep branch structure. + +## Fallback And Cooldown Invariants + +- Fallback to next target is allowed only before visible output. +- Provider cooldown can start only before visible output or after error-only chunks. +- Same-credential retry after reasoning-only chunks remains controlled by `STREAM_RETRY_ON_REASONING_ONLY` behavior from current code. +- If a stream emits visible output then errors: + - no cross-target fallback. + - return/raise the current upstream stream failure behavior. + - emit `routing_stream_fallback_blocked_after_output` when in a fallback group. +- Preserve Phase 7 tests. + +## Native Provider Streaming + +- Add a streaming mode seam to `NativeProviderExecutor` that can return `StreamEvent`s or formatted SSE through the common formatter. +- Live `RequestExecutor` does not need to route native streaming fully unless safe, but the mode should be testable: + - `execution=native` streaming target can call native executor stream when provider declares native streaming support. + - if not supported, fail with `unsupported_operation` before output so fallback can choose next target. +- Add provider interface optional method only if needed: + - `supports_native_streaming(model, operation)` + - default false/no-op. +- This should prepare Claude Code/Codex/Copilot/Antigravity skeletons without claiming live support. + +## Responses Streaming Integration + +- Keep current `/v1/responses` SSE behavior. +- Share metrics/trace helpers where possible. +- Preserve stored final streamed response behavior. +- Add tests: + - Responses stream emits visible-output metrics. + - failed stream records error metrics. + - WebSocket formatter seam can format equivalent event payloads but no route is exposed. +- Do not change stored response shape unless necessary. + +## Transaction Trace Requirements + +Add or standardize stream trace passes: + +- `stream_started` +- `stream_first_byte` +- `stream_first_visible_output` +- `stream_stall_detected` +- `stream_heartbeat_sent` +- `stream_completed` +- `stream_cancelled` +- `stream_error_decision` +- `stream_metrics_final` + +Existing pass names stay: + +- `raw_stream_chunk` +- `parsed_stream_chunk` +- `assembled_stream_response` +- provider/native/responses stream pass names. + +Trace metrics must not include raw credentials or auth headers. + +Raw chunks should continue through existing redaction/serialization. + +## Client Disconnect And Cancellation + +- `StreamingHandler.wrap_stream()` already checks `request.is_disconnected()`. +- Add metrics event for cancellation/disconnect. +- Add trace `stream_cancelled`. +- Do not mark success on cancellation. +- Preserve current behavior for partial streams. + +## Heartbeats And Stall Detection + +Implement monitor primitives and tests first. + +Optional runtime heartbeat support: + +- if `STREAM_HEARTBEAT_SECONDS > 0`, emit SSE comment heartbeat `: keep-alive\n\n` or compatible event only when no provider chunk has arrived. +- default disabled to avoid client behavior changes. + +Optional stall detection: + +- if `STREAM_STALL_TIMEOUT_SECONDS > 0`, classify as transient stream failure before visible output or fail after visible output. +- default disabled. + +Tests can use fake clocks. + +## Testing Plan + +Stream event tests: + +- SSE chunk to event parsing. +- Responses event visibility. +- malformed chunk fails closed. +- event-to-SSE formatting. +- WebSocket formatter seam output shape. + +Metrics tests: + +- TTFB and TTFT calculations. +- chunk counts. +- cancellation. +- stall detection with fake clock. + +Policy tests: + +- current reasoning-only retry behavior preserved. +- visible output detection for text/tool deltas. +- error and done chunks not visible. + +Error handler tests: + +- large retry-after before output starts cooldown decision. +- visible output blocks fallback/cooldown. +- transient errors choose same-key retry before max retry. +- permanent errors fail. + +Executor tests: + +- streaming metrics trace passes emitted. +- cancellation trace emitted. +- pre-output fallback still works. +- post-output fallback remains blocked. +- provider cooldown still starts before output. + +Native stream tests: + +- native streaming emits common stream events/trace. +- unsupported native streaming fails before output and can fallback. + +Responses streaming tests: + +- existing route behavior unchanged. +- metrics final trace emitted. + +Regression tests: + +- Phase 1 protocol tests. +- Phase 2 transform logging tests. +- Phase 3 adapter/cache tests. +- Phase 4 Responses tests. +- Phase 5 provider/native tests. +- Phase 6 routing tests. +- Phase 7 retry/cooldown tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add streaming event/transport/metrics primitives with tests. +2. Move/re-export stream retry/visibility policy with tests. +3. Add streaming error decision helper with tests. +4. Integrate stream metrics and trace passes into `StreamingHandler` and/or executor with tests. +5. Refactor streaming error branches only as far as tests make safe. +6. Add native streaming seam/support detection tests. +7. Add Responses streaming metrics/transport seam tests. +8. Run focused and regression tests. +9. Review with `explore` and `explore-heavy`; fix findings; write uncommitted Phase 8 report. + +## Risks And Mitigations + +- Stream behavior is client-visible. Mitigation: keep SSE output format unchanged by default and add primitives before enforcement. +- Timeout/stall handling can break long reasoning streams. Mitigation: default stall/TTFB enforcement disabled; only metrics on by default. +- Refactoring executor streaming can regress retry semantics. Mitigation: keep branch structure unless shared helper is proven by tests. +- WebSocket support can be over-promised. Mitigation: formatter seam only, no route unless explicitly implemented later. +- Visible-output detection can be too permissive. Mitigation: fail closed on malformed/ambiguous chunks and preserve current tests. +- Metrics can leak data if raw chunks are logged. Mitigation: trace summaries and use existing redaction/serialization paths. From 9bd1ef74df12b8bcd8ad54131a953e1320a2a725 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:26:37 +0200 Subject: [PATCH 052/182] feat(streaming): add stream event primitives Adds transport-neutral stream events, SSE/WebSocket/JSONL formatters, and stream lifecycle metrics/monitoring primitives without changing executor runtime behavior. Tests cover visible chunk detection, malformed/error/done chunks, formatter output, TTFB/TTFT metrics, cancellation, and stall detection. Tests: python -m pytest tests/test_stream_events.py tests/test_stream_transport.py tests/test_stream_metrics.py --- src/rotator_library/streaming/__init__.py | 18 ++++ src/rotator_library/streaming/events.py | 110 +++++++++++++++++++++ src/rotator_library/streaming/metrics.py | 97 ++++++++++++++++++ src/rotator_library/streaming/transport.py | 65 ++++++++++++ tests/test_stream_events.py | 27 +++++ tests/test_stream_metrics.py | 42 ++++++++ tests/test_stream_transport.py | 24 +++++ 7 files changed, 383 insertions(+) create mode 100644 src/rotator_library/streaming/__init__.py create mode 100644 src/rotator_library/streaming/events.py create mode 100644 src/rotator_library/streaming/metrics.py create mode 100644 src/rotator_library/streaming/transport.py create mode 100644 tests/test_stream_events.py create mode 100644 tests/test_stream_metrics.py create mode 100644 tests/test_stream_transport.py diff --git a/src/rotator_library/streaming/__init__.py b/src/rotator_library/streaming/__init__.py new file mode 100644 index 000000000..785937e52 --- /dev/null +++ b/src/rotator_library/streaming/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Reusable streaming primitives shared by protocol and provider layers.""" + +from .events import StreamEvent, stream_event_from_sse_chunk +from .metrics import StreamMetrics, StreamMonitor +from .transport import JSONLineStreamFormatter, SSEStreamFormatter, WebSocketStreamFormatter + +__all__ = [ + "JSONLineStreamFormatter", + "SSEStreamFormatter", + "StreamEvent", + "StreamMetrics", + "StreamMonitor", + "WebSocketStreamFormatter", + "stream_event_from_sse_chunk", +] diff --git a/src/rotator_library/streaming/events.py b/src/rotator_library/streaming/events.py new file mode 100644 index 000000000..294b6b267 --- /dev/null +++ b/src/rotator_library/streaming/events.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Transport-neutral stream event model.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import UTC, datetime +import json +from typing import Any, Optional + +from ..protocols import serialize_value + + +@dataclass(frozen=True) +class StreamEvent: + """A normalized stream event before transport-specific formatting. + + The event model is intentionally broad so providers can attach native data + without inventing a new stream lifecycle object. Raw data is for trace/debug + use only and should still pass through transaction-log redaction when logged. + """ + + event_type: str + data: Any = field(default_factory=dict) + protocol: Optional[str] = None + transport: str = "sse" + raw: Any = None + metadata: dict[str, Any] = field(default_factory=dict) + visible_output: bool = False + timestamp_utc: str = field(default_factory=lambda: datetime.now(UTC).isoformat()) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-safe representation for traces and transports.""" + + return { + "event_type": self.event_type, + "protocol": self.protocol, + "transport": self.transport, + "data": serialize_value(self.data), + "raw": serialize_value(self.raw), + "metadata": serialize_value(self.metadata), + "visible_output": self.visible_output, + "timestamp_utc": self.timestamp_utc, + } + + +def stream_event_from_sse_chunk(chunk: Any, *, protocol: str = "openai_chat") -> StreamEvent: + """Parse an SSE `data:` chunk into a normalized stream event. + + Malformed or non-JSON chunks are treated as metadata and fail closed for + visibility. This avoids accidental fallback after ambiguous client output. + """ + + if isinstance(chunk, dict): + data = chunk + raw = chunk + elif isinstance(chunk, str): + raw = chunk + text = chunk.strip() + if text.startswith("data:"): + text = text[len("data:") :].strip() + if text == "[DONE]": + return StreamEvent("completed", protocol=protocol, raw=chunk, data={"done": True}) + try: + data = json.loads(text) + except json.JSONDecodeError: + return StreamEvent("metadata", protocol=protocol, raw=chunk, data={"malformed": True}) + else: + return StreamEvent("metadata", protocol=protocol, raw=chunk, data={"unsupported_chunk": True}) + + if isinstance(data, dict) and data.get("error"): + return StreamEvent("error", protocol=protocol, raw=raw, data=data, visible_output=False) + visible = _openai_chat_visible(data) if protocol == "openai_chat" else False + event_type = "delta" if visible else "parsed_chunk" + return StreamEvent(event_type, protocol=protocol, raw=raw, data=data, visible_output=visible) + + +def _openai_chat_visible(data: Any) -> bool: + if not isinstance(data, dict): + return False + choices = data.get("choices") + if not isinstance(choices, list): + return False + for choice in choices: + if not isinstance(choice, dict): + continue + delta = choice.get("delta") or {} + message = choice.get("message") or {} + for source in (delta, message): + if not isinstance(source, dict): + continue + if _has_visible_text(source.get("content")) or _has_visible_text(source.get("text")): + return True + if source.get("tool_calls") or source.get("function_call"): + return True + return False + + +def _has_visible_text(value: Any) -> bool: + if isinstance(value, str): + return bool(value.strip()) + if isinstance(value, list): + for item in value: + if isinstance(item, str) and item.strip(): + return True + if isinstance(item, dict) and _has_visible_text(item.get("text")): + return True + return False diff --git a/src/rotator_library/streaming/metrics.py b/src/rotator_library/streaming/metrics.py new file mode 100644 index 000000000..64470d700 --- /dev/null +++ b/src/rotator_library/streaming/metrics.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Stream lifecycle metrics and stall monitoring.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Optional + +from .events import StreamEvent + + +Clock = Callable[[], float] + + +@dataclass +class StreamMetrics: + """Timing and lifecycle counters for one stream.""" + + started_at: float + first_byte_at: Optional[float] = None + first_visible_output_at: Optional[float] = None + last_chunk_at: Optional[float] = None + completed_at: Optional[float] = None + chunk_count: int = 0 + visible_chunk_count: int = 0 + error_count: int = 0 + cancelled: bool = False + + @property + def ttfb_seconds(self) -> Optional[float]: + return None if self.first_byte_at is None else self.first_byte_at - self.started_at + + @property + def ttft_seconds(self) -> Optional[float]: + return None if self.first_visible_output_at is None else self.first_visible_output_at - self.started_at + + @property + def duration_seconds(self) -> Optional[float]: + return None if self.completed_at is None else self.completed_at - self.started_at + + @property + def idle_seconds(self) -> Optional[float]: + if self.last_chunk_at is None or self.completed_at is None: + return None + return self.completed_at - self.last_chunk_at + + def to_dict(self) -> dict: + return { + "started_at": self.started_at, + "first_byte_at": self.first_byte_at, + "first_visible_output_at": self.first_visible_output_at, + "last_chunk_at": self.last_chunk_at, + "completed_at": self.completed_at, + "chunk_count": self.chunk_count, + "visible_chunk_count": self.visible_chunk_count, + "error_count": self.error_count, + "cancelled": self.cancelled, + "ttfb_seconds": self.ttfb_seconds, + "ttft_seconds": self.ttft_seconds, + "duration_seconds": self.duration_seconds, + "idle_seconds": self.idle_seconds, + } + + +class StreamMonitor: + """Record stream lifecycle events without changing stream behavior.""" + + def __init__(self, *, clock: Clock) -> None: + self._clock = clock + self.metrics = StreamMetrics(started_at=clock()) + + def record_event(self, event: StreamEvent) -> None: + now = self._clock() + if self.metrics.first_byte_at is None: + self.metrics.first_byte_at = now + self.metrics.last_chunk_at = now + self.metrics.chunk_count += 1 + if event.visible_output: + self.metrics.visible_chunk_count += 1 + if self.metrics.first_visible_output_at is None: + self.metrics.first_visible_output_at = now + if event.event_type == "error": + self.metrics.error_count += 1 + + def complete(self) -> None: + self.metrics.completed_at = self._clock() + + def cancel(self) -> None: + self.metrics.cancelled = True + self.metrics.completed_at = self._clock() + + def is_stalled(self, timeout_seconds: float) -> bool: + if timeout_seconds <= 0 or self.metrics.last_chunk_at is None: + return False + return self._clock() - self.metrics.last_chunk_at > timeout_seconds diff --git a/src/rotator_library/streaming/transport.py b/src/rotator_library/streaming/transport.py new file mode 100644 index 000000000..7472757df --- /dev/null +++ b/src/rotator_library/streaming/transport.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Transport formatters for normalized stream events.""" + +from __future__ import annotations + +import json + +from .events import StreamEvent + + +class SSEStreamFormatter: + """Format stream events as HTTP Server-Sent Events.""" + + transport = "sse" + + def format_event(self, event: StreamEvent) -> str: + return f"event: {event.event_type}\ndata: {json.dumps(event.to_dict(), ensure_ascii=False)}\n\n" + + def format_error(self, event: StreamEvent) -> str: + return self.format_event(event) + + def format_done(self) -> str: + return "data: [DONE]\n\n" + + def is_terminal_event(self, event: StreamEvent) -> bool: + return event.event_type in {"completed", "cancelled", "error"} + + +class WebSocketStreamFormatter: + """Format stream events as WebSocket JSON messages without exposing a route.""" + + transport = "websocket" + future_supported = True + + def format_event(self, event: StreamEvent) -> dict: + return {"type": event.event_type, "payload": event.to_dict()} + + def format_error(self, event: StreamEvent) -> dict: + return self.format_event(event) + + def format_done(self) -> dict: + return {"type": "completed", "payload": {"done": True}} + + def is_terminal_event(self, event: StreamEvent) -> bool: + return event.event_type in {"completed", "cancelled", "error"} + + +class JSONLineStreamFormatter: + """Format stream events as newline-delimited JSON for provider tests.""" + + transport = "jsonl" + + def format_event(self, event: StreamEvent) -> str: + return json.dumps(event.to_dict(), ensure_ascii=False) + "\n" + + def format_error(self, event: StreamEvent) -> str: + return self.format_event(event) + + def format_done(self) -> str: + return json.dumps({"event_type": "completed", "done": True}) + "\n" + + def is_terminal_event(self, event: StreamEvent) -> bool: + return event.event_type in {"completed", "cancelled", "error"} diff --git a/tests/test_stream_events.py b/tests/test_stream_events.py new file mode 100644 index 000000000..365e9cfbe --- /dev/null +++ b/tests/test_stream_events.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from rotator_library.streaming import stream_event_from_sse_chunk + + +def test_stream_event_from_sse_chunk_detects_visible_chat_delta() -> None: + event = stream_event_from_sse_chunk('data: {"choices":[{"delta":{"content":"hi"}}]}\n\n') + + assert event.event_type == "delta" + assert event.visible_output is True + + +def test_stream_event_from_sse_chunk_treats_error_and_done_as_not_visible() -> None: + error_event = stream_event_from_sse_chunk('data: {"error":{"type":"rate_limit"}}\n\n') + done_event = stream_event_from_sse_chunk("data: [DONE]\n\n") + + assert error_event.event_type == "error" + assert error_event.visible_output is False + assert done_event.event_type == "completed" + assert done_event.visible_output is False + + +def test_stream_event_from_sse_chunk_malformed_fails_closed() -> None: + event = stream_event_from_sse_chunk("data: not-json\n\n") + + assert event.event_type == "metadata" + assert event.visible_output is False diff --git a/tests/test_stream_metrics.py b/tests/test_stream_metrics.py new file mode 100644 index 000000000..5f082f0c1 --- /dev/null +++ b/tests/test_stream_metrics.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from rotator_library.streaming import StreamEvent, StreamMonitor + + +class FakeClock: + def __init__(self) -> None: + self.now = 0.0 + + def __call__(self) -> float: + return self.now + + def advance(self, seconds: float) -> None: + self.now += seconds + + +def test_stream_metrics_records_ttfb_ttft_and_counts() -> None: + clock = FakeClock() + monitor = StreamMonitor(clock=clock) + clock.advance(0.5) + monitor.record_event(StreamEvent("parsed_chunk")) + clock.advance(0.7) + monitor.record_event(StreamEvent("delta", visible_output=True)) + clock.advance(0.3) + monitor.complete() + + assert monitor.metrics.ttfb_seconds == 0.5 + assert monitor.metrics.ttft_seconds == 1.2 + assert monitor.metrics.duration_seconds == 1.5 + assert monitor.metrics.chunk_count == 2 + assert monitor.metrics.visible_chunk_count == 1 + + +def test_stream_monitor_records_cancel_and_stall() -> None: + clock = FakeClock() + monitor = StreamMonitor(clock=clock) + monitor.record_event(StreamEvent("parsed_chunk")) + clock.advance(5) + + assert monitor.is_stalled(2) is True + monitor.cancel() + assert monitor.metrics.cancelled is True diff --git a/tests/test_stream_transport.py b/tests/test_stream_transport.py new file mode 100644 index 000000000..be377f13f --- /dev/null +++ b/tests/test_stream_transport.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from rotator_library.streaming import JSONLineStreamFormatter, SSEStreamFormatter, StreamEvent, WebSocketStreamFormatter + + +def test_sse_formatter_outputs_named_event() -> None: + formatted = SSEStreamFormatter().format_event(StreamEvent("delta", data={"text": "hi"}, visible_output=True)) + + assert formatted.startswith("event: delta\n") + assert "data:" in formatted + + +def test_websocket_formatter_exposes_future_transport_shape() -> None: + formatted = WebSocketStreamFormatter().format_event(StreamEvent("delta", data={"text": "hi"})) + + assert formatted["type"] == "delta" + assert formatted["payload"]["transport"] == "sse" + + +def test_jsonl_formatter_outputs_one_line_json() -> None: + formatted = JSONLineStreamFormatter().format_event(StreamEvent("metadata", data={"ok": True})) + + assert formatted.endswith("\n") + assert "metadata" in formatted From e756795aff392eb35a806d8962458dc719b172b1 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:29:30 +0200 Subject: [PATCH 053/182] feat(streaming): centralize stream retry policy Moves stream retry and visible-output policy into the new streaming package while keeping the old client stream_retry_policy import path as a compatibility wrapper. Executor fallback visibility now delegates to the shared policy, preserving conservative malformed-chunk behavior and reasoning-only retry semantics. Tests: python -m pytest tests/test_stream_policy.py tests/test_stream_events.py tests/test_cooldown_activation.py tests/test_streaming_fallback_policy.py tests/test_request_executor_fallback_groups.py tests/test_request_executor_fallback_error_summary.py tests/test_retry_policy.py --- src/rotator_library/client/executor.py | 17 +-- .../client/stream_retry_policy.py | 76 +---------- src/rotator_library/streaming/__init__.py | 3 + src/rotator_library/streaming/policy.py | 123 ++++++++++++++++++ tests/test_stream_policy.py | 28 ++++ 5 files changed, 159 insertions(+), 88 deletions(-) create mode 100644 src/rotator_library/streaming/policy.py create mode 100644 tests/test_stream_policy.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 0d80ad286..5e2607bb6 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -71,7 +71,7 @@ from .filters import CredentialFilter from .transforms import ProviderTransforms from .streaming import StreamingHandler -from .stream_retry_policy import can_retry_stream_after_error +from ..streaming.policy import can_retry_stream_after_error, is_visible_stream_output if TYPE_CHECKING: from ..usage import UsageManager @@ -2057,20 +2057,7 @@ def _stream_chunk_is_visible_output(chunk: str) -> bool: next ordered target. """ - text = chunk.strip() - if not text or text == "data: [DONE]": - return False - if text.startswith("data:"): - payload = text[len("data:") :].strip() - if not payload or payload == "[DONE]": - return False - try: - data = json.loads(payload) - except json.JSONDecodeError: - return True - if isinstance(data, dict) and "error" in data: - return False - return True + return is_visible_stream_output(chunk) def _can_start_stream_provider_cooldown(last_streamed_chunk: Optional[str]) -> bool: diff --git a/src/rotator_library/client/stream_retry_policy.py b/src/rotator_library/client/stream_retry_policy.py index ccf0a593c..c776d5cdd 100644 --- a/src/rotator_library/client/stream_retry_policy.py +++ b/src/rotator_library/client/stream_retry_policy.py @@ -1,78 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -"""Retry-safety policy for stream errors after output has started.""" +"""Backward-compatible import path for stream retry policy.""" -import json -from typing import Any, Optional +from ..streaming.policy import can_retry_stream_after_error - -_REASONING_FIELDS = ( - "reasoning", - "reasoning_content", - "thinking", - "thinking_content", -) - - -def can_retry_stream_after_error( - last_streamed_chunk: Optional[str], - allow_reasoning_only_retry: bool, -) -> bool: - """ - Return whether a stream can be retried after an upstream error. - - Retrying is always safe before any chunk has been emitted. After that, it is - only allowed when explicitly enabled and the latest chunk is clearly - reasoning/thinking-only. Ambiguous chunks fail closed. - """ - if last_streamed_chunk is None: - return True - if not allow_reasoning_only_retry: - return False - - payload = last_streamed_chunk.strip() - if not payload.startswith("data:"): - return False - payload = payload[5:].strip() - if not payload or payload == "[DONE]": - return False - try: - data = json.loads(payload) - except json.JSONDecodeError: - return False - - has_reasoning = False - choices = data.get("choices") - if not isinstance(choices, list): - return False - - for choice in choices: - if not isinstance(choice, dict): - return False - for source in (choice, choice.get("delta"), choice.get("message")): - if not isinstance(source, dict): - continue - if ( - _has_text(source.get("content")) - or _has_text(source.get("text")) - or source.get("tool_calls") - or source.get("function_call") - ): - return False - if any(_has_text(source.get(key)) for key in _REASONING_FIELDS): - has_reasoning = True - - return has_reasoning - - -def _has_text(value: Any) -> bool: - if isinstance(value, str): - return bool(value.strip()) - if isinstance(value, list): - for item in value: - if isinstance(item, str) and item.strip(): - return True - if isinstance(item, dict) and _has_text(item.get("text")): - return True - return False +__all__ = ["can_retry_stream_after_error"] diff --git a/src/rotator_library/streaming/__init__.py b/src/rotator_library/streaming/__init__.py index 785937e52..9acaff50e 100644 --- a/src/rotator_library/streaming/__init__.py +++ b/src/rotator_library/streaming/__init__.py @@ -5,6 +5,7 @@ from .events import StreamEvent, stream_event_from_sse_chunk from .metrics import StreamMetrics, StreamMonitor +from .policy import can_retry_stream_after_error, is_visible_stream_output from .transport import JSONLineStreamFormatter, SSEStreamFormatter, WebSocketStreamFormatter __all__ = [ @@ -14,5 +15,7 @@ "StreamMetrics", "StreamMonitor", "WebSocketStreamFormatter", + "can_retry_stream_after_error", + "is_visible_stream_output", "stream_event_from_sse_chunk", ] diff --git a/src/rotator_library/streaming/policy.py b/src/rotator_library/streaming/policy.py new file mode 100644 index 000000000..d69bca69a --- /dev/null +++ b/src/rotator_library/streaming/policy.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Stream retry and visible-output policy.""" + +from __future__ import annotations + +import json +from typing import Any, Optional + +_REASONING_FIELDS = ( + "reasoning", + "reasoning_content", + "thinking", + "thinking_content", +) + + +def can_retry_stream_after_error(last_streamed_chunk: Optional[str], allow_reasoning_only_retry: bool) -> bool: + """Return whether an upstream stream can be retried after an error.""" + + if last_streamed_chunk is None: + return True + if not allow_reasoning_only_retry: + return False + data = _sse_json(last_streamed_chunk, malformed_is_visible=False) + if data is None: + return False + + has_reasoning = False + choices = data.get("choices") + if not isinstance(choices, list): + return False + for choice in choices: + if not isinstance(choice, dict): + return False + for source in (choice, choice.get("delta"), choice.get("message")): + if not isinstance(source, dict): + continue + if _has_visible_text(source.get("content")) or _has_visible_text(source.get("text")): + return False + if source.get("tool_calls") or source.get("function_call"): + return False + if any(_has_visible_text(source.get(key)) for key in _REASONING_FIELDS): + has_reasoning = True + return has_reasoning + + +def is_visible_stream_output(chunk: Optional[str], *, protocol: str = "openai_chat") -> bool: + """Return whether a formatted stream chunk should block fallback. + + Malformed or ambiguous chunks fail closed by counting as visible output. This + preserves the existing safety rule that route fallback must not happen after + a client may have received model output. + """ + + if chunk is None: + return False + data = _sse_json(chunk, malformed_is_visible=True) + if data is _MALFORMED_VISIBLE: + return True + if data is None: + return False + if data.get("error"): + return False + if protocol == "responses": + return _responses_visible(data) + return _openai_chat_visible(data) + + +_MALFORMED_VISIBLE = object() + + +def _sse_json(chunk: str, *, malformed_is_visible: bool) -> dict[str, Any] | object | None: + payload = chunk.strip() + if not payload.startswith("data:"): + return _MALFORMED_VISIBLE if malformed_is_visible and payload else None + payload = payload[5:].strip() + if not payload or payload == "[DONE]": + return None + try: + parsed = json.loads(payload) + except json.JSONDecodeError: + return _MALFORMED_VISIBLE if malformed_is_visible else None + return parsed if isinstance(parsed, dict) else (_MALFORMED_VISIBLE if malformed_is_visible else None) + + +def _openai_chat_visible(data: dict[str, Any]) -> bool: + choices = data.get("choices") + if not isinstance(choices, list): + return False + for choice in choices: + if not isinstance(choice, dict): + continue + for source in (choice.get("delta"), choice.get("message")): + if not isinstance(source, dict): + continue + if _has_visible_text(source.get("content")) or _has_visible_text(source.get("text")): + return True + if source.get("tool_calls") or source.get("function_call"): + return True + return False + + +def _responses_visible(data: dict[str, Any]) -> bool: + event_type = data.get("event_type") or data.get("type") + if event_type == "response.output_text.delta": + return bool(str(data.get("delta", "")).strip()) + if event_type == "response.failed": + return False + return False + + +def _has_visible_text(value: Any) -> bool: + if isinstance(value, str): + return bool(value.strip()) + if isinstance(value, list): + for item in value: + if isinstance(item, str) and item.strip(): + return True + if isinstance(item, dict) and _has_visible_text(item.get("text")): + return True + return False diff --git a/tests/test_stream_policy.py b/tests/test_stream_policy.py new file mode 100644 index 000000000..a7e1a9397 --- /dev/null +++ b/tests/test_stream_policy.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from rotator_library.client.stream_retry_policy import can_retry_stream_after_error as compat_retry_policy +from rotator_library.streaming.policy import can_retry_stream_after_error, is_visible_stream_output + + +def test_reasoning_only_retry_policy_is_preserved() -> None: + reasoning_chunk = 'data: {"choices":[{"delta":{"reasoning_content":"thinking"}}]}\n\n' + text_chunk = 'data: {"choices":[{"delta":{"content":"visible"}}]}\n\n' + + assert can_retry_stream_after_error(None, False) is True + assert can_retry_stream_after_error(reasoning_chunk, True) is True + assert can_retry_stream_after_error(reasoning_chunk, False) is False + assert can_retry_stream_after_error(text_chunk, True) is False + assert compat_retry_policy(reasoning_chunk, True) is True + + +def test_visible_output_detection_for_chat_chunks() -> None: + assert is_visible_stream_output('data: {"choices":[{"delta":{"content":"hello"}}]}\n\n') is True + assert is_visible_stream_output('data: {"choices":[{"delta":{"tool_calls":[{"id":"call_1"}]}}]}\n\n') is True + assert is_visible_stream_output('data: {"error":{"type":"rate_limit"}}\n\n') is False + assert is_visible_stream_output("data: [DONE]\n\n") is False + assert is_visible_stream_output("not-sse") is True + + +def test_visible_output_detection_for_responses_events() -> None: + assert is_visible_stream_output('data: {"event_type":"response.output_text.delta","delta":"hi"}\n\n', protocol="responses") is True + assert is_visible_stream_output('data: {"event_type":"response.failed","error":{"message":"x"}}\n\n', protocol="responses") is False From b6760d3c2cb8a90ad28b41d0f482084131af4316 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:30:41 +0200 Subject: [PATCH 054/182] feat(streaming): add stream error decisions Adds a side-effect-free StreamingErrorDecision helper that centralizes stream error classification, same-key retry, rotation, provider cooldown eligibility, and visible-output blocking while delegating to existing classifier and retry policy helpers. Tests cover small retry-after retry, large retry-after cooldown decisions, visible-output blocking, and reasoning-only retry behavior. Tests: python -m pytest tests/test_streaming_error_handler.py tests/test_stream_policy.py tests/test_retry_policy.py tests/test_cooldown_activation.py --- src/rotator_library/streaming/__init__.py | 3 + src/rotator_library/streaming/errors.py | 96 +++++++++++++++++++++++ tests/test_streaming_error_handler.py | 77 ++++++++++++++++++ 3 files changed, 176 insertions(+) create mode 100644 src/rotator_library/streaming/errors.py create mode 100644 tests/test_streaming_error_handler.py diff --git a/src/rotator_library/streaming/__init__.py b/src/rotator_library/streaming/__init__.py index 9acaff50e..209c068a8 100644 --- a/src/rotator_library/streaming/__init__.py +++ b/src/rotator_library/streaming/__init__.py @@ -4,6 +4,7 @@ """Reusable streaming primitives shared by protocol and provider layers.""" from .events import StreamEvent, stream_event_from_sse_chunk +from .errors import StreamingErrorDecision, decide_streaming_error_action from .metrics import StreamMetrics, StreamMonitor from .policy import can_retry_stream_after_error, is_visible_stream_output from .transport import JSONLineStreamFormatter, SSEStreamFormatter, WebSocketStreamFormatter @@ -12,10 +13,12 @@ "JSONLineStreamFormatter", "SSEStreamFormatter", "StreamEvent", + "StreamingErrorDecision", "StreamMetrics", "StreamMonitor", "WebSocketStreamFormatter", "can_retry_stream_after_error", + "decide_streaming_error_action", "is_visible_stream_output", "stream_event_from_sse_chunk", ] diff --git a/src/rotator_library/streaming/errors.py b/src/rotator_library/streaming/errors.py new file mode 100644 index 000000000..83c9624b5 --- /dev/null +++ b/src/rotator_library/streaming/errors.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Side-effect-free decisions for streaming failures.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from ..error_handler import ClassifiedError, classify_error, should_retry_same_key, should_rotate_on_error +from ..retry_policy import ProviderCooldownDecision, decide_provider_cooldown +from .policy import can_retry_stream_after_error, is_visible_stream_output + + +@dataclass(frozen=True) +class StreamingErrorDecision: + """Decision returned by streaming error policy before executor side effects.""" + + classified: ClassifiedError + action: str + start_provider_cooldown: bool = False + provider_cooldown_duration: int = 0 + reason: str = "" + + +def decide_streaming_error_action( + error: Exception, + *, + provider: str, + last_streamed_chunk: str | None, + attempt: int, + max_retries: int, + small_cooldown_threshold: int, + provider_cooldown_min_seconds: int, + provider_cooldown_default_seconds: int, + cooldown_on_quota: bool = False, + allow_reasoning_only_retry: bool = False, +) -> StreamingErrorDecision: + """Classify a stream failure without sleeping or mutating state. + + The executor remains responsible for logging, credential state, sleeping, + provider cooldown mutation, and fallback. This helper only makes the same + decision consistently across streaming exception branches. + """ + + classified = classify_error(error, provider) + cooldown = _cooldown_decision( + classified, + small_cooldown_threshold=small_cooldown_threshold, + provider_cooldown_min_seconds=provider_cooldown_min_seconds, + provider_cooldown_default_seconds=provider_cooldown_default_seconds, + cooldown_on_quota=cooldown_on_quota, + last_streamed_chunk=last_streamed_chunk, + ) + if not should_rotate_on_error(classified): + return _decision(classified, "fail", cooldown, "non_rotatable") + if not can_retry_stream_after_error(last_streamed_chunk, allow_reasoning_only_retry): + return _decision(classified, "fallback_blocked_after_output", cooldown, "visible_output") + if should_retry_same_key(classified, small_cooldown_threshold) and attempt < max_retries - 1: + return _decision(classified, "retry_same", cooldown, "retry_same_credential") + return _decision(classified, "rotate", cooldown, "rotate_credential") + + +def _cooldown_decision( + classified: ClassifiedError, + *, + small_cooldown_threshold: int, + provider_cooldown_min_seconds: int, + provider_cooldown_default_seconds: int, + cooldown_on_quota: bool, + last_streamed_chunk: str | None, +) -> ProviderCooldownDecision: + if is_visible_stream_output(last_streamed_chunk): + return ProviderCooldownDecision(False, reason="visible_output") + return decide_provider_cooldown( + classified, + small_cooldown_threshold=small_cooldown_threshold, + provider_cooldown_min_seconds=provider_cooldown_min_seconds, + default_duration=provider_cooldown_default_seconds, + cooldown_on_quota=cooldown_on_quota, + ) + + +def _decision( + classified: ClassifiedError, + action: str, + cooldown: ProviderCooldownDecision, + reason: str, +) -> StreamingErrorDecision: + return StreamingErrorDecision( + classified=classified, + action=action, + start_provider_cooldown=cooldown.should_start, + provider_cooldown_duration=cooldown.duration, + reason=reason if not cooldown.should_start else f"{reason};cooldown:{cooldown.reason}", + ) diff --git a/tests/test_streaming_error_handler.py b/tests/test_streaming_error_handler.py new file mode 100644 index 000000000..c60b822e9 --- /dev/null +++ b/tests/test_streaming_error_handler.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from litellm import RateLimitError + +from rotator_library.streaming.errors import decide_streaming_error_action + + +def _rate_limit(retry_after: int | None = None) -> Exception: + error = RateLimitError("rate limited", llm_provider="openai", model="gpt-test") + if retry_after is not None: + error.retry_after = retry_after + return error + + +def test_streaming_error_decision_retries_same_key_for_small_retry_after() -> None: + decision = decide_streaming_error_action( + _rate_limit(3), + provider="openai", + last_streamed_chunk=None, + attempt=0, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.action == "retry_same" + assert decision.start_provider_cooldown is False + + +def test_streaming_error_decision_starts_cooldown_before_visible_output() -> None: + decision = decide_streaming_error_action( + _rate_limit(60), + provider="openai", + last_streamed_chunk=None, + attempt=1, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.action == "rotate" + assert decision.start_provider_cooldown is True + assert decision.provider_cooldown_duration == 60 + + +def test_streaming_error_decision_blocks_after_visible_output_and_skips_cooldown() -> None: + decision = decide_streaming_error_action( + _rate_limit(60), + provider="openai", + last_streamed_chunk='data: {"choices":[{"delta":{"content":"hi"}}]}\n\n', + attempt=0, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.action == "fallback_blocked_after_output" + assert decision.start_provider_cooldown is False + + +def test_streaming_error_decision_allows_reasoning_only_retry_when_enabled() -> None: + decision = decide_streaming_error_action( + _rate_limit(3), + provider="openai", + last_streamed_chunk='data: {"choices":[{"delta":{"reasoning_content":"thinking"}}]}\n\n', + attempt=0, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + allow_reasoning_only_retry=True, + ) + + assert decision.action == "retry_same" From 1eb3868a6567dc492c02aab3fba3296b23094200 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:33:27 +0200 Subject: [PATCH 055/182] feat(streaming): trace stream lifecycle metrics Adds additive stream lifecycle metrics and transform trace passes to StreamingHandler without changing emitted SSE chunks. The executor passes the transaction logger through so streams now record start, first byte, first visible output, completion, and final metrics. Tests: python -m pytest tests/test_request_executor_stream_metrics.py tests/test_stream_metrics.py tests/test_stream_events.py tests/test_streaming_fallback_policy.py tests/test_cooldown_activation.py tests/test_responses_streaming.py tests/test_native_provider_streaming.py --- src/rotator_library/client/executor.py | 1 + src/rotator_library/client/streaming.py | 86 +++++++++++++++++++ tests/test_request_executor_stream_metrics.py | 33 +++++++ 3 files changed, 120 insertions(+) create mode 100644 tests/test_request_executor_stream_metrics.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 5e2607bb6..f5052c1a6 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1168,6 +1168,7 @@ async def _execute_streaming( response_callback=lambda response: self._record_session_response( context, response ), + transaction_logger=context.transaction_logger, ) lib_logger.info( diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 01900b50d..a5769f1df 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -17,6 +17,7 @@ import json import logging import re +import time from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, TYPE_CHECKING import litellm @@ -24,6 +25,7 @@ from ..core.errors import StreamedAPIError, CredentialNeedsReauthError from ..core.types import ProcessedChunk from ..core.utils import normalize_usage_for_response +from ..streaming import StreamEvent, StreamMonitor, stream_event_from_sse_chunk if TYPE_CHECKING: from ..usage.manager import CredentialContext @@ -50,6 +52,7 @@ async def wrap_stream( cred_context: Optional["CredentialContext"] = None, skip_cost_calculation: bool = False, response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + transaction_logger: Optional[Any] = None, ) -> AsyncGenerator[str, None]: """ Wrap a LiteLLM stream with error handling and usage tracking. @@ -81,6 +84,13 @@ async def wrap_stream( thinking_tokens = 0 assistant_parts: List[str] = [] tool_call_ids: List[str] = [] + monitor = StreamMonitor(clock=time.monotonic) + self._log_stream_lifecycle( + transaction_logger, + "stream_started", + monitor, + StreamEvent("started", protocol="openai_chat"), + ) # Use manual iteration to allow continue after partial JSON errors stream_iterator = stream.__aiter__() @@ -97,6 +107,20 @@ async def wrap_stream( chunk = await stream_iterator.__anext__() + raw_event = StreamEvent( + "raw_chunk", + protocol="openai_chat", + raw=chunk, + data={"chunk_index": monitor.metrics.chunk_count + 1}, + ) + if monitor.metrics.first_byte_at is None: + self._log_stream_lifecycle( + transaction_logger, + "stream_first_byte", + monitor, + raw_event, + ) + # Clear error buffer on successful chunk receipt error_buffer.reset() @@ -112,6 +136,19 @@ async def wrap_stream( assistant_parts, tool_call_ids, ) + event = stream_event_from_sse_chunk(processed.sse_string) + first_visible = ( + event.visible_output + and monitor.metrics.first_visible_output_at is None + ) + monitor.record_event(event) + if first_visible: + self._log_stream_lifecycle( + transaction_logger, + "stream_first_visible_output", + monitor, + event, + ) # Update tracking state if processed.has_tool_calls: @@ -221,6 +258,7 @@ async def wrap_stream( ) # Not a JSON-related error, re-raise + monitor.metrics.error_count += 1 raise except StreamedAPIError: @@ -268,9 +306,57 @@ async def wrap_stream( } ) + monitor.complete() + self._log_stream_lifecycle( + transaction_logger, + "stream_completed", + monitor, + StreamEvent("completed", protocol="openai_chat"), + ) + self._log_stream_lifecycle( + transaction_logger, + "stream_metrics_final", + monitor, + StreamEvent("metadata", protocol="openai_chat"), + ) + # Yield [DONE] for completed streams yield "data: [DONE]\n\n" + elif request and await request.is_disconnected(): + monitor.cancel() + self._log_stream_lifecycle( + transaction_logger, + "stream_cancelled", + monitor, + StreamEvent("cancelled", protocol="openai_chat"), + ) + + @staticmethod + def _log_stream_lifecycle( + transaction_logger: Optional[Any], + pass_name: str, + monitor: StreamMonitor, + event: StreamEvent, + ) -> None: + """Emit stream lifecycle metrics without affecting stream delivery.""" + + if not transaction_logger: + return + try: + transaction_logger.log_transform_pass( + pass_name, + {"event": event.to_dict(), "metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="client", + protocol=event.protocol, + transport="sse", + metadata={"event_type": event.event_type}, + snapshot=False, + ) + except Exception as exc: + lib_logger.debug("Stream lifecycle trace failed: %s", exc) + def _collect_session_response_anchors( self, sse_string: str, diff --git a/tests/test_request_executor_stream_metrics.py b/tests/test_request_executor_stream_metrics.py new file mode 100644 index 000000000..0bf1c74e9 --- /dev/null +++ b/tests/test_request_executor_stream_metrics.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.client.streaming import StreamingHandler +from rotator_library.transaction_logger import TransactionLogger + + +async def _chunks(): + yield {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + yield {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}} + + +def _trace_passes(log_dir): + return [json.loads(line)["pass_name"] for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +@pytest.mark.asyncio +async def test_streaming_handler_emits_lifecycle_metrics_without_changing_output(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_chunks(), "cred", "openai/gpt-test", transaction_logger=logger)] + + assert chunks[0].startswith("data: ") + assert chunks[-1] == "data: [DONE]\n\n" + pass_names = _trace_passes(logger.log_dir) + assert "stream_started" in pass_names + assert "stream_first_byte" in pass_names + assert "stream_first_visible_output" in pass_names + assert "stream_completed" in pass_names + assert "stream_metrics_final" in pass_names From 1374388fdfc92ba143d71ed36c7f5ec51ab2eec5 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:35:35 +0200 Subject: [PATCH 056/182] feat(streaming): add native streaming opt-in seam Adds explicit provider native-streaming support declarations defaulting to false, plus helper functions to safely query support and convert formatted native stream chunks into the common stream event model. Tests verify default false support, opt-in support, common event conversion, native stream regressions, and provider declaration regressions. Tests: python -m pytest tests/test_native_streaming_transport_seam.py tests/test_native_provider_streaming.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py --- .../native_provider/streaming.py | 29 +++++++++++++++++++ .../providers/provider_interface.py | 11 +++++++ tests/test_native_streaming_transport_seam.py | 25 ++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 tests/test_native_streaming_transport_seam.py diff --git a/src/rotator_library/native_provider/streaming.py b/src/rotator_library/native_provider/streaming.py index 592b34704..b043a2338 100644 --- a/src/rotator_library/native_provider/streaming.py +++ b/src/rotator_library/native_provider/streaming.py @@ -7,6 +7,8 @@ from typing import Any +from ..streaming import StreamEvent, stream_event_from_sse_chunk + def stream_event_payload(event: Any) -> Any: """Return a JSON-safe payload for stream field-cache and trace passes.""" @@ -14,3 +16,30 @@ def stream_event_payload(event: Any) -> Any: if hasattr(event, "to_dict"): return event.to_dict() return event + + +def provider_supports_native_streaming(provider: Any, *, model: str = "", operation: str = "chat") -> bool: + """Return explicit provider native-streaming support. + + Providers must opt in. Missing methods and exceptions fail closed so routed + streaming can fallback before output rather than accidentally claiming native + streaming support. + """ + + method = getattr(provider, "supports_native_streaming", None) + if not method: + return False + try: + return bool(method(model=model, operation=operation)) + except TypeError: + return bool(method(model, operation)) + except Exception: + return False + + +def native_stream_event_from_formatted(formatted: Any, *, protocol: str = "openai_chat") -> StreamEvent: + """Convert a formatted native stream chunk into the common event seam.""" + + if isinstance(formatted, str): + return stream_event_from_sse_chunk(formatted, protocol=protocol) + return StreamEvent("parsed_chunk", protocol=protocol, data=formatted, raw=formatted) diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index 8b4696cad..c0df98738 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -276,6 +276,7 @@ class ProviderInterface(ABC, metaclass=SingletonABCMeta): protocol_name: Optional[str] = None adapter_names: Tuple[str, ...] = () field_cache_rules: Tuple[Any, ...] = () + native_streaming_supported: bool = False @abstractmethod async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: @@ -362,6 +363,16 @@ def get_field_cache_rules(self, model: str = "") -> Tuple[Any, ...]: return tuple(self.field_cache_rules) + def supports_native_streaming(self, model: str = "", operation: str = "chat") -> bool: + """Return whether this provider explicitly supports native streaming. + + The default is intentionally false. Phase 8 keeps live streaming + conservative: providers must opt in before routed streaming can use the + native stream executor instead of current custom/LiteLLM behavior. + """ + + return self.native_streaming_supported + async def acompletion( self, client: httpx.AsyncClient, **kwargs ) -> Union[ diff --git a/tests/test_native_streaming_transport_seam.py b/tests/test_native_streaming_transport_seam.py new file mode 100644 index 000000000..88b3f7a8a --- /dev/null +++ b/tests/test_native_streaming_transport_seam.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from rotator_library.native_provider.streaming import native_stream_event_from_formatted, provider_supports_native_streaming +from rotator_library.providers.provider_interface import ProviderInterface + + +class DefaultProvider(ProviderInterface): + async def get_models(self, api_key, client): + return [] + + +class StreamingProvider(DefaultProvider): + native_streaming_supported = True + + +def test_provider_native_streaming_support_defaults_false() -> None: + assert provider_supports_native_streaming(DefaultProvider(), model="gpt-test") is False + assert provider_supports_native_streaming(StreamingProvider(), model="gpt-test") is True + + +def test_native_formatted_sse_chunk_uses_common_stream_event_seam() -> None: + event = native_stream_event_from_formatted('data: {"choices":[{"delta":{"content":"hi"}}]}\n\n') + + assert event.visible_output is True + assert event.event_type == "delta" From 1c67bb8b2605903b5f95d8748c49dc00e13c5027 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:36:48 +0200 Subject: [PATCH 057/182] feat(streaming): trace responses stream metrics Adds common stream lifecycle metrics and trace passes to Responses streaming while preserving existing SSE output and stored final response behavior. Tests: python -m pytest tests/test_responses_streaming.py tests/test_responses_routes.py tests/test_request_executor_stream_metrics.py tests/test_stream_events.py tests/test_stream_transport.py tests/test_stream_metrics.py tests/test_stream_policy.py --- src/rotator_library/responses/service.py | 64 ++++++++++++++++++++++++ tests/test_responses_streaming.py | 16 ++++++ 2 files changed, 80 insertions(+) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index f5e9e0157..35639c846 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -10,6 +10,7 @@ from typing import Any, AsyncGenerator, Optional from ..protocols import ProtocolContext +from ..streaming import StreamEvent, StreamMonitor from ..protocols.responses import ResponsesProtocol from .bridge import ResponsesBridge from .store import InMemoryResponsesStore, ResponsesStore @@ -136,10 +137,29 @@ async def stream_response( state = ResponsesStreamState(response_id=response_id, model=unified.model) usage = None item_started = False + monitor = StreamMonitor(clock=time.monotonic) + self._trace( + transaction_logger, + "stream_started", + {"event": StreamEvent("started", protocol="responses").to_dict(), "metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="client", + metadata={"transport": "sse"}, + ) yield formatter.format_event("response.created", response_created_payload(response_id, unified.model)) try: chat_stream = await client.acompletion(request=request, **chat_kwargs) async for raw_chunk in chat_stream: + if monitor.metrics.first_byte_at is None: + monitor.record_event(StreamEvent("raw_chunk", protocol="responses", raw=raw_chunk)) + self._trace( + transaction_logger, + "stream_first_byte", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="provider", + metadata={"transport": "sse"}, + ) self._trace(transaction_logger, "raw_chat_bridge_stream_chunk", raw_chunk, direction="stream", stage="provider") chunk = parse_chat_sse_chunk(raw_chunk) if not chunk or chunk.get("type") == "done": @@ -162,6 +182,24 @@ async def stream_response( output_item_id=state.output_item_id, ) event = output_text_delta_payload(state, delta) + first_visible = monitor.metrics.first_visible_output_at is None + monitor.record_event( + StreamEvent( + "delta", + protocol="responses", + data=event, + visible_output=True, + ) + ) + if first_visible: + self._trace( + transaction_logger, + "stream_first_visible_output", + {"event": event, "metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": "sse"}, + ) self._trace(transaction_logger, "formatted_responses_stream_event", event, direction="stream", stage="final") yield formatter.format_event("response.output_text.delta", event) @@ -172,11 +210,37 @@ async def stream_response( completed = response_completed_payload(state, _usage_to_responses_stream(usage)) await self._store_stream_response(stream_request, completed, parent) self._trace(transaction_logger, "stored_responses_stream_response", completed, direction="metadata", stage="final") + monitor.complete() + self._trace( + transaction_logger, + "stream_completed", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": "sse"}, + ) + self._trace( + transaction_logger, + "stream_metrics_final", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": "sse"}, + ) yield formatter.format_event("response.completed", completed) yield formatter.done() except Exception as exc: + monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) + self._trace( + transaction_logger, + "stream_metrics_final", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": "sse", "failed": True}, + ) yield formatter.format_event("response.failed", failed) yield formatter.done() diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 4fc5fed80..937d822e0 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -3,6 +3,7 @@ import pytest from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesWebSocketFormatter +from rotator_library.transaction_logger import TransactionLogger def _event_names(events: list[str]) -> list[str]: @@ -84,3 +85,18 @@ def test_transport_formatters_expose_sse_and_websocket_seam() -> None: assert websocket.future_supported is True with pytest.raises(NotImplementedError): websocket.format_event("response.created", {}) + + +@pytest.mark.asyncio +async def test_responses_stream_records_common_stream_metrics(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + _ = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient(), transaction_logger=logger)] + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "stream_started" in trace_text + assert "stream_first_byte" in trace_text + assert "stream_first_visible_output" in trace_text + assert "stream_completed" in trace_text + assert "stream_metrics_final" in trace_text From 052a62c414f11cbc76409436586be51a441bc79f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:47:49 +0200 Subject: [PATCH 058/182] docs(experimental): plan usage quota cost Adds the Phase 9 plan for normalized usage records, advisory cost breakdowns, quota snapshots, executor/stream/Responses/native accounting integration, and focused regression coverage. --- docs/experimental/phase-9-usage-quota-cost.md | 331 ++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 docs/experimental/phase-9-usage-quota-cost.md diff --git a/docs/experimental/phase-9-usage-quota-cost.md b/docs/experimental/phase-9-usage-quota-cost.md new file mode 100644 index 000000000..66cb18273 --- /dev/null +++ b/docs/experimental/phase-9-usage-quota-cost.md @@ -0,0 +1,331 @@ +# Phase 9 Plan: Usage, Quota, Cost + +## Goal + +Make usage, quota, and cost accounting consistent across LiteLLM fallback, native protocol adapters, Responses API, streaming, routing fallback chains, and provider-specific quota metadata while preserving the existing `UsageManager`, `SelectionEngine`, fair-cycle behavior, storage format, and provider quota reset logic. Phase 9 should formalize a normalized usage/cost layer that feeds the current usage engine rather than replacing it. + +## Non-Goals + +- Do not replace `UsageManager`, `TrackingEngine`, `LimitEngine`, `SelectionEngine`, `SessionTracker`, or the retry-after parser. +- Do not introduce SQLite or any database. +- Do not implement the full Phase 10 JSON/env config system. +- Do not change credential selection strategy semantics. +- Do not rewrite provider quota trackers. +- Do not require live provider credentials or live pricing API calls. +- Do not make cost enforcement mandatory; cost tracking is additive unless a provider already has limits. +- Do not change client-visible response usage fields except where normalization already happens today. + +## Current Code Context + +- `CredentialContext.mark_success()` already accepts prompt, completion, thinking, cache read/write tokens, approximate cost, and response headers. +- `TrackingEngine.record_usage()` stores request count, token buckets, output tokens, total tokens, and approximate cost into windows/totals. +- `RequestExecutor._extract_usage_tokens()` extracts usage from LiteLLM-like response objects, subtracting reasoning tokens from completion tokens when the provider includes reasoning inside completion. +- `RequestExecutor._calculate_cost()` uses LiteLLM cost helpers and returns `0.0` for providers that set `skip_cost_calculation`. +- `StreamingHandler.wrap_stream()` records stream usage after final chunk and currently calculates cost with `litellm.get_model_info()`. +- Phase 4 Responses streaming maps chat usage into Responses usage for storage. +- Phase 5 native protocols can produce provider-native usage shapes through protocol parse/format paths. +- Phase 6 routing can try multiple targets, but only the successful target should record usage. +- Phase 7 fallback summaries are client-safe and should not be polluted by raw usage/cost internals. +- Phase 8 stream metrics are timing/count-only and do not alter usage recording. + +## Files To Add + +- `src/rotator_library/usage/accounting.py` +- `src/rotator_library/usage/costs.py` +- `src/rotator_library/usage/quota.py` +- `tests/test_usage_accounting.py` +- `tests/test_usage_costs.py` +- `tests/test_usage_quota_snapshots.py` +- `tests/test_executor_usage_accounting.py` +- `tests/test_streaming_usage_accounting.py` +- `tests/test_responses_usage_accounting.py` +- `tests/test_native_usage_accounting.py` + +## Files Likely To Touch + +- `src/rotator_library/usage/types.py` +- `src/rotator_library/usage/manager.py` +- `src/rotator_library/usage/tracking/engine.py` +- `src/rotator_library/client/executor.py` +- `src/rotator_library/client/streaming.py` +- `src/rotator_library/responses/bridge.py` +- `src/rotator_library/responses/service.py` +- `src/rotator_library/native_provider/executor.py` +- `src/rotator_library/core/utils.py` only if current usage normalization helper needs a small adapter. +- Provider files only for optional cost/quota declarations, not broad rewrites. + +## Normalized Usage Model + +Add `UsageRecord` dataclass: + +- `input_tokens` +- `output_tokens` +- `completion_tokens` +- `reasoning_tokens` +- `cache_read_tokens` +- `cache_write_tokens` +- `total_tokens` +- `raw_total_tokens` +- `request_count` +- `source` +- `provider` +- `model` +- `metadata` + +Token semantics: + +- `input_tokens` means billable non-cache-read prompt tokens when provider separates cache-read tokens. +- `cache_read_tokens` means prompt/cache tokens read from provider cache. +- `cache_write_tokens` means prompt/cache tokens written or created. +- `completion_tokens` means visible/output text/tool tokens, excluding reasoning when the provider reports reasoning separately. +- `reasoning_tokens` means hidden thinking/reasoning tokens. +- `output_tokens = completion_tokens + reasoning_tokens`. +- `total_tokens = input_tokens + cache_read_tokens + cache_write_tokens + completion_tokens + reasoning_tokens`. +- `raw_total_tokens` preserves provider-reported total before normalization for debugging. + +Rationale: current code already avoids double-counting reasoning by subtracting thinking from completion when necessary. Phase 9 makes that rule explicit and reusable. + +## Usage Extraction + +Add `extract_usage_record(response_or_usage, provider=None, model=None, source=None)`. + +Support: + +- LiteLLM/OpenAI object usage attributes. +- dict usage fields. +- OpenAI `prompt_tokens_details.cached_tokens`. +- OpenAI `completion_tokens_details.reasoning_tokens`. +- Anthropic `input_tokens`, `output_tokens`, `cache_creation_input_tokens`, `cache_read_input_tokens`. +- Gemini `usageMetadata`, `promptTokenCount`, `candidatesTokenCount`, `thoughtsTokenCount`, `cachedContentTokenCount`, `totalTokenCount`. +- Responses `input_tokens`, `output_tokens`, `output_tokens_details.reasoning_tokens`, `input_tokens_details.cached_tokens`. +- Existing stream usage dicts after `normalize_usage_for_response()`. + +Unknown usage shapes return an empty `UsageRecord` with source metadata rather than raising in runtime paths. + +Tests must cover dicts, objects, nested details, and double-count prevention. + +## Cost Model + +Add `CostBreakdown` dataclass: + +- `input_cost` +- `cache_read_cost` +- `cache_write_cost` +- `output_cost` +- `reasoning_cost` +- `total_cost` +- `currency` +- `pricing_source` +- `metadata` + +Add `ModelPricing` dataclass: + +- per-token prices for input, cache read, cache write, output, reasoning. +- `currency`. +- `source`. + +Add `CostCalculator`: + +- prefers explicit provider/model pricing declarations. +- falls back to LiteLLM model info/completion_cost where applicable. +- returns zero cost for `skip_cost_calculation`. +- does not call network. + +Minimal Phase 9 provider declarations: + +- Optional `get_model_pricing(model)` on `ProviderInterface`, default `None`. +- Providers can later define native pricing without touching accounting code. + +Env pricing can be minimal and Phase-10-ready: + +- `MODEL_PRICE_{PROVIDER}_{MODEL}_INPUT` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_OUTPUT` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_CACHE_READ` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_CACHE_WRITE` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_REASONING` + +If env parsing is too broad, implement programmatic pricing tests and leave env polish to Phase 10. + +## UsageManager Integration + +- Preserve existing `mark_success()` signature for compatibility. +- Add optional `usage_record` and `cost_breakdown` parameters only if needed, but do not force every call site to change. +- Preferred minimal integration: + - Executor and streaming code use `UsageRecord` internally, then pass existing numeric fields to `mark_success()`. + - `approx_cost` remains a float in current storage. + - Add optional metadata later only if storage impact is small. +- Do not change persisted JSON shape unless additive and backward-compatible. +- Add trace pass: + - `usage_accounting_summary` + - Include normalized token fields, raw total, cost total, pricing source, provider/model, and source. + - No credential secrets. + +## Quota Snapshot Model + +Add `QuotaSnapshot` dataclass: + +- `provider` +- `model` +- `quota_group` +- `credential_id` optional stable/masked identifier. +- `window_name` +- `limit` +- `used` +- `remaining` +- `reset_at` +- `source` +- `metadata` + +Add helper to build snapshots from `UsageManager` state: + +- per model. +- per quota group. +- per credential where available. + +Keep `WindowLimitChecker` behavior unchanged. + +Tests should prove snapshots reflect group windows and model windows without affecting limits. + +Optional `UsageAPI` helper can expose snapshots for future UI/TUI work if small. + +## Routing/Fallback Usage Behavior + +- Only successful target attempts record usage. +- Failed target attempts continue to record failures through existing paths. +- Route fallback summaries should include optional `usage_recorded: false` for failed targets only if already tracked internally; do not expose raw costs. +- Tests: + - first target failure + second target success records usage only on second provider/credential. + - native target success records normalized usage. + - explicit LiteLLM fallback target records normalized usage. + +## Streaming Usage + +- Replace ad-hoc stream usage extraction/cost calculation with `UsageRecord` and `CostCalculator`. +- Preserve existing `mark_success()` numeric values and final `[DONE]` behavior. +- Preserve `skip_cost_calculation`. +- Add tests: + - stream final usage with cache/read/write/reasoning records expected buckets. + - stream with missing usage records zero usage but still marks success as today. + - stream metrics from Phase 8 remain independent from token usage. + +## Responses Usage + +- Use `UsageRecord` when converting/storing Responses usage. +- Ensure `previous_response_id` storage is not affected. +- Add tests: + - non-streaming Responses usage maps to normalized fields. + - streaming Responses usage stores expected input/output totals. + - reasoning details are not double counted. + +## Native Provider Usage + +- In `NativeProviderExecutor.execute()`, protocol-formatted provider responses should expose usage in a shape `extract_usage_record()` can understand. +- Add trace summary but do not require the isolated native executor to mutate `UsageManager`. +- Live routed native execution via `RequestExecutor` should record normalized usage because it receives the native response and passes through the same executor accounting. +- Tests: + - native OpenAI-style response usage records normalized fields. + - native Gemini usage metadata records thought/cached tokens. + - native Responses usage records reasoning/details. + +## Quota/Cost Reporting + +- Keep existing quota viewer/API behavior unchanged unless adding optional fields is safe. +- Add cost totals to existing total/window stats already present as `approx_cost`. +- Do not implement hard cost caps unless trivial and isolated; cost caps can be a Phase 10+ config feature. +- Add clear docs/comments that `approx_cost` is advisory and depends on available pricing. + +## Transform Trace Requirements + +Add: + +- `usage_accounting_summary` +- `usage_cost_calculated` +- `quota_snapshot_built` for explicit snapshot APIs/tests, not every request. + +Metadata: + +- provider, model, source, pricing source, skip-cost flag. + +Data: + +- normalized usage record and cost breakdown only. + +Do not log credentials, raw headers, or raw provider errors. + +## Testing Plan + +Accounting tests: + +- OpenAI/LiteLLM object usage. +- OpenAI dict usage with prompt/completion details. +- Anthropic usage. +- Gemini `usageMetadata`. +- Responses usage. +- reasoning double-count prevention. +- missing/unknown usage shape returns empty record. + +Cost tests: + +- explicit pricing calculates all buckets. +- provider `skip_cost_calculation` returns zero. +- LiteLLM fallback path still returns a float when LiteLLM has model info. +- missing pricing returns zero with `pricing_source="unavailable"`. + +Quota snapshot tests: + +- model window snapshot. +- group window snapshot. +- missing window snapshot is empty/no-op. + +Executor tests: + +- non-stream success uses normalized usage. +- trace emits usage accounting summary. +- fallback second target success records usage once. + +Streaming tests: + +- final stream usage uses normalized accounting. +- stream cost honors skip-cost provider. +- Phase 8 metrics still emitted. + +Responses tests: + +- stored Responses usage shape remains compatible. +- normalized trace emitted when transaction logger exists. + +Native tests: + +- native response usage shapes normalize. + +Regression tests: + +- Phase 1 protocol tests. +- Phase 2 transform logging tests. +- Phase 3 adapter/cache tests. +- Phase 4 Responses tests. +- Phase 5 provider/native tests. +- Phase 6 routing tests. +- Phase 7 retry/cooldown tests. +- Phase 8 streaming tests. +- `test_session_tracking.py`. +- `test_selection_engine.py`. + +## Commit Checkpoints + +1. Add `UsageRecord`, extraction helpers, and tests. +2. Add `CostBreakdown`, pricing helpers/calculator, and tests. +3. Wire executor non-streaming accounting and trace summary with tests. +4. Wire streaming accounting/cost calculation with tests. +5. Add quota snapshot helpers and tests. +6. Add Responses/native usage accounting coverage. +7. Run focused and regression tests. +8. Review with `explore` and `explore-heavy`; fix findings; write uncommitted Phase 9 report. + +## Risks And Mitigations + +- Token normalization could change quota accounting. Mitigation: preserve current numeric `mark_success()` buckets and test current OpenAI/LiteLLM behavior. +- Cost estimates may be wrong for providers with unknown pricing. Mitigation: return zero/unavailable rather than guessing; make `approx_cost` advisory. +- Persisted usage JSON compatibility could break. Mitigation: keep existing storage fields and only add optional helpers unless tests prove serialization is safe. +- Provider-specific usage shapes are broad. Mitigation: implement common native shapes now and leave provider overrides possible. +- Fallback chains could double-record usage. Mitigation: only successful attempt calls `mark_success()`; add tests for failure then success. From ec38cc0d7ef690f34421ffc9f65ba4a398cc27b5 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:49:50 +0200 Subject: [PATCH 059/182] feat(usage): add normalized usage records Adds UsageRecord and extraction helpers for OpenAI/LiteLLM, Anthropic, Gemini, Responses, dict/object usage shapes, cache buckets, and reasoning-token double-count prevention. The helper feeds the existing UsageManager buckets and does not change persistence, selection, or quota behavior. Tests: python -m pytest tests/test_usage_accounting.py tests/test_selection_engine.py --- src/rotator_library/usage/__init__.py | 3 + src/rotator_library/usage/accounting.py | 246 ++++++++++++++++++++++++ tests/test_usage_accounting.py | 108 +++++++++++ 3 files changed, 357 insertions(+) create mode 100644 src/rotator_library/usage/accounting.py create mode 100644 tests/test_usage_accounting.py diff --git a/src/rotator_library/usage/__init__.py b/src/rotator_library/usage/__init__.py index f8ae947fe..7dcf9060a 100644 --- a/src/rotator_library/usage/__init__.py +++ b/src/rotator_library/usage/__init__.py @@ -56,11 +56,14 @@ # Main facade (imports components above) from .manager import UsageManager, CredentialContext +from .accounting import UsageRecord, extract_usage_record __all__ = [ # Main public API "UsageManager", "CredentialContext", + "UsageRecord", + "extract_usage_record", # Types "WindowStats", "TotalStats", diff --git a/src/rotator_library/usage/accounting.py b/src/rotator_library/usage/accounting.py new file mode 100644 index 000000000..f045d9c6f --- /dev/null +++ b/src/rotator_library/usage/accounting.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Normalized usage accounting across provider protocols. + +The existing `UsageManager` remains the source of truth for persistence, +windows, and credential selection. This module only converts provider-specific +usage payloads into the numeric buckets that `CredentialContext.mark_success()` +already accepts. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from ..protocols import serialize_value + + +@dataclass(frozen=True) +class UsageRecord: + """Provider-neutral token usage buckets. + + `completion_tokens` excludes `reasoning_tokens` when providers report hidden + reasoning separately. This preserves the current no-double-count behavior and + gives Phase 9+ cost logic clear input/output/reasoning buckets. + """ + + input_tokens: int = 0 + completion_tokens: int = 0 + reasoning_tokens: int = 0 + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 + raw_total_tokens: int = 0 + request_count: int = 1 + source: str = "unknown" + provider: Optional[str] = None + model: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def output_tokens(self) -> int: + return self.completion_tokens + self.reasoning_tokens + + @property + def total_tokens(self) -> int: + return ( + self.input_tokens + + self.cache_read_tokens + + self.cache_write_tokens + + self.completion_tokens + + self.reasoning_tokens + ) + + @property + def prompt_tokens_for_mark_success(self) -> int: + """Return non-cache-read prompt tokens for existing usage storage.""" + + return self.input_tokens + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-safe summary for transform traces and tests.""" + + return { + "input_tokens": self.input_tokens, + "completion_tokens": self.completion_tokens, + "reasoning_tokens": self.reasoning_tokens, + "cache_read_tokens": self.cache_read_tokens, + "cache_write_tokens": self.cache_write_tokens, + "output_tokens": self.output_tokens, + "total_tokens": self.total_tokens, + "raw_total_tokens": self.raw_total_tokens, + "request_count": self.request_count, + "source": self.source, + "provider": self.provider, + "model": self.model, + "metadata": serialize_value(self.metadata), + } + + +def extract_usage_record( + response_or_usage: Any, + *, + provider: Optional[str] = None, + model: Optional[str] = None, + source: str = "response", +) -> UsageRecord: + """Extract normalized usage from a response object or usage payload.""" + + usage = _unwrap_usage(response_or_usage) + if usage is None: + return UsageRecord(provider=provider, model=model, source=source) + data = _as_dict(usage) + if not data: + return UsageRecord(provider=provider, model=model, source=source) + + if "usageMetadata" in data and isinstance(data["usageMetadata"], dict): + data = data["usageMetadata"] + + if _looks_like_gemini(data): + return _from_gemini_usage(data, provider=provider, model=model, source=source) + if _looks_like_anthropic(data): + return _from_anthropic_usage(data, provider=provider, model=model, source=source) + return _from_openai_like_usage(data, provider=provider, model=model, source=source) + + +def _unwrap_usage(value: Any) -> Any: + if value is None: + return None + if isinstance(value, dict): + return value.get("usage", value) + return getattr(value, "usage", value) + + +def _as_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return value + if hasattr(value, "model_dump"): + dumped = value.model_dump() + return dumped if isinstance(dumped, dict) else {} + if hasattr(value, "dict"): + dumped = value.dict() + return dumped if isinstance(dumped, dict) else {} + result: dict[str, Any] = {} + for key in ( + "prompt_tokens", + "completion_tokens", + "total_tokens", + "prompt_tokens_details", + "completion_tokens_details", + "cache_read_tokens", + "cache_creation_tokens", + "input_tokens", + "output_tokens", + "input_tokens_details", + "output_tokens_details", + "cache_creation_input_tokens", + "cache_read_input_tokens", + "cached_tokens", + "cache_creation_tokens", + "cache_write_tokens", + "reasoning_tokens", + "thinking_tokens", + ): + if hasattr(value, key): + result[key] = getattr(value, key) + return result + + +def _from_openai_like_usage(data: dict[str, Any], *, provider: Optional[str], model: Optional[str], source: str) -> UsageRecord: + prompt_tokens = _int(data.get("prompt_tokens", data.get("input_tokens", 0))) + completion_tokens = _int(data.get("completion_tokens", data.get("output_tokens", 0))) + raw_total = _int(data.get("total_tokens", data.get("raw_total_tokens", 0))) + + prompt_details = _as_dict(data.get("prompt_tokens_details") or data.get("input_tokens_details") or {}) + completion_details = _as_dict(data.get("completion_tokens_details") or data.get("output_tokens_details") or {}) + cache_read = _int( + data.get( + "cache_read_tokens", + data.get("cached_tokens", prompt_details.get("cached_tokens", 0)), + ) + ) + cache_write = _int( + data.get( + "cache_creation_tokens", + data.get("cache_write_tokens", prompt_details.get("cache_creation_tokens", 0)), + ) + ) + reasoning = _int( + data.get( + "reasoning_tokens", + completion_details.get("reasoning_tokens", completion_details.get("thinking_tokens", 0)), + ) + ) + if reasoning and completion_tokens >= reasoning: + completion_tokens -= reasoning + input_tokens = max(0, prompt_tokens - cache_read) + return UsageRecord( + input_tokens=input_tokens, + completion_tokens=completion_tokens, + reasoning_tokens=reasoning, + cache_read_tokens=cache_read, + cache_write_tokens=cache_write, + raw_total_tokens=raw_total, + source=source, + provider=provider, + model=model, + metadata={"shape": "openai_like"}, + ) + + +def _from_anthropic_usage(data: dict[str, Any], *, provider: Optional[str], model: Optional[str], source: str) -> UsageRecord: + cache_read = _int(data.get("cache_read_input_tokens", data.get("cache_read_tokens", 0))) + cache_write = _int(data.get("cache_creation_input_tokens", data.get("cache_creation_tokens", 0))) + input_tokens = max(0, _int(data.get("input_tokens", 0)) - cache_read - cache_write) + output_tokens = _int(data.get("output_tokens", data.get("completion_tokens", 0))) + reasoning = _int(data.get("reasoning_tokens", data.get("thinking_tokens", 0))) + if reasoning and output_tokens >= reasoning: + output_tokens -= reasoning + return UsageRecord( + input_tokens=input_tokens, + completion_tokens=output_tokens, + reasoning_tokens=reasoning, + cache_read_tokens=cache_read, + cache_write_tokens=cache_write, + raw_total_tokens=_int(data.get("total_tokens", 0)), + source=source, + provider=provider, + model=model, + metadata={"shape": "anthropic"}, + ) + + +def _from_gemini_usage(data: dict[str, Any], *, provider: Optional[str], model: Optional[str], source: str) -> UsageRecord: + cache_read = _int(data.get("cachedContentTokenCount", data.get("cache_read_tokens", 0))) + prompt_tokens = _int(data.get("promptTokenCount", data.get("prompt_tokens", 0))) + reasoning = _int(data.get("thoughtsTokenCount", data.get("reasoning_tokens", 0))) + completion = _int(data.get("candidatesTokenCount", data.get("completion_tokens", 0))) + if reasoning and completion >= reasoning: + completion -= reasoning + return UsageRecord( + input_tokens=max(0, prompt_tokens - cache_read), + completion_tokens=completion, + reasoning_tokens=reasoning, + cache_read_tokens=cache_read, + raw_total_tokens=_int(data.get("totalTokenCount", data.get("total_tokens", 0))), + source=source, + provider=provider, + model=model, + metadata={"shape": "gemini"}, + ) + + +def _looks_like_gemini(data: dict[str, Any]) -> bool: + return any(key in data for key in ("promptTokenCount", "candidatesTokenCount", "thoughtsTokenCount", "cachedContentTokenCount")) + + +def _looks_like_anthropic(data: dict[str, Any]) -> bool: + return any(key in data for key in ("cache_creation_input_tokens", "cache_read_input_tokens")) and "input_tokens" in data + + +def _int(value: Any) -> int: + try: + return max(0, int(value or 0)) + except (TypeError, ValueError): + return 0 diff --git a/tests/test_usage_accounting.py b/tests/test_usage_accounting.py new file mode 100644 index 000000000..d74db9d6d --- /dev/null +++ b/tests/test_usage_accounting.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from rotator_library.usage.accounting import UsageRecord, extract_usage_record + + +def test_openai_dict_usage_extracts_cache_and_reasoning_without_double_counting() -> None: + record = extract_usage_record( + { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 30, + "total_tokens": 130, + "prompt_tokens_details": {"cached_tokens": 40, "cache_creation_tokens": 5}, + "completion_tokens_details": {"reasoning_tokens": 10}, + } + }, + provider="openai", + model="gpt-test", + ) + + assert record.input_tokens == 60 + assert record.cache_read_tokens == 40 + assert record.cache_write_tokens == 5 + assert record.completion_tokens == 20 + assert record.reasoning_tokens == 10 + assert record.output_tokens == 30 + assert record.raw_total_tokens == 130 + + +def test_openai_object_usage_extracts_attributes() -> None: + response = SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=12, + completion_tokens=7, + prompt_tokens_details=SimpleNamespace(cached_tokens=2, cache_creation_tokens=1), + completion_tokens_details=SimpleNamespace(reasoning_tokens=3), + ) + ) + + record = extract_usage_record(response) + + assert record.input_tokens == 10 + assert record.cache_read_tokens == 2 + assert record.cache_write_tokens == 1 + assert record.completion_tokens == 4 + assert record.reasoning_tokens == 3 + + +def test_anthropic_usage_extracts_cache_buckets() -> None: + record = extract_usage_record( + { + "input_tokens": 100, + "output_tokens": 20, + "cache_creation_input_tokens": 8, + "cache_read_input_tokens": 12, + } + ) + + assert record.input_tokens == 80 + assert record.cache_read_tokens == 12 + assert record.cache_write_tokens == 8 + assert record.completion_tokens == 20 + assert record.metadata["shape"] == "anthropic" + + +def test_gemini_usage_metadata_extracts_thought_and_cached_tokens() -> None: + record = extract_usage_record( + { + "usageMetadata": { + "promptTokenCount": 80, + "candidatesTokenCount": 25, + "thoughtsTokenCount": 5, + "cachedContentTokenCount": 30, + "totalTokenCount": 110, + } + } + ) + + assert record.input_tokens == 50 + assert record.cache_read_tokens == 30 + assert record.completion_tokens == 20 + assert record.reasoning_tokens == 5 + assert record.metadata["shape"] == "gemini" + + +def test_responses_usage_extracts_nested_details() -> None: + record = extract_usage_record( + { + "input_tokens": 42, + "output_tokens": 13, + "input_tokens_details": {"cached_tokens": 10}, + "output_tokens_details": {"reasoning_tokens": 4}, + }, + source="responses", + ) + + assert record.input_tokens == 32 + assert record.cache_read_tokens == 10 + assert record.completion_tokens == 9 + assert record.reasoning_tokens == 4 + + +def test_unknown_usage_shape_returns_empty_record() -> None: + record = extract_usage_record({"not_usage": True}, provider="x", model="y") + + assert record == UsageRecord(provider="x", model="y", source="response", metadata={"shape": "openai_like"}) From e2fd82554985d52b18d1ae45aa4cf18716b5f27d Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:51:23 +0200 Subject: [PATCH 060/182] feat(usage): add advisory cost calculator Adds ModelPricing, CostBreakdown, and CostCalculator for normalized usage records. The calculator supports provider-declared pricing, LiteLLM model-info fallback, and provider skip-cost behavior without changing UsageManager persistence or selection. Tests: python -m pytest tests/test_usage_costs.py tests/test_usage_accounting.py tests/test_provider_protocol_declarations.py --- .../providers/provider_interface.py | 10 ++ src/rotator_library/usage/__init__.py | 4 + src/rotator_library/usage/costs.py | 136 ++++++++++++++++++ tests/test_usage_costs.py | 59 ++++++++ 4 files changed, 209 insertions(+) create mode 100644 src/rotator_library/usage/costs.py create mode 100644 tests/test_usage_costs.py diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index c0df98738..e73acebe0 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -373,6 +373,16 @@ def supports_native_streaming(self, model: str = "", operation: str = "chat") -> return self.native_streaming_supported + def get_model_pricing(self, model: str = "") -> Optional[Any]: + """Return optional local pricing metadata for advisory cost tracking. + + Providers can return `usage.costs.ModelPricing` or a compatible dict. + The default is `None`, which lets cost accounting safely fall back to + LiteLLM model metadata or report pricing as unavailable. + """ + + return None + async def acompletion( self, client: httpx.AsyncClient, **kwargs ) -> Union[ diff --git a/src/rotator_library/usage/__init__.py b/src/rotator_library/usage/__init__.py index 7dcf9060a..fe9f9616d 100644 --- a/src/rotator_library/usage/__init__.py +++ b/src/rotator_library/usage/__init__.py @@ -57,12 +57,16 @@ # Main facade (imports components above) from .manager import UsageManager, CredentialContext from .accounting import UsageRecord, extract_usage_record +from .costs import CostBreakdown, CostCalculator, ModelPricing __all__ = [ # Main public API "UsageManager", "CredentialContext", "UsageRecord", + "CostBreakdown", + "CostCalculator", + "ModelPricing", "extract_usage_record", # Types "WindowStats", diff --git a/src/rotator_library/usage/costs.py b/src/rotator_library/usage/costs.py new file mode 100644 index 000000000..55cf1ed92 --- /dev/null +++ b/src/rotator_library/usage/costs.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Advisory cost calculation for normalized usage records.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +import litellm + +from ..protocols import serialize_value +from .accounting import UsageRecord + + +@dataclass(frozen=True) +class ModelPricing: + """Per-token pricing for one provider/model. + + Prices are advisory and local-only; this module never calls a network pricing + endpoint. Providers can return this object from `get_model_pricing()` later. + """ + + input_cost_per_token: float = 0.0 + cache_read_cost_per_token: float = 0.0 + cache_write_cost_per_token: float = 0.0 + output_cost_per_token: float = 0.0 + reasoning_cost_per_token: float = 0.0 + currency: str = "USD" + source: str = "explicit" + + +@dataclass(frozen=True) +class CostBreakdown: + """Advisory request cost split by normalized usage bucket.""" + + input_cost: float = 0.0 + cache_read_cost: float = 0.0 + cache_write_cost: float = 0.0 + output_cost: float = 0.0 + reasoning_cost: float = 0.0 + currency: str = "USD" + pricing_source: str = "unavailable" + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def total_cost(self) -> float: + return self.input_cost + self.cache_read_cost + self.cache_write_cost + self.output_cost + self.reasoning_cost + + def to_dict(self) -> dict[str, Any]: + return { + "input_cost": self.input_cost, + "cache_read_cost": self.cache_read_cost, + "cache_write_cost": self.cache_write_cost, + "output_cost": self.output_cost, + "reasoning_cost": self.reasoning_cost, + "total_cost": self.total_cost, + "currency": self.currency, + "pricing_source": self.pricing_source, + "metadata": serialize_value(self.metadata), + } + + +class CostCalculator: + """Calculate advisory costs without replacing usage tracking.""" + + def __init__(self, *, provider_plugin: Any = None, use_litellm_fallback: bool = True) -> None: + self.provider_plugin = provider_plugin + self.use_litellm_fallback = use_litellm_fallback + + def calculate(self, usage: UsageRecord, *, model: str, response: Any = None) -> CostBreakdown: + """Return an advisory cost breakdown for a normalized usage record.""" + + if self.provider_plugin and getattr(self.provider_plugin, "skip_cost_calculation", False): + return CostBreakdown(pricing_source="skipped", metadata={"reason": "provider_skip_cost_calculation"}) + pricing = self._provider_pricing(model) + if pricing: + return _calculate_from_pricing(usage, pricing) + if self.use_litellm_fallback: + lite = self._litellm_cost(usage, model=model, response=response) + if lite: + return lite + return CostBreakdown(pricing_source="unavailable") + + def _provider_pricing(self, model: str) -> Optional[ModelPricing]: + if not self.provider_plugin: + return None + method = getattr(self.provider_plugin, "get_model_pricing", None) + if not method: + return None + pricing = method(model) + if isinstance(pricing, ModelPricing): + return pricing + if isinstance(pricing, dict): + return ModelPricing(**pricing) + return None + + @staticmethod + def _litellm_cost(usage: UsageRecord, *, model: str, response: Any = None) -> Optional[CostBreakdown]: + if response is not None: + try: + cost = litellm.completion_cost(completion_response=response, model=model) + if cost is not None: + return CostBreakdown(output_cost=float(cost), pricing_source="litellm_completion_cost") + except Exception: + pass + try: + model_info = litellm.get_model_info(model) + except Exception: + return None + input_price = float(model_info.get("input_cost_per_token") or 0.0) + output_price = float(model_info.get("output_cost_per_token") or 0.0) + if input_price == 0.0 and output_price == 0.0: + return None + pricing = ModelPricing( + input_cost_per_token=input_price, + cache_read_cost_per_token=float(model_info.get("cache_read_input_token_cost") or input_price), + cache_write_cost_per_token=float(model_info.get("cache_creation_input_token_cost") or input_price), + output_cost_per_token=output_price, + reasoning_cost_per_token=output_price, + source="litellm_model_info", + ) + return _calculate_from_pricing(usage, pricing) + + +def _calculate_from_pricing(usage: UsageRecord, pricing: ModelPricing) -> CostBreakdown: + return CostBreakdown( + input_cost=usage.input_tokens * pricing.input_cost_per_token, + cache_read_cost=usage.cache_read_tokens * pricing.cache_read_cost_per_token, + cache_write_cost=usage.cache_write_tokens * pricing.cache_write_cost_per_token, + output_cost=usage.completion_tokens * pricing.output_cost_per_token, + reasoning_cost=usage.reasoning_tokens * (pricing.reasoning_cost_per_token or pricing.output_cost_per_token), + currency=pricing.currency, + pricing_source=pricing.source, + ) diff --git a/tests/test_usage_costs.py b/tests/test_usage_costs.py new file mode 100644 index 000000000..3dce73ed0 --- /dev/null +++ b/tests/test_usage_costs.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from rotator_library.usage.accounting import UsageRecord +from rotator_library.usage.costs import CostCalculator, ModelPricing + + +class PricingProvider: + def get_model_pricing(self, model: str): + return ModelPricing( + input_cost_per_token=0.001, + cache_read_cost_per_token=0.0001, + cache_write_cost_per_token=0.0005, + output_cost_per_token=0.002, + reasoning_cost_per_token=0.003, + source="provider_test", + ) + + +class SkipCostProvider: + skip_cost_calculation = True + + +def test_explicit_pricing_calculates_all_usage_buckets() -> None: + usage = UsageRecord(input_tokens=10, cache_read_tokens=20, cache_write_tokens=5, completion_tokens=7, reasoning_tokens=3) + + cost = CostCalculator(provider_plugin=PricingProvider(), use_litellm_fallback=False).calculate(usage, model="test") + + assert cost.input_cost == 0.01 + assert cost.cache_read_cost == 0.002 + assert cost.cache_write_cost == 0.0025 + assert cost.output_cost == 0.014 + assert cost.reasoning_cost == 0.009000000000000001 + assert cost.pricing_source == "provider_test" + + +def test_skip_cost_provider_returns_zero_skipped_breakdown() -> None: + cost = CostCalculator(provider_plugin=SkipCostProvider()).calculate(UsageRecord(input_tokens=100), model="test") + + assert cost.total_cost == 0.0 + assert cost.pricing_source == "skipped" + + +def test_missing_pricing_returns_unavailable_zero() -> None: + cost = CostCalculator(use_litellm_fallback=False).calculate(UsageRecord(input_tokens=100), model="unknown-model") + + assert cost.total_cost == 0.0 + assert cost.pricing_source == "unavailable" + + +def test_litellm_model_info_fallback(monkeypatch) -> None: + monkeypatch.setattr( + "rotator_library.usage.costs.litellm.get_model_info", + lambda model: {"input_cost_per_token": 0.001, "output_cost_per_token": 0.002}, + ) + + cost = CostCalculator().calculate(UsageRecord(input_tokens=2, completion_tokens=3), model="gpt-test") + + assert cost.total_cost == 0.008 + assert cost.pricing_source == "litellm_model_info" From 4d194e477cfdf094f3f9cfd70ca2205a476fdca3 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:55:06 +0200 Subject: [PATCH 061/182] feat(usage): account for executor responses Wires normalized UsageRecord extraction and advisory CostBreakdown calculation into non-streaming executor success handling while preserving the existing CredentialContext.mark_success numeric buckets. Adds usage_accounting_summary trace output and keeps the previous tuple extraction behavior available through the normalized record. Tests: python -m pytest tests/test_executor_usage_accounting.py tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_request_executor_fallback_groups.py tests/test_request_executor_native_routing.py tests/test_request_executor_fallback_error_summary.py tests/test_responses_service.py tests/test_native_provider_executor.py --- src/rotator_library/client/executor.py | 81 ++++++++++++++++++++++--- tests/test_executor_usage_accounting.py | 61 +++++++++++++++++++ 2 files changed, 133 insertions(+), 9 deletions(-) create mode 100644 tests/test_executor_usage_accounting.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index f5052c1a6..41c83fe63 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -66,6 +66,8 @@ from ..routing import FallbackPolicy, clone_context_for_target from ..routing.types import RouteTarget from ..native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from ..usage.accounting import UsageRecord, extract_usage_record +from ..usage.costs import CostBreakdown, CostCalculator from .types import RetryState, AvailabilityStats from .filters import CredentialFilter @@ -919,16 +921,15 @@ async def _execute_non_streaming( ) # Success! Extract token usage if available - ( - prompt_tokens, - completion_tokens, - prompt_tokens_cached, - prompt_tokens_cache_write, - thinking_tokens, - ) = self._extract_usage_tokens(response) - approx_cost = self._calculate_cost( - provider, model, response + usage_record, cost_breakdown = self._account_for_response_usage( + provider, model, response, context ) + prompt_tokens = usage_record.prompt_tokens_for_mark_success + completion_tokens = usage_record.completion_tokens + prompt_tokens_cached = usage_record.cache_read_tokens + prompt_tokens_cache_write = usage_record.cache_write_tokens + thinking_tokens = usage_record.reasoning_tokens + approx_cost = cost_breakdown.total_cost response_headers = self._extract_response_headers( response ) @@ -1785,6 +1786,68 @@ async def _validate_request( raise ValueError(result) def _extract_usage_tokens(self, response: Any) -> tuple[int, int, int, int, int]: + """Extract legacy usage tuple through the normalized usage record.""" + + record = extract_usage_record(response) + return ( + record.prompt_tokens_for_mark_success, + record.completion_tokens, + record.cache_read_tokens, + record.cache_write_tokens, + record.reasoning_tokens, + ) + + def _account_for_response_usage( + self, + provider: str, + model: str, + response: Any, + context: RequestContext, + ) -> tuple[UsageRecord, CostBreakdown]: + """Normalize usage and advisory cost for one successful response.""" + + usage_record = extract_usage_record( + response, + provider=provider, + model=model, + source="executor_response", + ) + plugin = self._get_plugin_instance(provider) + cost_breakdown = CostCalculator(provider_plugin=plugin).calculate( + usage_record, + model=model, + response=response, + ) + self._trace_usage_accounting(context, usage_record, cost_breakdown) + return usage_record, cost_breakdown + + @staticmethod + def _trace_usage_accounting( + context: RequestContext, + usage_record: UsageRecord, + cost_breakdown: CostBreakdown, + ) -> None: + """Record normalized usage/cost trace data without affecting requests.""" + + if not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + "usage_accounting_summary", + {"usage": usage_record.to_dict(), "cost": cost_breakdown.to_dict()}, + direction="metadata", + stage="final", + metadata={ + "provider": usage_record.provider, + "model": usage_record.model, + "source": usage_record.source, + "pricing_source": cost_breakdown.pricing_source, + }, + snapshot=False, + ) + + def _legacy_extract_usage_tokens(self, response: Any) -> tuple[int, int, int, int, int]: + """Previous extraction logic kept temporarily for comparison/debugging.""" + prompt_tokens = 0 completion_tokens = 0 cached_tokens = 0 diff --git a/tests/test_executor_usage_accounting.py b/tests/test_executor_usage_accounting.py new file mode 100644 index 000000000..4730c8857 --- /dev/null +++ b/tests/test_executor_usage_accounting.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace + +from rotator_library.client.executor import RequestExecutor +from rotator_library.core.types import RequestContext +from rotator_library.transaction_logger import TransactionLogger + + +def _executor() -> RequestExecutor: + return RequestExecutor({}, None, None, None, {}, None) + + +def test_executor_accounts_for_non_streaming_usage_and_cost_trace(tmp_path, monkeypatch) -> None: + executor = _executor() + response = SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=100, + completion_tokens=30, + prompt_tokens_details={"cached_tokens": 40}, + completion_tokens_details={"reasoning_tokens": 10}, + ) + ) + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + context = RequestContext( + provider="openai", + model="gpt-test", + kwargs={}, + streaming=False, + credentials=[], + deadline=0, + transaction_logger=logger, + ) + monkeypatch.setattr( + "rotator_library.usage.costs.litellm.get_model_info", + lambda model: {"input_cost_per_token": 0.001, "output_cost_per_token": 0.002}, + ) + + usage, cost = executor._account_for_response_usage("openai", "gpt-test", response, context) + + assert usage.input_tokens == 60 + assert usage.cache_read_tokens == 40 + assert usage.completion_tokens == 20 + assert usage.reasoning_tokens == 10 + assert cost.total_cost > 0 + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + assert entries[-1]["pass_name"] == "usage_accounting_summary" + assert entries[-1]["data"]["usage"]["total_tokens"] == usage.total_tokens + + +def test_executor_extract_usage_tokens_uses_normalized_record() -> None: + response = SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=10, + completion_tokens=5, + completion_tokens_details={"reasoning_tokens": 2}, + ) + ) + + assert _executor()._extract_usage_tokens(response) == (10, 3, 0, 0, 2) From cf964f25bd218a09e66b42b8b66b8c5893749307 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:56:41 +0200 Subject: [PATCH 062/182] feat(usage): account for stream usage Wires StreamingHandler final usage through normalized UsageRecord extraction and advisory CostCalculator while preserving existing mark_success buckets, stream metrics, and SSE output. Adds stream usage accounting trace coverage and skip-cost behavior tests. Tests: python -m pytest tests/test_streaming_usage_accounting.py tests/test_request_executor_stream_metrics.py tests/test_streaming_fallback_policy.py tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_responses_streaming.py tests/test_native_provider_streaming.py --- src/rotator_library/client/streaming.py | 113 ++++++++++++----------- tests/test_streaming_usage_accounting.py | 61 ++++++++++++ 2 files changed, 121 insertions(+), 53 deletions(-) create mode 100644 tests/test_streaming_usage_accounting.py diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index a5769f1df..31d22499e 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -26,6 +26,8 @@ from ..core.types import ProcessedChunk from ..core.utils import normalize_usage_for_response from ..streaming import StreamEvent, StreamMonitor, stream_event_from_sse_chunk +from ..usage.accounting import UsageRecord, extract_usage_record +from ..usage.costs import CostBreakdown, CostCalculator if TYPE_CHECKING: from ..usage.manager import CredentialContext @@ -82,6 +84,7 @@ async def wrap_stream( prompt_tokens_uncached = 0 completion_tokens = 0 thinking_tokens = 0 + usage_record = UsageRecord(source="stream") assistant_parts: List[str] = [] tool_call_ids: List[str] = [] monitor = StreamMonitor(clock=time.monotonic) @@ -158,52 +161,17 @@ async def wrap_stream( # Only update if not already tool_calls (highest priority) accumulated_finish_reason = processed.finish_reason if processed.usage and isinstance(processed.usage, dict): - # Extract token counts from final chunk - prompt_tokens = processed.usage.get("prompt_tokens", 0) - completion_tokens = processed.usage.get("completion_tokens", 0) - prompt_details = processed.usage.get("prompt_tokens_details") - if prompt_details: - if isinstance(prompt_details, dict): - prompt_tokens_cached = ( - prompt_details.get("cached_tokens", 0) or 0 - ) - prompt_tokens_cache_write = ( - prompt_details.get("cache_creation_tokens", 0) or 0 - ) - else: - prompt_tokens_cached = ( - getattr(prompt_details, "cached_tokens", 0) or 0 - ) - prompt_tokens_cache_write = ( - getattr(prompt_details, "cache_creation_tokens", 0) - or 0 - ) - completion_details = processed.usage.get( - "completion_tokens_details" - ) - if completion_details: - if isinstance(completion_details, dict): - thinking_tokens = ( - completion_details.get("reasoning_tokens", 0) or 0 - ) - else: - thinking_tokens = ( - getattr(completion_details, "reasoning_tokens", 0) - or 0 - ) - if processed.usage.get("cache_read_tokens") is not None: - prompt_tokens_cached = ( - processed.usage.get("cache_read_tokens") or 0 - ) - if processed.usage.get("cache_creation_tokens") is not None: - prompt_tokens_cache_write = ( - processed.usage.get("cache_creation_tokens") or 0 - ) - if thinking_tokens and completion_tokens >= thinking_tokens: - completion_tokens = completion_tokens - thinking_tokens - prompt_tokens_uncached = max( - 0, prompt_tokens - prompt_tokens_cached + usage_record = extract_usage_record( + processed.usage, + model=model, + source="stream_final_chunk", ) + prompt_tokens = usage_record.input_tokens + usage_record.cache_read_tokens + completion_tokens = usage_record.completion_tokens + thinking_tokens = usage_record.reasoning_tokens + prompt_tokens_cached = usage_record.cache_read_tokens + prompt_tokens_cache_write = usage_record.cache_write_tokens + prompt_tokens_uncached = usage_record.prompt_tokens_for_mark_success yield processed.sse_string @@ -269,20 +237,23 @@ async def wrap_stream( # Record usage if stream completed if stream_completed: if cred_context: - approx_cost = 0.0 - if not skip_cost_calculation: - approx_cost = self._calculate_stream_cost( - model, - prompt_tokens_uncached + prompt_tokens_cached, - completion_tokens + thinking_tokens, - ) + cost_breakdown = self._calculate_stream_cost_breakdown( + model, + usage_record, + skip_cost_calculation=skip_cost_calculation, + ) + self._log_stream_usage_accounting( + transaction_logger, + usage_record, + cost_breakdown, + ) cred_context.mark_success( prompt_tokens=prompt_tokens_uncached, completion_tokens=completion_tokens, thinking_tokens=thinking_tokens, prompt_tokens_cache_read=prompt_tokens_cached, prompt_tokens_cache_write=prompt_tokens_cache_write, - approx_cost=approx_cost, + approx_cost=cost_breakdown.total_cost, ) if response_callback and (assistant_parts or tool_call_ids): @@ -552,6 +523,42 @@ def _calculate_stream_cost( lib_logger.debug(f"Stream cost calculation failed for {model}: {exc}") return 0.0 + def _calculate_stream_cost_breakdown( + self, + model: str, + usage_record: UsageRecord, + *, + skip_cost_calculation: bool, + ) -> CostBreakdown: + """Calculate advisory stream cost through the shared cost helper.""" + + if skip_cost_calculation: + return CostBreakdown(pricing_source="skipped") + return CostCalculator().calculate(usage_record, model=model) + + @staticmethod + def _log_stream_usage_accounting( + transaction_logger: Optional[Any], + usage_record: UsageRecord, + cost_breakdown: CostBreakdown, + ) -> None: + """Trace normalized stream usage without affecting stream delivery.""" + + if not transaction_logger: + return + transaction_logger.log_transform_pass( + "usage_accounting_summary", + {"usage": usage_record.to_dict(), "cost": cost_breakdown.to_dict()}, + direction="metadata", + stage="final", + transport="sse", + metadata={ + "source": usage_record.source, + "pricing_source": cost_breakdown.pricing_source, + }, + snapshot=False, + ) + class StreamBuffer: """ diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py new file mode 100644 index 000000000..ad29ae004 --- /dev/null +++ b/tests/test_streaming_usage_accounting.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.client.streaming import StreamingHandler +from rotator_library.transaction_logger import TransactionLogger + + +class FakeCredentialContext: + def __init__(self) -> None: + self.success_kwargs = None + + def mark_success(self, **kwargs) -> None: + self.success_kwargs = kwargs + + +async def _usage_chunks(): + yield {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + yield { + "id": "chunk_2", + "choices": [{"delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 30, + "prompt_tokens_details": {"cached_tokens": 40, "cache_creation_tokens": 5}, + "completion_tokens_details": {"reasoning_tokens": 10}, + }, + } + + +@pytest.mark.asyncio +async def test_streaming_usage_uses_normalized_accounting_and_trace(tmp_path, monkeypatch) -> None: + monkeypatch.setattr( + "rotator_library.usage.costs.litellm.get_model_info", + lambda model: {"input_cost_per_token": 0.001, "output_cost_per_token": 0.002}, + ) + cred_context = FakeCredentialContext() + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "gpt-test", cred_context=cred_context, transaction_logger=logger)] + + assert chunks[-1] == "data: [DONE]\n\n" + assert cred_context.success_kwargs["prompt_tokens"] == 60 + assert cred_context.success_kwargs["prompt_tokens_cache_read"] == 40 + assert cred_context.success_kwargs["prompt_tokens_cache_write"] == 5 + assert cred_context.success_kwargs["completion_tokens"] == 20 + assert cred_context.success_kwargs["thinking_tokens"] == 10 + assert cred_context.success_kwargs["approx_cost"] > 0 + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + assert any(entry["pass_name"] == "usage_accounting_summary" for entry in entries) + + +@pytest.mark.asyncio +async def test_streaming_usage_skip_cost_returns_zero() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "gpt-test", cred_context=cred_context, skip_cost_calculation=True)] + + assert cred_context.success_kwargs["approx_cost"] == 0.0 From 665b14a227c025fcd959562ca680be72285705a6 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 02:57:59 +0200 Subject: [PATCH 063/182] feat(usage): add quota snapshots Adds read-only QuotaSnapshot helpers built from existing CredentialState model/group windows. This provides client-safe quota reporting without changing LimitEngine, WindowLimitChecker, selection, or persistence behavior. Tests: python -m pytest tests/test_usage_quota_snapshots.py tests/test_usage_accounting.py tests/test_selection_engine.py --- src/rotator_library/usage/__init__.py | 3 + src/rotator_library/usage/quota.py | 118 ++++++++++++++++++++++++++ tests/test_usage_quota_snapshots.py | 46 ++++++++++ 3 files changed, 167 insertions(+) create mode 100644 src/rotator_library/usage/quota.py create mode 100644 tests/test_usage_quota_snapshots.py diff --git a/src/rotator_library/usage/__init__.py b/src/rotator_library/usage/__init__.py index fe9f9616d..2ab7f5bd8 100644 --- a/src/rotator_library/usage/__init__.py +++ b/src/rotator_library/usage/__init__.py @@ -58,6 +58,7 @@ from .manager import UsageManager, CredentialContext from .accounting import UsageRecord, extract_usage_record from .costs import CostBreakdown, CostCalculator, ModelPricing +from .quota import QuotaSnapshot, build_quota_snapshots __all__ = [ # Main public API @@ -67,6 +68,8 @@ "CostBreakdown", "CostCalculator", "ModelPricing", + "QuotaSnapshot", + "build_quota_snapshots", "extract_usage_record", # Types "WindowStats", diff --git a/src/rotator_library/usage/quota.py b/src/rotator_library/usage/quota.py new file mode 100644 index 000000000..b42759aa3 --- /dev/null +++ b/src/rotator_library/usage/quota.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Read-only quota snapshot helpers built from existing usage state.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Mapping, Optional + +from ..error_handler import mask_credential +from ..protocols import serialize_value +from .types import CredentialState, WindowStats + + +@dataclass(frozen=True) +class QuotaSnapshot: + """Client-safe view of one usage window. + + Snapshots are reporting-only. They intentionally do not participate in limit + checks, selection, or persistence so existing quota behavior remains owned by + `UsageManager`, `TrackingEngine`, and `LimitEngine`. + """ + + provider: str + model: Optional[str] + quota_group: Optional[str] + credential_id: Optional[str] + window_name: str + limit: Optional[int] + used: int + remaining: Optional[int] + reset_at: Optional[float] + source: str + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "provider": self.provider, + "model": self.model, + "quota_group": self.quota_group, + "credential_id": self.credential_id, + "window_name": self.window_name, + "limit": self.limit, + "used": self.used, + "remaining": self.remaining, + "reset_at": self.reset_at, + "source": self.source, + "metadata": serialize_value(self.metadata), + } + + +def build_quota_snapshots( + *, + provider: str, + states: Mapping[str, CredentialState], + model: Optional[str] = None, + quota_group: Optional[str] = None, + include_credentials: bool = True, +) -> list[QuotaSnapshot]: + """Build read-only quota snapshots from credential states.""" + + snapshots: list[QuotaSnapshot] = [] + for stable_id, state in states.items(): + credential_id = mask_credential(stable_id) if include_credentials else None + if model: + model_stats = state.get_model_stats(model, create=False) + if model_stats: + snapshots.extend( + _snapshots_for_windows( + provider=provider, + model=model, + quota_group=None, + credential_id=credential_id, + windows=model_stats.windows, + source="model", + ) + ) + if quota_group: + group_stats = state.get_group_stats(quota_group, create=False) + if group_stats: + snapshots.extend( + _snapshots_for_windows( + provider=provider, + model=model, + quota_group=quota_group, + credential_id=credential_id, + windows=group_stats.windows, + source="group", + ) + ) + return snapshots + + +def _snapshots_for_windows( + *, + provider: str, + model: Optional[str], + quota_group: Optional[str], + credential_id: Optional[str], + windows: Mapping[str, WindowStats], + source: str, +) -> list[QuotaSnapshot]: + return [ + QuotaSnapshot( + provider=provider, + model=model, + quota_group=quota_group, + credential_id=credential_id, + window_name=window.name, + limit=window.limit, + used=window.request_count, + remaining=window.remaining, + reset_at=window.reset_at, + source=source, + ) + for window in windows.values() + ] diff --git a/tests/test_usage_quota_snapshots.py b/tests/test_usage_quota_snapshots.py new file mode 100644 index 000000000..cc18de095 --- /dev/null +++ b/tests/test_usage_quota_snapshots.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from rotator_library.usage.quota import build_quota_snapshots +from rotator_library.usage.types import CredentialState, WindowStats + + +def _state() -> CredentialState: + state = CredentialState(stable_id="credential-secret", provider="openai", accessor="credential-secret") + model_stats = state.get_model_stats("gpt-test") + model_stats.windows["daily"] = WindowStats(name="daily", request_count=3, limit=10, reset_at=123.0) + group_stats = state.get_group_stats("chat") + group_stats.windows["daily"] = WindowStats(name="daily", request_count=5, limit=20, reset_at=456.0) + return state + + +def test_build_quota_snapshots_for_model_window() -> None: + snapshots = build_quota_snapshots(provider="openai", states={"credential-secret": _state()}, model="gpt-test") + + assert len(snapshots) == 1 + snapshot = snapshots[0] + assert snapshot.provider == "openai" + assert snapshot.model == "gpt-test" + assert snapshot.quota_group is None + assert snapshot.used == 3 + assert snapshot.remaining == 7 + assert snapshot.credential_id != "credential-secret" + + +def test_build_quota_snapshots_for_group_window_without_credentials() -> None: + snapshots = build_quota_snapshots( + provider="openai", + states={"credential-secret": _state()}, + model="gpt-test", + quota_group="chat", + include_credentials=False, + ) + + group_snapshot = [snapshot for snapshot in snapshots if snapshot.source == "group"][0] + assert group_snapshot.quota_group == "chat" + assert group_snapshot.used == 5 + assert group_snapshot.remaining == 15 + assert group_snapshot.credential_id is None + + +def test_build_quota_snapshots_missing_windows_returns_empty() -> None: + assert build_quota_snapshots(provider="openai", states={"credential-secret": _state()}, model="missing") == [] From 07c7a21a31fe2704bb0a7360b883c3da0c7b9841 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:00:18 +0200 Subject: [PATCH 064/182] feat(usage): trace responses and native usage Adds normalized usage accounting traces for Responses create/stream paths and native provider execution. Responses bridge now preserves usage detail fields so reasoning and cache buckets can be normalized without changing stored response behavior. Tests: python -m pytest tests/test_responses_usage_accounting.py tests/test_native_usage_accounting.py tests/test_responses_service.py tests/test_native_provider_executor.py tests/test_responses_bridge.py tests/test_responses_routes.py tests/test_responses_streaming.py tests/test_protocol_responses.py tests/test_native_provider_streaming.py --- .../native_provider/executor.py | 15 +++++++ src/rotator_library/responses/bridge.py | 9 +++- src/rotator_library/responses/service.py | 26 ++++++++++++ tests/test_native_usage_accounting.py | 41 +++++++++++++++++++ tests/test_responses_service.py | 1 + tests/test_responses_usage_accounting.py | 36 ++++++++++++++++ 6 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 tests/test_native_usage_accounting.py create mode 100644 tests/test_responses_usage_accounting.py diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index cddc8d60f..84e776f2d 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -10,6 +10,7 @@ from ..adapters import get_adapter, run_adapter_chain from ..field_cache import FieldCacheEngine from ..protocols import get_protocol +from ..usage.accounting import extract_usage_record from .context import NativeProviderContext from .http import NativeHTTPTransport from .streaming import stream_event_payload @@ -54,6 +55,20 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont self._trace(context, "parsed_native_provider_response", provider_response, direction="response", stage="protocol") provider_response = await run_adapter_chain(adapters, provider_response, context.adapter_context(), stage="response") await cache_engine.extract("response", provider_response, context.field_cache_context(), transaction_logger=logger) + usage_record = extract_usage_record( + provider_response, + provider=context.provider, + model=context.model, + source="native_provider_response", + ) + self._trace( + context, + "usage_accounting_summary", + {"usage": usage_record.to_dict()}, + direction="metadata", + stage="final", + snapshot=False, + ) self._trace(context, "final_client_response", provider_response, direction="response", stage="final") return provider_response except Exception as exc: diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index eb919731a..3c2c9425c 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -225,11 +225,18 @@ def _status_from_chat(response: dict[str, Any]) -> str: def _usage_to_responses(usage: Any) -> Any: if not isinstance(usage, dict): return usage - return { + result = { "input_tokens": usage.get("prompt_tokens", usage.get("input_tokens", 0)), "output_tokens": usage.get("completion_tokens", usage.get("output_tokens", 0)), "total_tokens": usage.get("total_tokens", 0), } + prompt_details = usage.get("prompt_tokens_details") or usage.get("input_tokens_details") + if isinstance(prompt_details, dict): + result["input_tokens_details"] = {"cached_tokens": prompt_details.get("cached_tokens", 0)} + completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") + if isinstance(completion_details, dict): + result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} + return result def _as_dict(value: Any) -> dict[str, Any]: diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 35639c846..93d9dfc9b 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -11,6 +11,7 @@ from ..protocols import ProtocolContext from ..streaming import StreamEvent, StreamMonitor +from ..usage.accounting import extract_usage_record from ..protocols.responses import ResponsesProtocol from .bridge import ResponsesBridge from .store import InMemoryResponsesStore, ResponsesStore @@ -93,6 +94,7 @@ async def create_response( response_payload = self.bridge.from_chat_response(chat_response, unified) self._trace(transaction_logger, "parsed_unified_response", response_payload, direction="response", stage="protocol") + self._trace_responses_usage(transaction_logger, response_payload, unified.model, source="responses_response") if raw_request.get("store", True): stored = self._stored_response(raw_request, response_payload, parent) @@ -208,6 +210,7 @@ async def stream_response( done_item = output_item_done_payload(state) yield formatter.format_event("response.output_item.done", done_item) completed = response_completed_payload(state, _usage_to_responses_stream(usage)) + self._trace_responses_usage(transaction_logger, completed, unified.model, source="responses_stream") await self._store_stream_response(stream_request, completed, parent) self._trace(transaction_logger, "stored_responses_stream_response", completed, direction="metadata", stage="final") monitor.complete() @@ -344,6 +347,29 @@ def _log_transform_error(transaction_logger: Optional[Any], pass_name: str, erro if transaction_logger: transaction_logger.log_transform_error(pass_name, error, payload=payload, stage="adapter", protocol="responses") + def _trace_responses_usage( + self, + transaction_logger: Optional[Any], + response_payload: dict[str, Any], + model: str, + *, + source: str, + ) -> None: + """Trace normalized Responses usage without changing stored payloads.""" + + usage = response_payload.get("usage") if isinstance(response_payload, dict) else None + if not usage: + return + record = extract_usage_record(usage, provider="responses", model=model, source=source) + self._trace( + transaction_logger, + "usage_accounting_summary", + {"usage": record.to_dict()}, + direction="metadata", + stage="final", + metadata={"source": source}, + ) + async def _store_stream_response( self, raw_request: dict[str, Any], diff --git a/tests/test_native_usage_accounting.py b/tests/test_native_usage_accounting.py new file mode 100644 index 000000000..62116a6d4 --- /dev/null +++ b/tests/test_native_usage_accounting.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from rotator_library.transaction_logger import TransactionLogger + + +class FakeNativeTransport(NativeHTTPTransport): + def __init__(self): + pass + + async def post_json(self, endpoint, *, headers, payload): + return { + "id": "chatcmpl_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 6, "completion_tokens_details": {"reasoning_tokens": 2}}, + } + + +@pytest.mark.asyncio +async def test_native_executor_traces_normalized_usage(tmp_path) -> None: + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="openai", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + headers={}, + transaction_logger=logger, + ) + + await NativeProviderExecutor().execute({"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]}, context, FakeNativeTransport()) + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + usage_entries = [entry for entry in entries if entry["pass_name"] == "usage_accounting_summary"] + assert usage_entries[-1]["data"]["usage"]["completion_tokens"] == 4 + assert usage_entries[-1]["data"]["usage"]["reasoning_tokens"] == 2 diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py index 0408958c6..1832491ea 100644 --- a/tests/test_responses_service.py +++ b/tests/test_responses_service.py @@ -117,6 +117,7 @@ async def test_service_emits_transform_trace_passes(tmp_path) -> None: "responses_bridge_chat_request", "raw_chat_bridge_response", "parsed_unified_response", + "usage_accounting_summary", "stored_responses_response", "final_responses_response", ] diff --git a/tests/test_responses_usage_accounting.py b/tests/test_responses_usage_accounting.py new file mode 100644 index 000000000..48be0d027 --- /dev/null +++ b/tests/test_responses_usage_accounting.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import json + +import pytest + +from rotator_library.responses import InMemoryResponsesStore, ResponsesService +from rotator_library.transaction_logger import TransactionLogger + + +class FakeUsageClient: + async def acompletion(self, **kwargs): + return { + "id": "chatcmpl_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "hi"}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 20, + "completion_tokens": 8, + "completion_tokens_details": {"reasoning_tokens": 3}, + }, + } + + +@pytest.mark.asyncio +async def test_responses_create_traces_normalized_usage(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + await service.create_response({"model": "gpt-test", "input": "hello"}, FakeUsageClient(), transaction_logger=logger) + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + usage_entries = [entry for entry in entries if entry["pass_name"] == "usage_accounting_summary"] + assert usage_entries + assert usage_entries[-1]["data"]["usage"]["completion_tokens"] == 5 + assert usage_entries[-1]["data"]["usage"]["reasoning_tokens"] == 3 From f047bf141ca500a281912ae13bdaa2a2c9519ae1 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:10:35 +0200 Subject: [PATCH 065/182] fix(usage): close Phase 9 review gaps Removes dead legacy accounting methods, adds regression coverage for fallback success-only usage and zero-usage streams, and preserves Responses streaming usage detail fields when provider chunks include cache or reasoning details. Tests: python -m pytest tests/test_request_executor_fallback_groups.py tests/test_streaming_usage_accounting.py tests/test_responses_usage_accounting.py tests/test_responses_streaming.py tests/test_executor_usage_accounting.py tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_native_usage_accounting.py --- src/rotator_library/client/executor.py | 80 ------------------- src/rotator_library/client/streaming.py | 22 ----- src/rotator_library/responses/service.py | 9 ++- .../test_request_executor_fallback_groups.py | 24 ++++++ tests/test_responses_usage_accounting.py | 20 +++++ tests/test_streaming_usage_accounting.py | 18 +++++ 6 files changed, 70 insertions(+), 103 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 41c83fe63..2c8b622a9 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1845,64 +1845,6 @@ def _trace_usage_accounting( snapshot=False, ) - def _legacy_extract_usage_tokens(self, response: Any) -> tuple[int, int, int, int, int]: - """Previous extraction logic kept temporarily for comparison/debugging.""" - - prompt_tokens = 0 - completion_tokens = 0 - cached_tokens = 0 - cache_write_tokens = 0 - thinking_tokens = 0 - - if hasattr(response, "usage") and response.usage: - prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 - completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 - - prompt_details = getattr(response.usage, "prompt_tokens_details", None) - if prompt_details: - if isinstance(prompt_details, dict): - cached_tokens = prompt_details.get("cached_tokens", 0) or 0 - cache_write_tokens = ( - prompt_details.get("cache_creation_tokens", 0) or 0 - ) - else: - cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0 - cache_write_tokens = ( - getattr(prompt_details, "cache_creation_tokens", 0) or 0 - ) - - completion_details = getattr( - response.usage, "completion_tokens_details", None - ) - if completion_details: - if isinstance(completion_details, dict): - thinking_tokens = completion_details.get("reasoning_tokens", 0) or 0 - else: - thinking_tokens = ( - getattr(completion_details, "reasoning_tokens", 0) or 0 - ) - - cache_read_tokens = getattr(response.usage, "cache_read_tokens", None) - if cache_read_tokens is not None: - cached_tokens = cache_read_tokens or 0 - cache_creation_tokens = getattr( - response.usage, "cache_creation_tokens", None - ) - if cache_creation_tokens is not None: - cache_write_tokens = cache_creation_tokens or 0 - - if thinking_tokens and completion_tokens >= thinking_tokens: - completion_tokens = completion_tokens - thinking_tokens - - uncached_prompt = max(0, prompt_tokens - cached_tokens) - return ( - uncached_prompt, - completion_tokens, - cached_tokens, - cache_write_tokens, - thinking_tokens, - ) - @staticmethod def _normalize_response_usage(response: Any, model: str) -> Any: """ @@ -1916,28 +1858,6 @@ def _normalize_response_usage(response: Any, model: str) -> Any: normalize_usage_for_response(response.usage, model) return response - def _calculate_cost(self, provider: str, model: str, response: Any) -> float: - plugin = self._get_plugin_instance(provider) - if plugin and getattr(plugin, "skip_cost_calculation", False): - return 0.0 - - try: - if isinstance(response, litellm.EmbeddingResponse): - model_info = litellm.get_model_info(model) - input_cost = model_info.get("input_cost_per_token") - if input_cost: - return (response.usage.prompt_tokens or 0) * input_cost - return 0.0 - - cost = litellm.completion_cost( - completion_response=response, - model=model, - ) - return float(cost) if cost is not None else 0.0 - except Exception as exc: - lib_logger.debug(f"Cost calculation failed for {model}: {exc}") - return 0.0 - async def _transaction_logging_stream_wrapper( self, stream: AsyncGenerator[str, None], diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 31d22499e..b13291fa4 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -20,8 +20,6 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, TYPE_CHECKING -import litellm - from ..core.errors import StreamedAPIError, CredentialNeedsReauthError from ..core.types import ProcessedChunk from ..core.utils import normalize_usage_for_response @@ -503,26 +501,6 @@ def _try_extract_error( return None - def _calculate_stream_cost( - self, - model: str, - prompt_tokens: int, - completion_tokens: int, - ) -> float: - try: - model_info = litellm.get_model_info(model) - input_cost = model_info.get("input_cost_per_token") - output_cost = model_info.get("output_cost_per_token") - total_cost = 0.0 - if input_cost: - total_cost += prompt_tokens * input_cost - if output_cost: - total_cost += completion_tokens * output_cost - return total_cost - except Exception as exc: - lib_logger.debug(f"Stream cost calculation failed for {model}: {exc}") - return 0.0 - def _calculate_stream_cost_breakdown( self, model: str, diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 93d9dfc9b..9b58a6a00 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -402,8 +402,15 @@ def _chunk_text_delta(chunk: dict[str, Any]) -> str: def _usage_to_responses_stream(usage: Any) -> Any: if not isinstance(usage, dict): return usage - return { + result = { "input_tokens": usage.get("prompt_tokens", usage.get("input_tokens", 0)), "output_tokens": usage.get("completion_tokens", usage.get("output_tokens", 0)), "total_tokens": usage.get("total_tokens", 0), } + prompt_details = usage.get("prompt_tokens_details") or usage.get("input_tokens_details") + if isinstance(prompt_details, dict): + result["input_tokens_details"] = {"cached_tokens": prompt_details.get("cached_tokens", 0)} + completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") + if isinstance(completion_details, dict): + result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} + return result diff --git a/tests/test_request_executor_fallback_groups.py b/tests/test_request_executor_fallback_groups.py index c408c2a4d..ad7799999 100644 --- a/tests/test_request_executor_fallback_groups.py +++ b/tests/test_request_executor_fallback_groups.py @@ -162,3 +162,27 @@ async def test_non_streaming_fallback_group_emits_routing_trace(tmp_path) -> Non assert "routing_target_attempt_failed" in pass_names assert "routing_fallback_selected" in pass_names assert "routing_target_attempt_succeeded" in pass_names + + +@pytest.mark.asyncio +async def test_non_streaming_fallback_group_records_usage_only_on_successful_target() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + recorded = [] + + async def fake_execute(self, context): + attempts.append(context.provider) + if len(attempts) == 1: + raise ClassifiedFailure("rate_limit") + recorded.append(context.provider) + return {"id": "ok", "model": context.model} + + executor._execute_non_streaming = MethodType(fake_execute, executor) + + result = await executor._execute_non_streaming_with_fallback( + _context(routing_targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1"))) + ) + + assert result == {"id": "ok", "model": "openai/gpt-5.1"} + assert attempts == ["codex", "openai"] + assert recorded == ["openai"] diff --git a/tests/test_responses_usage_accounting.py b/tests/test_responses_usage_accounting.py index 48be0d027..789868bb0 100644 --- a/tests/test_responses_usage_accounting.py +++ b/tests/test_responses_usage_accounting.py @@ -22,6 +22,15 @@ async def acompletion(self, **kwargs): } +class FakeStreamingUsageClient: + async def acompletion(self, **kwargs): + async def gen(): + yield 'data: {"id":"chunk_1","choices":[{"delta":{"content":"hi"}}]}\n\n' + yield 'data: {"id":"chunk_2","choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":20,"completion_tokens":8,"total_tokens":28,"completion_tokens_details":{"reasoning_tokens":3}}}\n\n' + + return gen() + + @pytest.mark.asyncio async def test_responses_create_traces_normalized_usage(tmp_path) -> None: logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) @@ -34,3 +43,14 @@ async def test_responses_create_traces_normalized_usage(tmp_path) -> None: assert usage_entries assert usage_entries[-1]["data"]["usage"]["completion_tokens"] == 5 assert usage_entries[-1]["data"]["usage"]["reasoning_tokens"] == 3 + + +@pytest.mark.asyncio +async def test_responses_stream_preserves_usage_details_in_completed_event() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + chunks = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "hello", "stream": True}, FakeStreamingUsageClient())] + completed = [chunk for chunk in chunks if "response.completed" in chunk][0] + payload = json.loads(completed.split("data: ", 1)[1]) + + assert payload["usage"]["output_tokens_details"] == {"reasoning_tokens": 3} diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py index ad29ae004..0534092aa 100644 --- a/tests/test_streaming_usage_accounting.py +++ b/tests/test_streaming_usage_accounting.py @@ -30,6 +30,11 @@ async def _usage_chunks(): } +async def _zero_usage_chunks(): + yield {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + yield {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}]} + + @pytest.mark.asyncio async def test_streaming_usage_uses_normalized_accounting_and_trace(tmp_path, monkeypatch) -> None: monkeypatch.setattr( @@ -59,3 +64,16 @@ async def test_streaming_usage_skip_cost_returns_zero() -> None: _ = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "gpt-test", cred_context=cred_context, skip_cost_calculation=True)] assert cred_context.success_kwargs["approx_cost"] == 0.0 + + +@pytest.mark.asyncio +async def test_streaming_without_usage_still_marks_success_with_zero_usage() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_zero_usage_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["prompt_tokens"] == 0 + assert cred_context.success_kwargs["completion_tokens"] == 0 + assert cred_context.success_kwargs["thinking_tokens"] == 0 + assert cred_context.success_kwargs["prompt_tokens_cache_read"] == 0 + assert cred_context.success_kwargs["prompt_tokens_cache_write"] == 0 From 751ef724d5b9b13fc0149c9ccaf4e89f904e75f7 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:26:08 +0200 Subject: [PATCH 066/182] chore(usage): remove unused accounting wrapper Removes the final unused RequestExecutor usage-token wrapper after Phase 9 moved production accounting to UsageRecord and CostCalculator directly. Tests: python -m pytest tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_usage_quota_snapshots.py tests/test_executor_usage_accounting.py tests/test_streaming_usage_accounting.py tests/test_responses_usage_accounting.py tests/test_native_usage_accounting.py tests/test_request_executor_fallback_groups.py --- src/rotator_library/client/executor.py | 14 +------------- tests/test_executor_usage_accounting.py | 12 ------------ 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 2c8b622a9..c1f6ed4ce 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1785,18 +1785,6 @@ async def _validate_request( if isinstance(result, str): raise ValueError(result) - def _extract_usage_tokens(self, response: Any) -> tuple[int, int, int, int, int]: - """Extract legacy usage tuple through the normalized usage record.""" - - record = extract_usage_record(response) - return ( - record.prompt_tokens_for_mark_success, - record.completion_tokens, - record.cache_read_tokens, - record.cache_write_tokens, - record.reasoning_tokens, - ) - def _account_for_response_usage( self, provider: str, @@ -1852,7 +1840,7 @@ def _normalize_response_usage(response: Any, model: str) -> Any: Delegates to normalize_usage_for_response which handles both dicts (streaming) and pydantic objects (non-streaming). - Internal tracking values from _extract_usage_tokens are unaffected. + Internal tracking values from UsageRecord accounting are unaffected. """ if hasattr(response, "usage") and response.usage: normalize_usage_for_response(response.usage, model) diff --git a/tests/test_executor_usage_accounting.py b/tests/test_executor_usage_accounting.py index 4730c8857..b31114ee8 100644 --- a/tests/test_executor_usage_accounting.py +++ b/tests/test_executor_usage_accounting.py @@ -47,15 +47,3 @@ def test_executor_accounts_for_non_streaming_usage_and_cost_trace(tmp_path, monk entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] assert entries[-1]["pass_name"] == "usage_accounting_summary" assert entries[-1]["data"]["usage"]["total_tokens"] == usage.total_tokens - - -def test_executor_extract_usage_tokens_uses_normalized_record() -> None: - response = SimpleNamespace( - usage=SimpleNamespace( - prompt_tokens=10, - completion_tokens=5, - completion_tokens_details={"reasoning_tokens": 2}, - ) - ) - - assert _executor()._extract_usage_tokens(response) == (10, 3, 0, 0, 2) From a9f57297caadd325f70de8a1ae81d5afaedc853d Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:29:32 +0200 Subject: [PATCH 067/182] docs(experimental): plan config polish Adds the full Phase 10 plan for optional JSON config, env override precedence, routing/pricing/streaming/field-cache config, validation, docs, tests, and review workflow. --- docs/experimental/phase-10-config-polish.md | 286 ++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 docs/experimental/phase-10-config-polish.md diff --git a/docs/experimental/phase-10-config-polish.md b/docs/experimental/phase-10-config-polish.md new file mode 100644 index 000000000..fd0b884ab --- /dev/null +++ b/docs/experimental/phase-10-config-polish.md @@ -0,0 +1,286 @@ +# Phase 10 Plan: Config Polish + +## Goal + +Make the new protocol/routing/field-cache/streaming/pricing features configurable in a consistent, documented, testable way without replacing the current `.env` workflow. Phase 10 should add a small optional JSON configuration layer, keep environment variables as the final override, document all new knobs in `.env.example`, and expose validation helpers so invalid config fails clearly instead of silently changing routing or accounting behavior. + +## Non-Goals + +- Do not introduce SQLite or any database. +- Do not replace `.env` as the primary user workflow. +- Do not replace provider class declarations, `UsageManager`, `SessionTracker`, `SelectionEngine`, or provider quota trackers. +- Do not implement full multi-user/security config. +- Do not move secrets into JSON config. +- Do not require JSON config for existing deployments. +- Do not rewrite every direct `os.getenv()` call in the repo. +- Do not change default routing, pricing, usage, or streaming behavior when no new config is present. + +## Configuration Precedence + +- Built-in defaults and provider declarations are the base. +- Optional JSON config may provide structured routing/pricing/streaming/provider metadata. +- Environment variables override JSON config. +- Request-level explicit routing/provider fields still win for that request where already supported. +- Secrets remain environment/OAuth-file based; JSON config should not contain API keys or bearer tokens. + +## Current Code Context + +- `.env.example` documents many legacy env vars but not all Phase 6-9 additions. +- `routing/config.py` currently reads fallback groups and model routes from env only. +- Phase 9 added `ModelPricing`, `CostCalculator`, and `ProviderInterface.get_model_pricing()` but not env/JSON pricing. +- Phase 8 added stream metrics primitives but runtime TTFB/stall/heartbeat env knobs are not implemented. +- Phase 7 added provider cooldown env parsing in `retry_policy.provider_cooldown_env()`. +- Phase 3 added field-cache rule dataclasses, but no user-facing JSON config parser for cache rules. +- Provider base URLs and native provider declarations are mostly provider methods or direct env vars. +- Existing usage config env parsing is large and should not be rewritten in Phase 10. +- Reports are user-facing only and should remain uncommitted. + +## Files To Add + +- `src/rotator_library/config/experimental.py` +- `tests/test_experimental_config.py` +- `tests/test_config_pricing.py` +- `tests/test_config_routing_json.py` +- `tests/test_config_stream_settings.py` +- `tests/test_env_example_experimental_config.py` +- Maybe `docs/experimental/config-reference.md` if the implementation needs more detail than `.env.example`. + +## Files Likely To Touch + +- `src/rotator_library/routing/config.py` +- `src/rotator_library/usage/costs.py` +- `src/rotator_library/client/streaming.py` +- `src/rotator_library/client/executor.py` only if streaming env knobs need to be passed through. +- `src/rotator_library/field_cache/rules.py` or equivalent only if JSON parsing needs a small helper. +- `src/rotator_library/providers/provider_interface.py` only if pricing declarations need docstring/typing alignment. +- `.env.example` +- `docs/experimental/phase-10-config-polish.md` + +## JSON Config Model + +Add `ExperimentalConfig` dataclass with optional sections: + +- `routing` +- `pricing` +- `streaming` +- `field_cache` +- `providers` + +Add loader: + +- `load_experimental_config(path=None, env=None)` +- If `path` is `None`, read from `LLM_PROXY_CONFIG_FILE` or `PROXY_CONFIG_FILE`. +- Missing path returns an empty config. +- Invalid JSON raises `ExperimentalConfigError`. +- Unknown top-level sections should be preserved in metadata or warned about, not fatal. + +Add helpers: + +- `as_bool()` +- `as_int()` +- `as_float()` +- `env_key(provider, model, suffix)` sanitizing provider/model names consistently. + +JSON should not read or interpolate secrets. + +## Routing JSON + +Support JSON shape: + +- `routing.fallback_groups..targets = ["codex/gpt-5.1-codex@native", "openai/gpt-5.1@litellm_fallback"]` +- `routing.fallback_groups..failover_on = ["rate_limit", "server_error"]` +- `routing.fallback_groups..stop_on = ["authentication", "validation"]` +- `routing.model_routes. = "group:code_chain"` or target spec. + +`routing/config.py` should merge JSON first, then env: + +- env `FALLBACK_GROUPS`, `FALLBACK_GROUP_*`, `MODEL_ROUTE_*` override or add entries. +- env overrides should retain current behavior for existing deployments. + +Validation: + +- empty groups are invalid. +- `group:` model routes must reference known groups after merge. +- invalid target specs raise `RoutingConfigError`. + +Tests: + +- JSON-only routing works. +- env override replaces JSON group targets. +- env model route can reference JSON group. +- invalid JSON group route fails clearly. +- existing env-only routing tests still pass. + +## Pricing Config + +Add env pricing support in `CostCalculator` or a small helper: + +- `MODEL_PRICE_{PROVIDER}_{MODEL}_INPUT` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_OUTPUT` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_CACHE_READ` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_CACHE_WRITE` +- `MODEL_PRICE_{PROVIDER}_{MODEL}_REASONING` + +Add JSON pricing support: + +- `pricing...input` +- `pricing...output` +- `pricing...cache_read` +- `pricing...cache_write` +- `pricing...reasoning` +- optional `currency` + +Precedence: + +- provider explicit `get_model_pricing()` first, because provider code may know exact native pricing. +- env pricing overrides JSON pricing. +- JSON pricing applies before LiteLLM fallback. +- LiteLLM remains last fallback. + +If env parsing is invalid, log warning and ignore that component. + +Tests: + +- JSON pricing calculates buckets. +- env pricing overrides JSON pricing. +- provider explicit pricing remains highest priority. +- missing pricing remains `unavailable`. +- `skip_cost_calculation` still wins over all pricing sources. + +## Streaming Settings + +Add `StreamRuntimeSettings` dataclass: + +- `ttfb_timeout_seconds` +- `stall_timeout_seconds` +- `heartbeat_seconds` +- `trace_metrics` + +Load from JSON `streaming` and env: + +- `STREAM_TTFB_TIMEOUT_SECONDS` +- `STREAM_STALL_TIMEOUT_SECONDS` +- `STREAM_HEARTBEAT_SECONDS` +- `STREAM_TRACE_METRICS` + +Phase 10 should not enforce timeouts by default. + +Minimal runtime integration: + +- `StreamingHandler.wrap_stream()` should read settings and conditionally emit lifecycle metrics only when trace metrics enabled, default true to preserve Phase 8 behavior. +- Do not implement heartbeat injection unless small and obviously safe. If implemented, default disabled and only emit SSE comments during controlled tests. +- Do not abort streams on TTFB/stall by default. If values are configured, trace `stream_stall_detected` when observable, but do not sever client streams unless explicitly `STREAM_STALL_ABORT=true` is introduced and tested. Prefer no abort in Phase 10. + +Tests: + +- JSON/env settings parse. +- env overrides JSON. +- `STREAM_TRACE_METRICS=false` suppresses lifecycle trace passes while SSE output remains unchanged. +- default trace behavior remains current. + +## Field-Cache Config + +If feasible, add parser for JSON field-cache rule declarations without wiring every provider: + +- `field_cache..[]` +- fields: `name`, `source`, `path`, `scope`, `mode`, `target_path`, `max_entries`. + +Keep it as helper-only if integration is too risky: + +- parse into existing `FieldCacheRule` dataclasses. +- providers/Phase 10+ can call it from `get_field_cache_rules()`. + +Tests: + +- valid JSON rule parses. +- invalid paths/rule modes fail clearly. +- provider/model wildcard merge order documented and tested if implemented. + +If the existing field-cache dataclasses do not support all desired fields cleanly, document and defer live provider integration. + +## Provider Metadata Config + +- Support safe non-secret provider metadata only. +- API base URLs can remain existing `*_API_BASE` env vars for now. +- JSON may define provider protocol/adapter names for future use, but Phase 10 should not make untrusted JSON instantiate arbitrary classes. +- Allow only names that already exist in protocol/adapter registries if validation is implemented. +- Do not place API keys, OAuth tokens, or authorization headers in JSON. +- Tests can validate config rejects/ignores secret-looking keys like `api_key`, `authorization`, `access_token`. + +## Config Validation And Diagnostics + +Add a validation result: + +- warnings for unknown sections/keys. +- errors for invalid routing target specs and invalid numeric pricing. +- no credential values in error messages. + +Add a CLI-free helper test; do not add a new CLI unless tiny. + +Add transform trace or startup log only if a config is loaded: + +- log path, loaded sections, warning count. +- no config contents with secrets. + +## `.env.example` Updates + +Add section for Phase 6 routing: + +- `FALLBACK_GROUPS` +- `FALLBACK_GROUP_` +- `MODEL_ROUTE_` +- execution suffixes `@auto`, `@native`, `@custom`, `@litellm_fallback`. + +Add section for Phase 7 provider cooldown: + +- `PROVIDER_COOLDOWN_MIN_SECONDS` +- `PROVIDER_COOLDOWN_DEFAULT_SECONDS` +- `PROVIDER_COOLDOWN_ON_QUOTA` +- transient retry delay/jitter already in defaults but should be documented. + +Add section for Phase 8 streaming: + +- `STREAM_RETRY_ON_REASONING_ONLY` +- `STREAM_TRACE_METRICS` +- `STREAM_TTFB_TIMEOUT_SECONDS` +- `STREAM_STALL_TIMEOUT_SECONDS` +- `STREAM_HEARTBEAT_SECONDS` + +Add section for Phase 9 pricing: + +- `MODEL_PRICE___INPUT` +- `MODEL_PRICE___OUTPUT` +- cache/reasoning variants. + +Add optional JSON config: + +- `LLM_PROXY_CONFIG_FILE=./config/llm-proxy.json` +- note that env overrides JSON and secrets should stay in env/OAuth files. + +Keep docs concise; do not document every old legacy env in Phase 10 unless already present. + +## Test Modernization + +- Do not take on the entire stale broad test suite unless small. +- Phase 10 can add focused tests for config parsing and maintained regression set. +- Existing stale ignored tests should not block completion unless they are tracked and part of maintained runs. + +## Implementation Checkpoints + +1. Add `ExperimentalConfig` loader and validation helpers with tests. +2. Integrate JSON+env routing config merge with existing routing tests. +3. Add JSON/env model pricing support to `CostCalculator` with tests. +4. Add stream runtime settings parsing and `STREAM_TRACE_METRICS` runtime integration with tests. +5. Add field-cache JSON rule parser helper if feasible with tests. +6. Update `.env.example` and optional config reference doc. +7. Run Phase 10 focused tests and Phase 1-9 maintained regressions. +8. Review with `explore` and `explore-heavy`; fix findings; write uncommitted Phase 10 report. + +## Risks And Mitigations + +- Config precedence confusion. Mitigation: tests for defaults, JSON, env override, provider explicit override, and request-level route behavior. +- Accidentally allowing secrets in JSON. Mitigation: validation rejects secret-looking keys in safe config sections. +- Routing merge could alter existing env behavior. Mitigation: env-only tests remain unchanged and env overrides JSON. +- Pricing could become authoritative incorrectly. Mitigation: keep costs advisory and return zero/unavailable when ambiguous. +- Streaming trace toggles could hide useful debug data unexpectedly. Mitigation: default trace metrics remains enabled. +- Over-scoping into a full config framework. Mitigation: only structured config for new experimental features; leave legacy env parsing intact. From 3059fcff8eb8ee54fb9a51063cd06b7fb3d0d802 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:32:46 +0200 Subject: [PATCH 068/182] feat(config): add experimental config loader Adds an optional JSON config loader for experimental routing, pricing, streaming, field-cache, and provider metadata. The loader rejects secret-like JSON keys, keeps env vars as the override layer, and stays out of eager config imports to avoid startup import cycles. Tests: python -m pytest tests/test_experimental_config.py --- src/rotator_library/config/experimental.py | 322 +++++++++++++++++++++ tests/test_experimental_config.py | 73 +++++ 2 files changed, 395 insertions(+) create mode 100644 src/rotator_library/config/experimental.py create mode 100644 tests/test_experimental_config.py diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py new file mode 100644 index 000000000..8f72ed385 --- /dev/null +++ b/src/rotator_library/config/experimental.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Optional structured configuration for experimental native features. + +The proxy remains environment-first: existing `.env` variables keep working and +environment variables override this JSON layer. This module intentionally avoids +secrets. API keys, OAuth tokens, bearer headers, and similar values must remain +in environment variables or provider-managed credential files. +""" + +from __future__ import annotations + +import json +import os +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Mapping, Optional + +from ..field_cache import FieldCacheInjection, FieldCacheRule +from ..usage.costs import ModelPricing + +_CONFIG_ENV_KEYS = ("LLM_PROXY_CONFIG_FILE", "PROXY_CONFIG_FILE") +_KNOWN_SECTIONS = {"routing", "pricing", "streaming", "field_cache", "providers"} +_SECRET_KEY_PARTS = ("api_key", "authorization", "access_token", "refresh_token", "client_secret", "bearer_token", "password") + + +class ExperimentalConfigError(ValueError): + """Raised when optional structured config is malformed or unsafe.""" + + +@dataclass(frozen=True) +class ExperimentalConfig: + """Parsed optional JSON config. + + Sections are stored as dictionaries rather than deep custom classes so Phase + 10 can layer config onto existing feature-specific parsers without creating + a second full application configuration system. + """ + + routing: dict[str, Any] = field(default_factory=dict) + pricing: dict[str, Any] = field(default_factory=dict) + streaming: dict[str, Any] = field(default_factory=dict) + field_cache: dict[str, Any] = field(default_factory=dict) + providers: dict[str, Any] = field(default_factory=dict) + unknown_sections: dict[str, Any] = field(default_factory=dict) + warnings: tuple[str, ...] = () + path: Optional[str] = None + + @property + def is_empty(self) -> bool: + return not (self.routing or self.pricing or self.streaming or self.field_cache or self.providers or self.unknown_sections) + + +@dataclass(frozen=True) +class StreamRuntimeSettings: + """Runtime stream observability settings. + + Timeout and heartbeat values are parsed here for future enforcement, but + Phase 10 only wires `trace_metrics` into runtime behavior. This avoids + surprising long-running reasoning streams while still validating config. + """ + + ttfb_timeout_seconds: Optional[float] = None + stall_timeout_seconds: Optional[float] = None + heartbeat_seconds: Optional[float] = None + trace_metrics: bool = True + + +def load_experimental_config(path: str | os.PathLike[str] | None = None, env: Mapping[str, str] | None = None) -> ExperimentalConfig: + """Load optional JSON config from an explicit path or config env var.""" + + source = env if env is not None else os.environ + resolved = Path(path) if path is not None else _path_from_env(source) + if resolved is None or not resolved.exists(): + return ExperimentalConfig(path=str(resolved) if resolved is not None else None) + try: + data = json.loads(resolved.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ExperimentalConfigError(f"Invalid JSON config at {resolved}: {exc.msg}") from exc + if not isinstance(data, dict): + raise ExperimentalConfigError("JSON config root must be an object") + _reject_secret_keys(data) + warnings = tuple(f"Unknown config section '{key}' ignored by current runtime" for key in data if key not in _KNOWN_SECTIONS) + unknown = {key: value for key, value in data.items() if key not in _KNOWN_SECTIONS} + return ExperimentalConfig( + routing=_dict_section(data, "routing"), + pricing=_dict_section(data, "pricing"), + streaming=_dict_section(data, "streaming"), + field_cache=_dict_section(data, "field_cache"), + providers=_dict_section(data, "providers"), + unknown_sections=unknown, + warnings=warnings, + path=str(resolved), + ) + + +def load_config_from_mapping(data: Mapping[str, Any]) -> ExperimentalConfig: + """Build config from an in-memory mapping for tests and provider helpers.""" + + _reject_secret_keys(data) + warnings = tuple(f"Unknown config section '{key}' ignored by current runtime" for key in data if key not in _KNOWN_SECTIONS) + return ExperimentalConfig( + routing=_dict_section(data, "routing"), + pricing=_dict_section(data, "pricing"), + streaming=_dict_section(data, "streaming"), + field_cache=_dict_section(data, "field_cache"), + providers=_dict_section(data, "providers"), + unknown_sections={key: value for key, value in data.items() if key not in _KNOWN_SECTIONS}, + warnings=warnings, + ) + + +def get_configured_model_pricing( + provider: str, + model: str, + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> Optional[ModelPricing]: + """Return JSON/env pricing for a provider/model, with env taking priority.""" + + source = env if env is not None else os.environ + env_pricing = _pricing_from_env(provider, model, source) + if env_pricing: + return env_pricing + active = config if config is not None else load_experimental_config(env=source) + pricing_section = active.pricing.get(provider, {}) if isinstance(active.pricing, dict) else {} + raw = pricing_section.get(model) if isinstance(pricing_section, dict) else None + if not isinstance(raw, dict): + return None + return ModelPricing( + input_cost_per_token=as_float(raw.get("input", raw.get("input_cost_per_token", 0.0)), name="pricing.input"), + output_cost_per_token=as_float(raw.get("output", raw.get("output_cost_per_token", 0.0)), name="pricing.output"), + cache_read_cost_per_token=as_float(raw.get("cache_read", raw.get("cache_read_cost_per_token", 0.0)), name="pricing.cache_read"), + cache_write_cost_per_token=as_float(raw.get("cache_write", raw.get("cache_write_cost_per_token", 0.0)), name="pricing.cache_write"), + reasoning_cost_per_token=as_float(raw.get("reasoning", raw.get("reasoning_cost_per_token", 0.0)), name="pricing.reasoning"), + currency=str(raw.get("currency", "USD")), + source="json_config", + ) + + +def get_stream_runtime_settings( + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> StreamRuntimeSettings: + """Return stream runtime settings with environment overriding JSON.""" + + source = env if env is not None else os.environ + active = config if config is not None else load_experimental_config(env=source) + streaming = active.streaming if isinstance(active.streaming, dict) else {} + + return StreamRuntimeSettings( + ttfb_timeout_seconds=_optional_positive_float(_env_or_json(source, "STREAM_TTFB_TIMEOUT_SECONDS", streaming, "ttfb_timeout_seconds"), "STREAM_TTFB_TIMEOUT_SECONDS"), + stall_timeout_seconds=_optional_positive_float(_env_or_json(source, "STREAM_STALL_TIMEOUT_SECONDS", streaming, "stall_timeout_seconds"), "STREAM_STALL_TIMEOUT_SECONDS"), + heartbeat_seconds=_optional_positive_float(_env_or_json(source, "STREAM_HEARTBEAT_SECONDS", streaming, "heartbeat_seconds"), "STREAM_HEARTBEAT_SECONDS"), + trace_metrics=as_bool(_env_or_json(source, "STREAM_TRACE_METRICS", streaming, "trace_metrics", default=True), name="STREAM_TRACE_METRICS"), + ) + + +def parse_field_cache_rules(config: ExperimentalConfig, provider: str, model: str) -> tuple[FieldCacheRule, ...]: + """Parse configured field-cache rules for a provider/model. + + Wildcard model rules are returned before exact model rules so providers can + define general preservation behavior and then append model-specific rules. + This helper is intentionally not auto-wired into providers; providers decide + whether external config is appropriate for their protocol state. + """ + + provider_rules = config.field_cache.get(provider, {}) if isinstance(config.field_cache, dict) else {} + if not isinstance(provider_rules, dict): + return () + raw_rules: list[Any] = [] + for key in ("*", model): + value = provider_rules.get(key, []) + if isinstance(value, list): + raw_rules.extend(value) + return tuple(_field_cache_rule_from_dict(rule) for rule in raw_rules if isinstance(rule, dict)) + + +def as_bool(value: Any, *, name: str) -> bool: + """Parse a JSON/env boolean value.""" + + if isinstance(value, bool): + return value + if isinstance(value, (int, float)) and value in (0, 1): + return bool(value) + if isinstance(value, str): + text = value.strip().lower() + if text in {"1", "true", "yes", "on"}: + return True + if text in {"0", "false", "no", "off"}: + return False + raise ExperimentalConfigError(f"Invalid boolean for {name}") + + +def as_float(value: Any, *, name: str) -> float: + """Parse a JSON/env float value with redacted errors.""" + + try: + return float(value) + except (TypeError, ValueError) as exc: + raise ExperimentalConfigError(f"Invalid number for {name}") from exc + + +def env_price_key(provider: str, model: str, suffix: str) -> str: + """Return normalized model-price environment variable name.""" + + return f"MODEL_PRICE_{_env_part(provider)}_{_env_part(model)}_{_env_part(suffix)}" + + +def _path_from_env(env: Mapping[str, str]) -> Optional[Path]: + for key in _CONFIG_ENV_KEYS: + value = env.get(key) + if value: + return Path(value) + return None + + +def _dict_section(data: Mapping[str, Any], key: str) -> dict[str, Any]: + value = data.get(key, {}) + if value is None: + return {} + if not isinstance(value, dict): + raise ExperimentalConfigError(f"Config section '{key}' must be an object") + return dict(value) + + +def _reject_secret_keys(value: Any, path: str = "config") -> None: + if isinstance(value, Mapping): + for key, nested in value.items(): + key_text = str(key).lower() + if any(part in key_text for part in _SECRET_KEY_PARTS): + raise ExperimentalConfigError(f"Unsafe secret-like key in JSON config at {path}.{key}") + _reject_secret_keys(nested, f"{path}.{key}") + elif isinstance(value, list): + for index, nested in enumerate(value): + _reject_secret_keys(nested, f"{path}[{index}]") + + +def _pricing_from_env(provider: str, model: str, env: Mapping[str, str]) -> Optional[ModelPricing]: + suffixes = { + "input": "INPUT", + "output": "OUTPUT", + "cache_read": "CACHE_READ", + "cache_write": "CACHE_WRITE", + "reasoning": "REASONING", + } + values: dict[str, float] = {} + for field_name, suffix in suffixes.items(): + key = env_price_key(provider, model, suffix) + raw = env.get(key) + if raw not in (None, ""): + values[field_name] = as_float(raw, name=key) + if not values: + return None + return ModelPricing( + input_cost_per_token=values.get("input", 0.0), + output_cost_per_token=values.get("output", 0.0), + cache_read_cost_per_token=values.get("cache_read", 0.0), + cache_write_cost_per_token=values.get("cache_write", 0.0), + reasoning_cost_per_token=values.get("reasoning", 0.0), + source="env", + ) + + +def _env_or_json(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], json_key: str, default: Any = None) -> Any: + if env_key in env: + return env[env_key] + return data.get(json_key, default) + + +def _optional_positive_float(value: Any, name: str) -> Optional[float]: + if value in (None, ""): + return None + parsed = as_float(value, name=name) + return parsed if parsed > 0 else None + + +def _field_cache_rule_from_dict(data: Mapping[str, Any]) -> FieldCacheRule: + inject_data = data.get("inject") + inject = None + if isinstance(inject_data, Mapping): + inject = FieldCacheInjection( + target=str(inject_data.get("target", "request")), + path=str(inject_data.get("path", data.get("target_path", ""))), + when_missing_only=as_bool(inject_data.get("when_missing_only", False), name="field_cache.inject.when_missing_only"), + insert=as_bool(inject_data.get("insert", False), name="field_cache.inject.insert"), + as_list=as_bool(inject_data.get("as_list", False), name="field_cache.inject.as_list"), + ) + elif data.get("target_path"): + inject = FieldCacheInjection(target=str(data.get("target", "request")), path=str(data["target_path"])) + scope = data.get("scope", ("provider", "model", "classifier", "session")) + if isinstance(scope, str): + scope_values = tuple(part.strip() for part in scope.split(",") if part.strip()) + elif isinstance(scope, (list, tuple)): + scope_values = tuple(str(part) for part in scope) + else: + raise ExperimentalConfigError("field_cache.scope must be a string or list") + try: + return FieldCacheRule( + name=str(data["name"]), + source=str(data["source"]), + path=str(data["path"]), + mode=str(data.get("mode", "last")), + scope=scope_values, + inject=inject, + enabled=as_bool(data.get("enabled", True), name="field_cache.enabled"), + ttl_seconds=int(data["ttl_seconds"]) if data.get("ttl_seconds") is not None else None, + metadata=dict(data.get("metadata", {})) if isinstance(data.get("metadata", {}), dict) else {}, + allow_missing_session=as_bool(data.get("allow_missing_session", False), name="field_cache.allow_missing_session"), + ) + except KeyError as exc: + raise ExperimentalConfigError(f"Missing field-cache rule key {exc.args[0]}") from exc + except ValueError as exc: + raise ExperimentalConfigError(str(exc)) from exc + + +def _env_part(value: str) -> str: + return re.sub(r"[^A-Z0-9]+", "_", value.upper()).strip("_") diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py new file mode 100644 index 000000000..7c9577c10 --- /dev/null +++ b/tests/test_experimental_config.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import pytest + +from rotator_library.config.experimental import ( + ExperimentalConfigError, + env_price_key, + get_stream_runtime_settings, + load_config_from_mapping, + load_experimental_config, + parse_field_cache_rules, +) + + +def test_missing_config_file_returns_empty(tmp_path) -> None: + config = load_experimental_config(tmp_path / "missing.json", env={}) + + assert config.is_empty + + +def test_loads_config_from_env_path(tmp_path) -> None: + path = tmp_path / "config.json" + path.write_text('{"routing":{"model_routes":{"code":"group:chain"}},"extra":{}}', encoding="utf-8") + + config = load_experimental_config(env={"LLM_PROXY_CONFIG_FILE": str(path)}) + + assert config.routing["model_routes"]["code"] == "group:chain" + assert config.unknown_sections == {"extra": {}} + assert config.warnings + + +def test_rejects_secret_like_json_keys() -> None: + with pytest.raises(ExperimentalConfigError): + load_config_from_mapping({"providers": {"openai": {"api_key": "hidden"}}}) + + +def test_invalid_json_raises(tmp_path) -> None: + path = tmp_path / "config.json" + path.write_text("{", encoding="utf-8") + + with pytest.raises(ExperimentalConfigError): + load_experimental_config(path, env={}) + + +def test_stream_runtime_settings_env_overrides_json() -> None: + config = load_config_from_mapping({"streaming": {"trace_metrics": True, "stall_timeout_seconds": 30}}) + + settings = get_stream_runtime_settings(config=config, env={"STREAM_TRACE_METRICS": "false"}) + + assert settings.trace_metrics is False + assert settings.stall_timeout_seconds == 30 + + +def test_field_cache_rules_parse_wildcard_then_model_specific() -> None: + config = load_config_from_mapping( + { + "field_cache": { + "gemini_cli": { + "*": [{"name": "thought", "source": "response", "path": "$.thought", "target_path": "$.cached_thought"}], + "gemini-3": [{"name": "signature", "source": "response", "path": "$.sig", "scope": ["provider", "model"]}], + } + } + } + ) + + rules = parse_field_cache_rules(config, "gemini_cli", "gemini-3") + + assert [rule.name for rule in rules] == ["thought", "signature"] + assert rules[0].inject is not None + + +def test_env_price_key_sanitizes_provider_and_model() -> None: + assert env_price_key("openai", "gpt-5.1-mini", "cache_read") == "MODEL_PRICE_OPENAI_GPT_5_1_MINI_CACHE_READ" From 3996cabd1a2df406436384f7368b258376b38618 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:34:13 +0200 Subject: [PATCH 069/182] feat(config): support json routing config Merges optional JSON fallback groups and model routes before environment variables, keeping existing env-only routing behavior and allowing env groups/routes to override JSON entries. Tests: python -m pytest tests/test_routing_config.py tests/test_config_routing_json.py tests/test_request_builder_routing.py tests/test_fallback_resolver.py --- src/rotator_library/routing/config.py | 74 ++++++++++++++++++++++++--- tests/test_config_routing_json.py | 57 +++++++++++++++++++++ 2 files changed, 124 insertions(+), 7 deletions(-) create mode 100644 tests/test_config_routing_json.py diff --git a/src/rotator_library/routing/config.py b/src/rotator_library/routing/config.py index 71a149bb2..efb3eb27f 100644 --- a/src/rotator_library/routing/config.py +++ b/src/rotator_library/routing/config.py @@ -8,7 +8,7 @@ import os from collections.abc import Mapping -from .types import FallbackGroup, RouteTarget, RoutingConfig +from .types import DEFAULT_FAILOVER_ON, DEFAULT_STOP_ON, FallbackGroup, RouteTarget, RoutingConfig class RoutingConfigError(ValueError): @@ -34,14 +34,24 @@ def parse_route_target(spec: str) -> RouteTarget: return RouteTarget(provider=provider.strip(), model=model.strip(), execution=(execution.strip() or "auto")) -def load_routing_config_from_env(env: Mapping[str, str] | None = None) -> RoutingConfig: - """Load fallback groups and model-route aliases from environment variables.""" +def load_routing_config_from_env(env: Mapping[str, str] | None = None, config: object | None = None) -> RoutingConfig: + """Load fallback groups and model-route aliases from JSON then environment. + + Environment variables intentionally remain the final override layer so the + existing `.env` deployment model keeps working exactly as before. The + optional JSON object is a convenience for structured routing plans. + """ source = env if env is not None else os.environ + if config is None: + from ..config.experimental import load_experimental_config + + config = load_experimental_config(env=source) + groups: dict[str, FallbackGroup] = _groups_from_json_config(config) + group_names = _csv(source.get("FALLBACK_GROUPS", "")) if len(group_names) != len(set(group_names)): raise RoutingConfigError("fallback group names must be unique") - groups: dict[str, FallbackGroup] = {} for name in group_names: key = f"FALLBACK_GROUP_{_env_key(name)}" target_specs = _csv(source.get(key, "")) @@ -49,18 +59,68 @@ def load_routing_config_from_env(env: Mapping[str, str] | None = None) -> Routin raise RoutingConfigError(f"fallback group {name} has no targets") groups[name] = FallbackGroup(name=name, targets=tuple(parse_route_target(spec) for spec in target_specs)) - model_routes: dict[str, str] = {} + model_routes: dict[str, str] = _model_routes_from_json_config(config) for key, value in source.items(): if not key.startswith("MODEL_ROUTE_"): continue model_alias = key[len("MODEL_ROUTE_") :].lower() route = value.strip() - if route.startswith("group:") and route[len("group:") :] not in groups: - raise RoutingConfigError(f"model route {key} references unknown fallback group {route}") model_routes[model_alias] = route + _validate_model_routes(model_routes, groups) return RoutingConfig(fallback_groups=groups, model_routes=model_routes) +def _groups_from_json_config(config: object) -> dict[str, FallbackGroup]: + routing = getattr(config, "routing", {}) + if not isinstance(routing, Mapping): + return {} + raw_groups = routing.get("fallback_groups", {}) + if not isinstance(raw_groups, Mapping): + raise RoutingConfigError("routing.fallback_groups must be an object") + groups: dict[str, FallbackGroup] = {} + for name, raw_group in raw_groups.items(): + if not isinstance(raw_group, Mapping): + raise RoutingConfigError(f"fallback group {name} must be an object") + raw_targets = raw_group.get("targets", []) + if not isinstance(raw_targets, list) or not raw_targets: + raise RoutingConfigError(f"fallback group {name} has no targets") + groups[str(name)] = FallbackGroup( + name=str(name), + targets=tuple(parse_route_target(str(spec)) for spec in raw_targets), + failover_on=_string_set(raw_group.get("failover_on"), DEFAULT_FAILOVER_ON), + stop_on=_string_set(raw_group.get("stop_on"), DEFAULT_STOP_ON), + max_targets=int(raw_group["max_targets"]) if raw_group.get("max_targets") is not None else None, + metadata=dict(raw_group.get("metadata", {})) if isinstance(raw_group.get("metadata", {}), Mapping) else {}, + ) + return groups + + +def _model_routes_from_json_config(config: object) -> dict[str, str]: + routing = getattr(config, "routing", {}) + if not isinstance(routing, Mapping): + return {} + raw_routes = routing.get("model_routes", {}) + if not isinstance(raw_routes, Mapping): + raise RoutingConfigError("routing.model_routes must be an object") + return {str(alias).lower(): str(route).strip() for alias, route in raw_routes.items()} + + +def _validate_model_routes(model_routes: Mapping[str, str], groups: Mapping[str, FallbackGroup]) -> None: + for key, route in model_routes.items(): + if route.startswith("group:") and route[len("group:") :] not in groups: + raise RoutingConfigError(f"model route {key} references unknown fallback group {route}") + + +def _string_set(value: object, default: frozenset[str]) -> frozenset[str]: + if value is None: + return default + if isinstance(value, str): + return frozenset(_csv(value)) + if isinstance(value, list): + return frozenset(str(item) for item in value) + raise RoutingConfigError("routing policy lists must be strings or arrays") + + def _csv(value: str) -> list[str]: return [part.strip() for part in value.split(",") if part.strip()] diff --git a/tests/test_config_routing_json.py b/tests/test_config_routing_json.py new file mode 100644 index 000000000..06c90c7b6 --- /dev/null +++ b/tests/test_config_routing_json.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytest + +from rotator_library.config.experimental import load_config_from_mapping +from rotator_library.routing import RoutingConfigError, load_routing_config_from_env + + +def test_json_routing_config_loads_fallback_group_and_route() -> None: + experimental = load_config_from_mapping( + { + "routing": { + "fallback_groups": { + "code_chain": { + "targets": ["codex/gpt-5.1@native", "openai/gpt-5.1@litellm_fallback"], + "failover_on": ["rate_limit"], + "stop_on": ["authentication"], + } + }, + "model_routes": {"code": "group:code_chain"}, + } + } + ) + + config = load_routing_config_from_env({}, config=experimental) + + assert config.fallback_groups["code_chain"].targets[0].execution == "native" + assert config.fallback_groups["code_chain"].failover_on == frozenset({"rate_limit"}) + assert config.model_routes["code"] == "group:code_chain" + + +def test_env_group_overrides_json_group_targets() -> None: + experimental = load_config_from_mapping( + {"routing": {"fallback_groups": {"chain": {"targets": ["codex/gpt"]}}, "model_routes": {"code": "group:chain"}}} + ) + + config = load_routing_config_from_env( + {"FALLBACK_GROUPS": "chain", "FALLBACK_GROUP_CHAIN": "openai/gpt-5.1", "MODEL_ROUTE_CODE": "group:chain"}, + config=experimental, + ) + + assert config.fallback_groups["chain"].targets[0].provider == "openai" + + +def test_env_route_can_reference_json_group() -> None: + experimental = load_config_from_mapping({"routing": {"fallback_groups": {"chain": {"targets": ["codex/gpt"]}}}}) + + config = load_routing_config_from_env({"MODEL_ROUTE_CODE": "group:chain"}, config=experimental) + + assert config.model_routes["code"] == "group:chain" + + +def test_json_route_rejects_unknown_group() -> None: + experimental = load_config_from_mapping({"routing": {"model_routes": {"code": "group:missing"}}}) + + with pytest.raises(RoutingConfigError): + load_routing_config_from_env({}, config=experimental) From c2d5dbb544623abb3e066228a46f415ad6ae7e2a Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:35:14 +0200 Subject: [PATCH 070/182] feat(config): add configured model pricing Adds JSON and environment model pricing lookup to CostCalculator while preserving skip-cost providers, explicit provider pricing priority, and LiteLLM as the final advisory fallback. Tests: python -m pytest tests/test_config_pricing.py tests/test_usage_costs.py tests/test_executor_usage_accounting.py tests/test_streaming_usage_accounting.py tests/test_experimental_config.py --- src/rotator_library/usage/costs.py | 25 ++++++++++- tests/test_config_pricing.py | 68 ++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 tests/test_config_pricing.py diff --git a/src/rotator_library/usage/costs.py b/src/rotator_library/usage/costs.py index 55cf1ed92..12c850ea1 100644 --- a/src/rotator_library/usage/costs.py +++ b/src/rotator_library/usage/costs.py @@ -65,16 +65,21 @@ def to_dict(self) -> dict[str, Any]: class CostCalculator: """Calculate advisory costs without replacing usage tracking.""" - def __init__(self, *, provider_plugin: Any = None, use_litellm_fallback: bool = True) -> None: + def __init__(self, *, provider_plugin: Any = None, use_litellm_fallback: bool = True, config: Any = None, env: Any = None) -> None: self.provider_plugin = provider_plugin self.use_litellm_fallback = use_litellm_fallback + self.config = config + self.env = env - def calculate(self, usage: UsageRecord, *, model: str, response: Any = None) -> CostBreakdown: + def calculate(self, usage: UsageRecord, *, model: str, response: Any = None, provider: str | None = None) -> CostBreakdown: """Return an advisory cost breakdown for a normalized usage record.""" if self.provider_plugin and getattr(self.provider_plugin, "skip_cost_calculation", False): return CostBreakdown(pricing_source="skipped", metadata={"reason": "provider_skip_cost_calculation"}) pricing = self._provider_pricing(model) + if pricing: + return _calculate_from_pricing(usage, pricing) + pricing = self._configured_pricing(provider or usage.provider or _provider_from_model(model), model) if pricing: return _calculate_from_pricing(usage, pricing) if self.use_litellm_fallback: @@ -96,6 +101,13 @@ def _provider_pricing(self, model: str) -> Optional[ModelPricing]: return ModelPricing(**pricing) return None + def _configured_pricing(self, provider: str | None, model: str) -> Optional[ModelPricing]: + if not provider: + return None + from ..config.experimental import get_configured_model_pricing + + return get_configured_model_pricing(provider, _model_without_provider(provider, model), config=self.config, env=self.env) + @staticmethod def _litellm_cost(usage: UsageRecord, *, model: str, response: Any = None) -> Optional[CostBreakdown]: if response is not None: @@ -134,3 +146,12 @@ def _calculate_from_pricing(usage: UsageRecord, pricing: ModelPricing) -> CostBr currency=pricing.currency, pricing_source=pricing.source, ) + + +def _provider_from_model(model: str) -> Optional[str]: + return model.split("/", 1)[0] if "/" in model else None + + +def _model_without_provider(provider: str, model: str) -> str: + prefix = f"{provider}/" + return model[len(prefix) :] if model.startswith(prefix) else model diff --git a/tests/test_config_pricing.py b/tests/test_config_pricing.py new file mode 100644 index 000000000..39269b4b5 --- /dev/null +++ b/tests/test_config_pricing.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from rotator_library.config.experimental import env_price_key, load_config_from_mapping +from rotator_library.usage.accounting import UsageRecord +from rotator_library.usage.costs import CostCalculator, ModelPricing + + +class ExplicitPricingProvider: + def get_model_pricing(self, model: str) -> ModelPricing: + return ModelPricing(input_cost_per_token=9.0, source="provider") + + +class SkipCostProvider: + skip_cost_calculation = True + + +def _usage() -> UsageRecord: + return UsageRecord(input_tokens=10, completion_tokens=5, reasoning_tokens=2, cache_read_tokens=3, cache_write_tokens=4, provider="openai", model="gpt-test") + + +def test_json_pricing_calculates_all_buckets() -> None: + config = load_config_from_mapping( + {"pricing": {"openai": {"gpt-test": {"input": 1.0, "output": 2.0, "reasoning": 3.0, "cache_read": 0.5, "cache_write": 0.75}}}} + ) + + cost = CostCalculator(config=config, use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "json_config" + assert cost.input_cost == 10.0 + assert cost.output_cost == 10.0 + assert cost.reasoning_cost == 6.0 + assert cost.cache_read_cost == 1.5 + assert cost.cache_write_cost == 3.0 + + +def test_env_pricing_overrides_json_pricing() -> None: + config = load_config_from_mapping({"pricing": {"openai": {"gpt-test": {"input": 1.0}}}}) + env = {env_price_key("openai", "gpt-test", "input"): "4.0"} + + cost = CostCalculator(config=config, env=env, use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "env" + assert cost.input_cost == 40.0 + + +def test_provider_pricing_beats_env_and_json_pricing() -> None: + config = load_config_from_mapping({"pricing": {"openai": {"gpt-test": {"input": 1.0}}}}) + env = {env_price_key("openai", "gpt-test", "input"): "4.0"} + + cost = CostCalculator(provider_plugin=ExplicitPricingProvider(), config=config, env=env, use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "provider" + assert cost.input_cost == 90.0 + + +def test_skip_cost_provider_beats_all_config_pricing() -> None: + config = load_config_from_mapping({"pricing": {"openai": {"gpt-test": {"input": 1.0}}}}) + + cost = CostCalculator(provider_plugin=SkipCostProvider(), config=config, use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "skipped" + assert cost.total_cost == 0.0 + + +def test_missing_pricing_remains_unavailable() -> None: + cost = CostCalculator(use_litellm_fallback=False).calculate(_usage(), model="openai/gpt-test") + + assert cost.pricing_source == "unavailable" From 422f8fb8457ddf7795d6a9e38a1f965d5cb23d70 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:36:22 +0200 Subject: [PATCH 071/182] feat(config): add stream runtime settings Adds parsed stream runtime settings and wires the safe STREAM_TRACE_METRICS toggle so lifecycle metrics can be disabled without changing SSE output. Timeout and heartbeat values are parsed for future use but not enforced. Tests: python -m pytest tests/test_config_stream_settings.py tests/test_request_executor_stream_metrics.py tests/test_streaming_usage_accounting.py tests/test_stream_events.py tests/test_stream_metrics.py tests/test_stream_policy.py tests/test_streaming_error_handler.py --- src/rotator_library/client/streaming.py | 16 ++++++---- tests/test_config_stream_settings.py | 32 +++++++++++++++++++ tests/test_request_executor_stream_metrics.py | 15 +++++++++ 3 files changed, 57 insertions(+), 6 deletions(-) create mode 100644 tests/test_config_stream_settings.py diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index b13291fa4..0298badde 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -86,8 +86,12 @@ async def wrap_stream( assistant_parts: List[str] = [] tool_call_ids: List[str] = [] monitor = StreamMonitor(clock=time.monotonic) + from ..config.experimental import get_stream_runtime_settings + + stream_settings = get_stream_runtime_settings() + lifecycle_logger = transaction_logger if stream_settings.trace_metrics else None self._log_stream_lifecycle( - transaction_logger, + lifecycle_logger, "stream_started", monitor, StreamEvent("started", protocol="openai_chat"), @@ -116,7 +120,7 @@ async def wrap_stream( ) if monitor.metrics.first_byte_at is None: self._log_stream_lifecycle( - transaction_logger, + lifecycle_logger, "stream_first_byte", monitor, raw_event, @@ -145,7 +149,7 @@ async def wrap_stream( monitor.record_event(event) if first_visible: self._log_stream_lifecycle( - transaction_logger, + lifecycle_logger, "stream_first_visible_output", monitor, event, @@ -277,13 +281,13 @@ async def wrap_stream( monitor.complete() self._log_stream_lifecycle( - transaction_logger, + lifecycle_logger, "stream_completed", monitor, StreamEvent("completed", protocol="openai_chat"), ) self._log_stream_lifecycle( - transaction_logger, + lifecycle_logger, "stream_metrics_final", monitor, StreamEvent("metadata", protocol="openai_chat"), @@ -295,7 +299,7 @@ async def wrap_stream( elif request and await request.is_disconnected(): monitor.cancel() self._log_stream_lifecycle( - transaction_logger, + lifecycle_logger, "stream_cancelled", monitor, StreamEvent("cancelled", protocol="openai_chat"), diff --git a/tests/test_config_stream_settings.py b/tests/test_config_stream_settings.py new file mode 100644 index 000000000..2ef253354 --- /dev/null +++ b/tests/test_config_stream_settings.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import pytest + +from rotator_library.config.experimental import ExperimentalConfigError, get_stream_runtime_settings, load_config_from_mapping + + +def test_stream_settings_parse_json_values() -> None: + config = load_config_from_mapping( + {"streaming": {"ttfb_timeout_seconds": 5, "stall_timeout_seconds": 30, "heartbeat_seconds": 10, "trace_metrics": False}} + ) + + settings = get_stream_runtime_settings(config=config, env={}) + + assert settings.ttfb_timeout_seconds == 5 + assert settings.stall_timeout_seconds == 30 + assert settings.heartbeat_seconds == 10 + assert settings.trace_metrics is False + + +def test_stream_settings_env_overrides_json_values() -> None: + config = load_config_from_mapping({"streaming": {"trace_metrics": True, "heartbeat_seconds": 10}}) + + settings = get_stream_runtime_settings(config=config, env={"STREAM_TRACE_METRICS": "false", "STREAM_HEARTBEAT_SECONDS": "2"}) + + assert settings.trace_metrics is False + assert settings.heartbeat_seconds == 2 + + +def test_stream_settings_invalid_boolean_fails_clearly() -> None: + with pytest.raises(ExperimentalConfigError): + get_stream_runtime_settings(env={"STREAM_TRACE_METRICS": "maybe"}) diff --git a/tests/test_request_executor_stream_metrics.py b/tests/test_request_executor_stream_metrics.py index 0bf1c74e9..8917b9008 100644 --- a/tests/test_request_executor_stream_metrics.py +++ b/tests/test_request_executor_stream_metrics.py @@ -31,3 +31,18 @@ async def test_streaming_handler_emits_lifecycle_metrics_without_changing_output assert "stream_first_visible_output" in pass_names assert "stream_completed" in pass_names assert "stream_metrics_final" in pass_names + + +@pytest.mark.asyncio +async def test_stream_trace_metrics_can_be_disabled_without_changing_output(tmp_path, monkeypatch) -> None: + monkeypatch.setenv("STREAM_TRACE_METRICS", "false") + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_chunks(), "cred", "openai/gpt-test", transaction_logger=logger)] + + assert chunks[0].startswith("data: ") + assert chunks[-1] == "data: [DONE]\n\n" + if (logger.log_dir / "transform_trace.jsonl").exists(): + pass_names = _trace_passes(logger.log_dir) + assert "stream_started" not in pass_names + assert "stream_metrics_final" not in pass_names From f417c0881d0b49b9ce4b88252562a4773fa0c1e8 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:37:17 +0200 Subject: [PATCH 072/182] docs(config): document experimental knobs Documents optional JSON config, fallback groups, provider cooldowns, streaming observability, and advisory model pricing in .env.example with a regression test to keep key knobs present. Tests: python -m pytest tests/test_env_example_experimental_config.py tests/test_experimental_config.py tests/test_config_routing_json.py tests/test_config_pricing.py tests/test_config_stream_settings.py --- .env.example | 44 +++++++++++++++++++ tests/test_env_example_experimental_config.py | 26 +++++++++++ 2 files changed, 70 insertions(+) create mode 100644 tests/test_env_example_experimental_config.py diff --git a/.env.example b/.env.example index 1c5b896fe..9d638a8f8 100644 --- a/.env.example +++ b/.env.example @@ -333,6 +333,50 @@ # Default: false # STREAM_RETRY_ON_REASONING_ONLY=false +# --- Optional Structured Config File --- +# Optional JSON config for structured experimental settings such as fallback +# groups, model routes, advisory pricing, streaming observability, and field +# cache rules. Environment variables override JSON config. Do not put API keys, +# OAuth tokens, bearer tokens, or authorization headers in this file. +# LLM_PROXY_CONFIG_FILE=./config/llm-proxy.json + +# --- Fallback Groups and Model Routes --- +# Ordered fallback groups try targets in order when the previous target fails +# with a retryable provider/category error. Execution suffixes: +# @auto Let the provider choose custom/native/LiteLLM behavior +# @native Require native protocol execution +# @custom Require provider custom execution +# @litellm_fallback Explicitly use LiteLLM fallback +# FALLBACK_GROUPS=code_chain +# FALLBACK_GROUP_CODE_CHAIN=codex/gpt-5.1-codex@native,openai/gpt-5.1@litellm_fallback +# MODEL_ROUTE_CODE=group:code_chain + +# --- Provider Cooldown Activation --- +# Provider-level cooldown is conservative and only intended for large/global +# retry-after events, not every per-credential quota error. +# PROVIDER_COOLDOWN_MIN_SECONDS=10 +# PROVIDER_COOLDOWN_DEFAULT_SECONDS=60 +# PROVIDER_COOLDOWN_ON_QUOTA=false + +# --- Streaming Observability --- +# Stream lifecycle metrics are traced by default. Timeout/heartbeat values are +# parsed for future enforcement but are not used to abort streams by default. +# STREAM_TRACE_METRICS=true +# STREAM_TTFB_TIMEOUT_SECONDS=0 +# STREAM_STALL_TIMEOUT_SECONDS=0 +# STREAM_HEARTBEAT_SECONDS=0 + +# --- Advisory Model Pricing --- +# Per-token advisory prices used for usage/cost traces and stored approximate +# cost values when provider/LiteLLM pricing is unavailable. Provider code can +# still supply authoritative pricing, and skip-cost providers always return zero. +# Env names sanitize provider/model by replacing non-alphanumerics with `_`. +# MODEL_PRICE_OPENAI_GPT_5_1_INPUT=0.000001 +# MODEL_PRICE_OPENAI_GPT_5_1_OUTPUT=0.00001 +# MODEL_PRICE_OPENAI_GPT_5_1_CACHE_READ=0.0000001 +# MODEL_PRICE_OPENAI_GPT_5_1_CACHE_WRITE=0.000001 +# MODEL_PRICE_OPENAI_GPT_5_1_REASONING=0.00001 + # ------------------------------------------------------------------------------ # | [ADVANCED] HTTP Timeout Configuration | # ------------------------------------------------------------------------------ diff --git a/tests/test_env_example_experimental_config.py b/tests/test_env_example_experimental_config.py new file mode 100644 index 000000000..5435c17c2 --- /dev/null +++ b/tests/test_env_example_experimental_config.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from pathlib import Path + + +def test_env_example_documents_experimental_config_knobs() -> None: + text = Path(".env.example").read_text(encoding="utf-8") + + for key in ( + "LLM_PROXY_CONFIG_FILE", + "FALLBACK_GROUPS", + "FALLBACK_GROUP_CODE_CHAIN", + "MODEL_ROUTE_CODE", + "PROVIDER_COOLDOWN_MIN_SECONDS", + "PROVIDER_COOLDOWN_DEFAULT_SECONDS", + "PROVIDER_COOLDOWN_ON_QUOTA", + "STREAM_TRACE_METRICS", + "STREAM_TTFB_TIMEOUT_SECONDS", + "STREAM_STALL_TIMEOUT_SECONDS", + "STREAM_HEARTBEAT_SECONDS", + "MODEL_PRICE_OPENAI_GPT_5_1_INPUT", + "MODEL_PRICE_OPENAI_GPT_5_1_REASONING", + ): + assert key in text + + assert "Do not put API keys" in text From b65f948cd1669b9b9b9496f7f8b60399055af0ab Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 03:44:29 +0200 Subject: [PATCH 073/182] test(config): cover runtime pricing config Adds production-style executor and streaming tests proving MODEL_PRICE_* env pricing is picked up when CostCalculator is constructed without explicit config. Also adds the planned as_int helper and documents zero-as-unset stream timeout semantics. Tests: python -m pytest tests/test_experimental_config.py tests/test_executor_usage_accounting.py tests/test_streaming_usage_accounting.py tests/test_config_pricing.py tests/test_config_stream_settings.py tests/test_request_executor_stream_metrics.py --- src/rotator_library/config/experimental.py | 11 +++++++++++ tests/test_executor_usage_accounting.py | 20 ++++++++++++++++++++ tests/test_experimental_config.py | 7 +++++++ tests/test_streaming_usage_accounting.py | 10 ++++++++++ 4 files changed, 48 insertions(+) diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py index 8f72ed385..31767d047 100644 --- a/src/rotator_library/config/experimental.py +++ b/src/rotator_library/config/experimental.py @@ -205,6 +205,15 @@ def as_float(value: Any, *, name: str) -> float: raise ExperimentalConfigError(f"Invalid number for {name}") from exc +def as_int(value: Any, *, name: str) -> int: + """Parse a JSON/env integer value with redacted errors.""" + + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ExperimentalConfigError(f"Invalid integer for {name}") from exc + + def env_price_key(provider: str, model: str, suffix: str) -> str: """Return normalized model-price environment variable name.""" @@ -276,6 +285,8 @@ def _optional_positive_float(value: Any, name: str) -> Optional[float]: if value in (None, ""): return None parsed = as_float(value, name=name) + # Zero and negative values mean "not configured" for timeout-like knobs. + # Runtime enforcement is intentionally disabled by default. return parsed if parsed > 0 else None diff --git a/tests/test_executor_usage_accounting.py b/tests/test_executor_usage_accounting.py index b31114ee8..c1786c082 100644 --- a/tests/test_executor_usage_accounting.py +++ b/tests/test_executor_usage_accounting.py @@ -47,3 +47,23 @@ def test_executor_accounts_for_non_streaming_usage_and_cost_trace(tmp_path, monk entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] assert entries[-1]["pass_name"] == "usage_accounting_summary" assert entries[-1]["data"]["usage"]["total_tokens"] == usage.total_tokens + + +def test_executor_accounting_uses_configured_env_pricing(tmp_path, monkeypatch) -> None: + monkeypatch.setenv("MODEL_PRICE_OPENAI_GPT_TEST_INPUT", "2.0") + executor = _executor() + response = SimpleNamespace(usage=SimpleNamespace(prompt_tokens=3, completion_tokens=0)) + context = RequestContext( + provider="openai", + model="gpt-test", + kwargs={}, + streaming=False, + credentials=[], + deadline=0, + transaction_logger=TransactionLogger("openai", "gpt-test", parent_dir=tmp_path), + ) + + _, cost = executor._account_for_response_usage("openai", "gpt-test", response, context) + + assert cost.pricing_source == "env" + assert cost.input_cost == 6.0 diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py index 7c9577c10..313d0483b 100644 --- a/tests/test_experimental_config.py +++ b/tests/test_experimental_config.py @@ -4,6 +4,7 @@ from rotator_library.config.experimental import ( ExperimentalConfigError, + as_int, env_price_key, get_stream_runtime_settings, load_config_from_mapping, @@ -71,3 +72,9 @@ def test_field_cache_rules_parse_wildcard_then_model_specific() -> None: def test_env_price_key_sanitizes_provider_and_model() -> None: assert env_price_key("openai", "gpt-5.1-mini", "cache_read") == "MODEL_PRICE_OPENAI_GPT_5_1_MINI_CACHE_READ" + + +def test_as_int_parses_integers_with_redacted_errors() -> None: + assert as_int("5", name="TEST_INT") == 5 + with pytest.raises(ExperimentalConfigError, match="TEST_INT"): + as_int("not-secret-value", name="TEST_INT") diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py index 0534092aa..0699b50c8 100644 --- a/tests/test_streaming_usage_accounting.py +++ b/tests/test_streaming_usage_accounting.py @@ -77,3 +77,13 @@ async def test_streaming_without_usage_still_marks_success_with_zero_usage() -> assert cred_context.success_kwargs["thinking_tokens"] == 0 assert cred_context.success_kwargs["prompt_tokens_cache_read"] == 0 assert cred_context.success_kwargs["prompt_tokens_cache_write"] == 0 + + +@pytest.mark.asyncio +async def test_streaming_usage_uses_configured_env_pricing(monkeypatch) -> None: + monkeypatch.setenv("MODEL_PRICE_OPENAI_GPT_TEST_INPUT", "2.0") + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "openai/gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 120.0 From 6e641aef2d895cd4b178bd3764543a9f056b1ad2 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 08:00:07 +0200 Subject: [PATCH 074/182] docs(experimental): plan protocol breadth correction Adds the corrective Phase 1b plan for closing protocol-foundation gaps found during the full audit pass. The plan covers an explicit operation model, non-chat protocol adapters for embeddings, images, audio, Ollama, and MCP, count-token capability seams, tests, transaction-logging implications, and review requirements. Reports and unrelated docs/issues content remain uncommitted. --- ...ase-1b-protocol-breadth-operation-model.md | 157 ++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 docs/experimental/phase-1b-protocol-breadth-operation-model.md diff --git a/docs/experimental/phase-1b-protocol-breadth-operation-model.md b/docs/experimental/phase-1b-protocol-breadth-operation-model.md new file mode 100644 index 000000000..19ea02c03 --- /dev/null +++ b/docs/experimental/phase-1b-protocol-breadth-operation-model.md @@ -0,0 +1,157 @@ +# Phase 1b: Protocol Breadth And Operation Model + +## Goal + +Correct the main Phase 1 audit failure. Phase 1 created a good chat/message protocol foundation, but the starting requirements asked for protocols as the highest priority and for broad protocol coverage. The current protocol layer is still too chat-oriented: it has OpenAI Chat, Anthropic Messages, Gemini, Responses, and LiteLLM fallback, but no explicit operation model for embeddings, image generation, audio transcription, speech/TTS, Ollama-native chat/generate, MCP-style tool gateway calls, or count-token operations. + +## Non-Goals + +- Do not wire every new operation into live FastAPI routes in this corrective slice. This is protocol foundation work. +- Do not replace current embeddings or Anthropic count_tokens routes yet. +- Do not replace existing provider execution, `UsageManager`, `SessionTracker`, routing, or LiteLLM fallback. +- Do not add SQLite or persistent admin/config databases. +- Do not implement full MCP transport/proxy behavior yet; define the protocol shape and operation seam so it is not blocked later. +- Do not fake WebSocket runtime support; keep explicit future transport/capability metadata. +- Do not touch uncommitted phase reports or `docs/issues/`. + +## Current Code Context + +- `src/rotator_library/protocols/types.py` has unified chat/message/request/response/stream dataclasses, but `UnifiedRequest` has no `operation` field and no dedicated multimodal operation carrier fields. +- `ProtocolAdapter` has `supported_transports` but no `supported_operations` or `supports_operation()`. +- Existing concrete protocols are chat/message/generate-content focused. +- `ContentBlock` already has generic `type`, `source`, `raw`, and `extra`, which can support multimodal payloads with careful extension. +- Registry auto-discovery exists and should be reused for new protocol modules. +- Tests exist only for the initial protocol modules. + +## Files To Add + +- `src/rotator_library/protocols/operation.py` +- `src/rotator_library/protocols/openai_embeddings.py` +- `src/rotator_library/protocols/openai_images.py` +- `src/rotator_library/protocols/openai_audio.py` +- `src/rotator_library/protocols/ollama.py` +- `src/rotator_library/protocols/mcp.py` +- `tests/test_protocol_operation_model.py` +- `tests/test_protocol_openai_embeddings.py` +- `tests/test_protocol_openai_images_audio.py` +- `tests/test_protocol_ollama_mcp.py` + +## Files To Touch + +- `src/rotator_library/protocols/types.py` +- `src/rotator_library/protocols/base.py` +- `src/rotator_library/protocols/__init__.py` +- Existing protocol modules where they should declare `supported_operations`. +- Existing protocol tests if serialization expectations need operation fields added. + +## Data Model Additions + +- Add operation constants in a small `operation.py` module, not a rigid enum, so custom/local protocols can introduce operation strings without fighting the core. +- Initial standard operation names: `chat`, `messages`, `responses`, `count_tokens`, `embeddings`, `image_generation`, `image_edit`, `image_variation`, `audio_transcription`, `audio_translation`, `speech`, `ollama_chat`, `ollama_generate`, `mcp`, and `unknown`. +- Add `operation` to `UnifiedRequest`. +- Add flexible operation carrier fields to `UnifiedRequest`: `input`, `modalities`, and `files`. +- Add `operation` to `UnifiedResponse`. +- Add flexible response fields to `UnifiedResponse`: `data` and `content_type`. +- Keep existing fields and defaults chat-compatible so existing adapters do not break. +- Update explicit serialization fields so transform logging can see the new values. + +## Protocol Adapter Additions + +- Add class attribute `supported_operations`. +- Add `supports_operation(operation_name)`. +- Existing protocols should declare their primary operations. +- The base adapter defaults to `operation="unknown"` when no operation is identified. +- LiteLLM fallback stays broad and explicit, not a native implementation claim. + +## New Protocol Modules + +### OpenAI Embeddings + +- Protocol name: `openai_embeddings`. +- Aliases: `embeddings`. +- Operation: `embeddings`. +- Parse/build `/v1/embeddings` style fields: `model`, `input`, `encoding_format`, `dimensions`, and `user`. +- Parse `data[].embedding` and `usage` from responses. +- Preserve unknown fields. + +### OpenAI Images + +- Protocol name: `openai_images`. +- Aliases: `images`, `image_generation`. +- Operations: `image_generation`, `image_edit`, and `image_variation`. +- Parse/build standard OpenAI image fields: `prompt`, `model`, `n`, `size`, `quality`, `style`, `response_format`, `image`, `mask`, and `user`. +- Parse response `data` entries with `url`, `b64_json`, and `revised_prompt`. +- Preserve file/source metadata without reading file contents. + +### OpenAI Audio + +- Protocol name: `openai_audio`. +- Aliases: `audio`, `audio_transcription`, and `speech`. +- Operations: `audio_transcription`, `audio_translation`, and `speech`. +- Parse/build transcription/translation fields: `file`, `model`, `language`, `prompt`, `response_format`, `temperature`, and `timestamp_granularities`. +- Parse/build speech fields: `input`, `model`, `voice`, `response_format`, and `speed`. +- Preserve binary or text provider responses in `raw`; map JSON responses into `data` or `output` when possible. + +### Ollama + +- Protocol name: `ollama`. +- Operations: `ollama_chat`, `ollama_generate`, and compatible `embeddings`. +- Parse/build `/api/chat`, `/api/generate`, and `/api/embeddings` shapes. +- Map `messages`, `prompt`, `options`, `system`, `template`, `context`, `keep_alive`, `format`, and `stream`. +- Parse final and stream chunks while preserving raw context/performance fields. + +### MCP + +- Protocol name: `mcp`. +- Operation: `mcp`. +- Define a JSON-RPC carrier for `initialize`, `tools/list`, `tools/call`, `resources/list`, `resources/read`, `prompts/list`, `prompts/get`, and generic method calls. +- Do not claim a full MCP proxy implementation in this slice. +- Preserve `jsonrpc`, `method`, `params`, `id`, `result`, and `error` fields exactly. + +## Count Tokens + +- Add count-token operation support where it naturally fits, especially Anthropic Messages. +- Avoid duplicating current count-token routes in this slice. +- Ensure `ProtocolAdapter.supports_operation("count_tokens")` can represent a protocol's support. + +## Tests + +- Operation model serialization includes `operation`, `input`, `modalities`, `files`, response `data`, and `content_type`. +- `ProtocolAdapter.supports_operation()` works and remains override-friendly. +- Registry discovers all new protocols. +- Embeddings request/response parse/build round-trip and usage extraction. +- Images request/response parse/build, including edit file/mask references. +- Audio transcription and speech request/response parse/build. +- Ollama chat/generate parse/build and stream parsing. +- MCP request/response/error round-trip. +- Existing protocol regression tests remain clean. + +## Transaction Logging Implications + +- This slice only adds trace-friendly fields and protocol adapters. +- Do not wire live transform logs here. +- Preserve `raw` and `extra` so Phase 2b can log every state. + +## Docstrings And Comments Required + +- Public classes must explain what API shape they model and what remains intentionally future-facing. +- Operation constants must document that strings are intentionally extensible. +- Lossy conversions, especially binary/audio/image payload preservation, need comments explaining that file contents are not inspected and provider-specific metadata stays in `extra`. + +## Commit Checkpoints + +1. Plan doc commit: `docs(experimental): plan protocol breadth correction`. +2. Operation model/types/base adapter commit with tests. +3. Embeddings/images/audio protocol commit with tests. +4. Ollama/MCP/count-token capability commit with tests. +5. Focused protocol regression run. +6. `explore` and `explore-heavy` review against this plan and the source requirements. +7. Fix findings and re-review if substantial. +8. Write uncommitted Phase 1b report. + +## Risks + +- Adding fields to dataclasses can break exact dict assertions. Mitigation: update tests only where serialization now intentionally includes new fields. +- A rigid operation enum could block custom providers. Mitigation: use string constants/helpers, not a closed enum. +- New non-chat adapters could overclaim live support. Mitigation: protocol package supports parse/build only; route/provider wiring comes later. +- MCP can sprawl. Mitigation: implement JSON-RPC carrier shape only, not a full proxy. From fd2f6a200d85c8dd3725c9f9b084c36db6c99396 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 08:02:38 +0200 Subject: [PATCH 075/182] feat(protocols): add operation model Adds extensible protocol operation names, operation capability checks on ProtocolAdapter, and broad request/response carrier fields for non-chat protocols. Existing chat, messages, Gemini, Responses, and LiteLLM fallback adapters now advertise supported operations without changing live execution paths. Tests: pytest tests/test_protocol_operation_model.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- src/rotator_library/protocols/__init__.py | 34 ++++++++++ .../protocols/anthropic_messages.py | 2 + src/rotator_library/protocols/base.py | 40 +++++++++++- src/rotator_library/protocols/gemini.py | 2 + .../protocols/litellm_fallback.py | 2 + src/rotator_library/protocols/openai_chat.py | 2 + src/rotator_library/protocols/operation.py | 38 +++++++++++ src/rotator_library/protocols/responses.py | 2 + src/rotator_library/protocols/types.py | 18 ++++++ tests/test_protocol_operation_model.py | 64 +++++++++++++++++++ 10 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 src/rotator_library/protocols/operation.py create mode 100644 tests/test_protocol_operation_model.py diff --git a/src/rotator_library/protocols/__init__.py b/src/rotator_library/protocols/__init__.py index a16b9fc20..f528f2046 100644 --- a/src/rotator_library/protocols/__init__.py +++ b/src/rotator_library/protocols/__init__.py @@ -9,6 +9,24 @@ """ from .base import ProtocolAdapter +from .operation import ( + OPERATION_AUDIO_TRANSCRIPTION, + OPERATION_AUDIO_TRANSLATION, + OPERATION_CHAT, + OPERATION_COUNT_TOKENS, + OPERATION_EMBEDDINGS, + OPERATION_IMAGE_EDIT, + OPERATION_IMAGE_GENERATION, + OPERATION_IMAGE_VARIATION, + OPERATION_MCP, + OPERATION_MESSAGES, + OPERATION_OLLAMA_CHAT, + OPERATION_OLLAMA_GENERATE, + OPERATION_RESPONSES, + OPERATION_SPEECH, + OPERATION_UNKNOWN, + normalize_operation, +) from .registry import ( PROTOCOL_ALIASES, PROTOCOL_PLUGINS, @@ -42,6 +60,21 @@ "PROTOCOL_PLUGINS", "ContentBlock", "CostDetails", + "OPERATION_AUDIO_TRANSCRIPTION", + "OPERATION_AUDIO_TRANSLATION", + "OPERATION_CHAT", + "OPERATION_COUNT_TOKENS", + "OPERATION_EMBEDDINGS", + "OPERATION_IMAGE_EDIT", + "OPERATION_IMAGE_GENERATION", + "OPERATION_IMAGE_VARIATION", + "OPERATION_MCP", + "OPERATION_MESSAGES", + "OPERATION_OLLAMA_CHAT", + "OPERATION_OLLAMA_GENERATE", + "OPERATION_RESPONSES", + "OPERATION_SPEECH", + "OPERATION_UNKNOWN", "ProtocolAdapter", "ProtocolContext", "ProtocolError", @@ -58,6 +91,7 @@ "get_protocol", "get_protocol_class", "list_protocols", + "normalize_operation", "register_protocol", "resolve_protocol_name", "serialize_value", diff --git a/src/rotator_library/protocols/anthropic_messages.py b/src/rotator_library/protocols/anthropic_messages.py index b9a81c248..665d02ebb 100644 --- a/src/rotator_library/protocols/anthropic_messages.py +++ b/src/rotator_library/protocols/anthropic_messages.py @@ -15,6 +15,7 @@ from typing import Any, ClassVar, Iterable from .base import ProtocolAdapter +from .operation import OPERATION_COUNT_TOKENS, OPERATION_MESSAGES from .types import ( ContentBlock, ProtocolContext, @@ -56,6 +57,7 @@ class AnthropicMessagesProtocol(ProtocolAdapter): name: ClassVar[str] = "anthropic_messages" aliases: ClassVar[tuple[str, ...]] = ("anthropic", "messages", "claude_messages") supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_MESSAGES, OPERATION_COUNT_TOKENS) def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) diff --git a/src/rotator_library/protocols/base.py b/src/rotator_library/protocols/base.py index 967d26926..293db1c05 100644 --- a/src/rotator_library/protocols/base.py +++ b/src/rotator_library/protocols/base.py @@ -13,6 +13,7 @@ from copy import deepcopy from typing import Any, ClassVar +from .operation import OPERATION_UNKNOWN, normalize_operation from .types import ( ProtocolContext, ProtocolError, @@ -36,6 +37,7 @@ class ProtocolAdapter: aliases: ClassVar[tuple[str, ...]] = () supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") future_transports: ClassVar[tuple[str, ...]] = () + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_UNKNOWN,) def supports_transport(self, transport_name: str) -> bool: """Return whether this protocol can format the requested transport.""" @@ -47,15 +49,34 @@ def is_future_transport(self, transport_name: str) -> bool: return transport_name in self.future_transports + def supports_operation(self, operation_name: str) -> bool: + """Return whether this adapter natively models an operation. + + Operation names are string based so custom protocols can add their own + values. The default base adapter only claims ``unknown`` and keeps raw + payloads intact; concrete adapters should list every operation they can + parse/build without relying on LiteLLM. + """ + + return normalize_operation(operation_name) in self.supported_operations + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: """Parse a raw client/provider request into a unified request.""" request = dict(raw_request or {}) return UnifiedRequest( + operation=normalize_operation(request.get("operation")), model=str(request.get("model") or getattr(context, "model", None) or ""), stream=bool(request.get("stream", False)), + input=deepcopy(request.get("input")), + modalities=list(request.get("modalities") or []), + files=list(request.get("files") or []), raw=deepcopy(raw_request), - extra={k: deepcopy(v) for k, v in request.items() if k not in {"model", "stream"}}, + extra={ + k: deepcopy(v) + for k, v in request.items() + if k not in {"operation", "model", "stream", "input", "modalities", "files"} + }, ) def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: @@ -76,6 +97,14 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex payload=unified_request.raw, ) payload = {"model": unified_request.model, "stream": unified_request.stream} + if unified_request.operation != OPERATION_UNKNOWN: + payload["operation"] = unified_request.operation + if unified_request.input is not None: + payload["input"] = deepcopy(unified_request.input) + if unified_request.modalities: + payload["modalities"] = deepcopy(unified_request.modalities) + if unified_request.files: + payload["files"] = deepcopy(unified_request.files) payload.update(deepcopy(unified_request.extra)) return payload @@ -84,8 +113,11 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No response = raw_response if isinstance(raw_response, dict) else {} return UnifiedResponse( + operation=normalize_operation(response.get("operation") if isinstance(response, dict) else None), id=response.get("id") if isinstance(response, dict) else None, model=response.get("model") if isinstance(response, dict) else getattr(context, "model", None), + data=deepcopy(response.get("data") or []) if isinstance(response, dict) else [], + content_type=response.get("content_type") if isinstance(response, dict) else None, raw=deepcopy(raw_response), extra=deepcopy(response) if isinstance(response, dict) else {}, ) @@ -96,10 +128,16 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo if isinstance(unified_response.raw, dict): return deepcopy(unified_response.raw) payload = deepcopy(unified_response.extra) + if unified_response.operation != OPERATION_UNKNOWN: + payload.setdefault("operation", unified_response.operation) if unified_response.id is not None: payload.setdefault("id", unified_response.id) if unified_response.model is not None: payload.setdefault("model", unified_response.model) + if unified_response.data: + payload.setdefault("data", deepcopy(unified_response.data)) + if unified_response.content_type is not None: + payload.setdefault("content_type", unified_response.content_type) return payload def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py index a812c41db..cf7547dc3 100644 --- a/src/rotator_library/protocols/gemini.py +++ b/src/rotator_library/protocols/gemini.py @@ -15,6 +15,7 @@ from typing import Any, ClassVar, Iterable from .base import ProtocolAdapter +from .operation import OPERATION_CHAT, OPERATION_COUNT_TOKENS from .types import ( ContentBlock, ProtocolContext, @@ -54,6 +55,7 @@ class GeminiProtocol(ProtocolAdapter): name: ClassVar[str] = "gemini" aliases: ClassVar[tuple[str, ...]] = ("google_gemini", "generate_content") supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_CHAT, OPERATION_COUNT_TOKENS) def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) diff --git a/src/rotator_library/protocols/litellm_fallback.py b/src/rotator_library/protocols/litellm_fallback.py index f8bd5a794..6c66352c6 100644 --- a/src/rotator_library/protocols/litellm_fallback.py +++ b/src/rotator_library/protocols/litellm_fallback.py @@ -13,6 +13,7 @@ from typing import ClassVar from .base import ProtocolAdapter +from .operation import OPERATION_UNKNOWN class LiteLLMFallbackProtocol(ProtocolAdapter): @@ -26,3 +27,4 @@ class LiteLLMFallbackProtocol(ProtocolAdapter): name: ClassVar[str] = "litellm_fallback" aliases: ClassVar[tuple[str, ...]] = ("litellm", "fallback") supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_UNKNOWN,) diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index aaaa32d94..9e496c806 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -16,6 +16,7 @@ from typing import Any, ClassVar, Iterable from .base import ProtocolAdapter +from .operation import OPERATION_CHAT from .types import ( ContentBlock, CostDetails, @@ -81,6 +82,7 @@ class OpenAIChatProtocol(ProtocolAdapter): "chat_completions", "openai_chat_completions", ) + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_CHAT,) supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: diff --git a/src/rotator_library/protocols/operation.py b/src/rotator_library/protocols/operation.py new file mode 100644 index 000000000..91d7cb5b8 --- /dev/null +++ b/src/rotator_library/protocols/operation.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Shared protocol operation names. + +Operations are deliberately plain strings instead of a closed enum. The native +protocol system is meant to be extended by local/custom providers, so the core +must provide well-known names without blocking new operations that are not known +today. +""" + +from __future__ import annotations + +from typing import Final + +OPERATION_UNKNOWN: Final[str] = "unknown" +OPERATION_CHAT: Final[str] = "chat" +OPERATION_MESSAGES: Final[str] = "messages" +OPERATION_RESPONSES: Final[str] = "responses" +OPERATION_COUNT_TOKENS: Final[str] = "count_tokens" +OPERATION_EMBEDDINGS: Final[str] = "embeddings" +OPERATION_IMAGE_GENERATION: Final[str] = "image_generation" +OPERATION_IMAGE_EDIT: Final[str] = "image_edit" +OPERATION_IMAGE_VARIATION: Final[str] = "image_variation" +OPERATION_AUDIO_TRANSCRIPTION: Final[str] = "audio_transcription" +OPERATION_AUDIO_TRANSLATION: Final[str] = "audio_translation" +OPERATION_SPEECH: Final[str] = "speech" +OPERATION_OLLAMA_CHAT: Final[str] = "ollama_chat" +OPERATION_OLLAMA_GENERATE: Final[str] = "ollama_generate" +OPERATION_MCP: Final[str] = "mcp" + + +def normalize_operation(operation: str | None) -> str: + """Normalize an operation name while preserving custom extensions.""" + + if not operation: + return OPERATION_UNKNOWN + return str(operation).strip().lower() or OPERATION_UNKNOWN diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py index 1330e72d7..a722495f8 100644 --- a/src/rotator_library/protocols/responses.py +++ b/src/rotator_library/protocols/responses.py @@ -15,6 +15,7 @@ from typing import Any, ClassVar, Iterable from .base import ProtocolAdapter +from .operation import OPERATION_RESPONSES from .types import ( ContentBlock, CostDetails, @@ -68,6 +69,7 @@ class ResponsesProtocol(ProtocolAdapter): name: ClassVar[str] = "responses" aliases: ClassVar[tuple[str, ...]] = ("openai_responses", "response_api") supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_RESPONSES,) future_transports: ClassVar[tuple[str, ...]] = ("websocket",) def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: diff --git a/src/rotator_library/protocols/types.py b/src/rotator_library/protocols/types.py index 4eb4a889a..b0fda21af 100644 --- a/src/rotator_library/protocols/types.py +++ b/src/rotator_library/protocols/types.py @@ -19,6 +19,8 @@ from pathlib import Path from typing import Any, ClassVar, Iterable, Mapping, Optional +from .operation import OPERATION_UNKNOWN + JsonObject = dict[str, Any] @@ -250,11 +252,15 @@ class UnifiedMessage(ProtocolSerializable): class UnifiedRequest(ProtocolSerializable): """A request after parsing from a client or provider protocol.""" + operation: str = OPERATION_UNKNOWN model: str = "" messages: list[UnifiedMessage] = field(default_factory=list) system: list[ContentBlock] = field(default_factory=list) tools: list[ToolDefinition] = field(default_factory=list) stream: bool = False + input: Any = None + modalities: list[str] = field(default_factory=list) + files: list[Any] = field(default_factory=list) generation_params: JsonObject = field(default_factory=dict) response_format: Any = None previous_response_id: Optional[str] = None @@ -263,11 +269,15 @@ class UnifiedRequest(ProtocolSerializable): extra: JsonObject = field(default_factory=dict) _fields: ClassVar[tuple[str, ...]] = ( + "operation", "model", "messages", "system", "tools", "stream", + "input", + "modalities", + "files", "generation_params", "response_format", "previous_response_id", @@ -281,10 +291,13 @@ class UnifiedRequest(ProtocolSerializable): class UnifiedResponse(ProtocolSerializable): """A complete provider/client response in protocol-neutral form.""" + operation: str = OPERATION_UNKNOWN id: Optional[str] = None model: Optional[str] = None messages: list[UnifiedMessage] = field(default_factory=list) output: list[Any] = field(default_factory=list) + data: list[Any] = field(default_factory=list) + content_type: Optional[str] = None stop_reason: Optional[str] = None usage: Optional[Usage] = None metadata: JsonObject = field(default_factory=dict) @@ -292,10 +305,13 @@ class UnifiedResponse(ProtocolSerializable): extra: JsonObject = field(default_factory=dict) _fields: ClassVar[tuple[str, ...]] = ( + "operation", "id", "model", "messages", "output", + "data", + "content_type", "stop_reason", "usage", "metadata", @@ -313,6 +329,7 @@ class UnifiedStreamEvent(ProtocolSerializable): """ type: str + operation: str = OPERATION_UNKNOWN delta: Optional[UnifiedMessage] = None message: Optional[UnifiedMessage] = None tool_call: Optional[ToolCall] = None @@ -323,6 +340,7 @@ class UnifiedStreamEvent(ProtocolSerializable): _fields: ClassVar[tuple[str, ...]] = ( "type", + "operation", "delta", "message", "tool_call", diff --git a/tests/test_protocol_operation_model.py b/tests/test_protocol_operation_model.py new file mode 100644 index 000000000..0f989269a --- /dev/null +++ b/tests/test_protocol_operation_model.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from rotator_library.protocols import ( + OPERATION_CHAT, + OPERATION_EMBEDDINGS, + OPERATION_UNKNOWN, + ProtocolAdapter, + UnifiedRequest, + UnifiedResponse, + get_protocol, + normalize_operation, +) + + +def test_operation_names_are_extensible_strings() -> None: + assert normalize_operation(" Chat ") == OPERATION_CHAT + assert normalize_operation(None) == OPERATION_UNKNOWN + assert normalize_operation("custom_gateway_op") == "custom_gateway_op" + + +def test_unified_request_response_carry_non_chat_operation_fields() -> None: + request = UnifiedRequest( + operation=OPERATION_EMBEDDINGS, + model="text-embedding-test", + input=["a", "b"], + modalities=["text"], + files=[{"name": "audio.wav", "content_type": "audio/wav"}], + ) + response = UnifiedResponse( + operation=OPERATION_EMBEDDINGS, + model="text-embedding-test", + data=[{"embedding": [0.1, 0.2]}], + content_type="application/json", + ) + + assert request.to_dict()["operation"] == OPERATION_EMBEDDINGS + assert request.to_dict()["input"] == ["a", "b"] + assert response.to_dict()["data"][0]["embedding"] == [0.1, 0.2] + assert response.to_dict()["content_type"] == "application/json" + + +def test_base_adapter_preserves_operation_fields() -> None: + adapter = ProtocolAdapter() + raw = { + "operation": "custom_op", + "model": "provider/model", + "input": "payload", + "modalities": ["text"], + "files": [{"name": "file.bin"}], + "custom": True, + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.operation == "custom_op" + assert unified.input == "payload" + assert rebuilt == raw + + +def test_protocols_advertise_supported_operations() -> None: + assert get_protocol("openai_chat").supports_operation(OPERATION_CHAT) + assert not get_protocol("openai_chat").supports_operation(OPERATION_EMBEDDINGS) + assert get_protocol("litellm_fallback").supports_operation(OPERATION_UNKNOWN) From 93487be4a7f3a9c078136b2a227e4df60655a90c Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 08:06:37 +0200 Subject: [PATCH 076/182] feat(protocols): add non-chat protocol adapters Adds first-class protocol foundations for OpenAI-compatible embeddings, images, audio transcription/speech, Ollama-native chat/generate/embeddings, and MCP JSON-RPC carrier envelopes. These adapters are parse/build foundations only; they do not claim live route wiring or replace existing execution paths. Raw payloads and extension fields are preserved for transform logging and future provider overrides. Tests: pytest tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_operation_model.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- src/rotator_library/protocols/mcp.py | 88 ++++++++++++ src/rotator_library/protocols/ollama.py | 127 ++++++++++++++++++ src/rotator_library/protocols/openai_audio.py | 89 ++++++++++++ .../protocols/openai_embeddings.py | 80 +++++++++++ .../protocols/openai_images.py | 93 +++++++++++++ tests/test_protocol_ollama_mcp.py | 33 +++++ tests/test_protocol_openai_embeddings.py | 28 ++++ tests/test_protocol_openai_images_audio.py | 38 ++++++ 8 files changed, 576 insertions(+) create mode 100644 src/rotator_library/protocols/mcp.py create mode 100644 src/rotator_library/protocols/ollama.py create mode 100644 src/rotator_library/protocols/openai_audio.py create mode 100644 src/rotator_library/protocols/openai_embeddings.py create mode 100644 src/rotator_library/protocols/openai_images.py create mode 100644 tests/test_protocol_ollama_mcp.py create mode 100644 tests/test_protocol_openai_embeddings.py create mode 100644 tests/test_protocol_openai_images_audio.py diff --git a/src/rotator_library/protocols/mcp.py b/src/rotator_library/protocols/mcp.py new file mode 100644 index 000000000..f04ae2c65 --- /dev/null +++ b/src/rotator_library/protocols/mcp.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""MCP JSON-RPC carrier protocol adapter. + +This is not a full MCP proxy implementation. It gives the native protocol layer a +lossless request/response carrier for future MCP gateway work, keeping method, +params, ids, results, and errors intact for transform logging and routing. +""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_MCP +from .types import ProtocolContext, UnifiedRequest, UnifiedResponse, UnifiedStreamEvent + +_REQUEST_CORE_FIELDS = {"jsonrpc", "id", "method", "params"} +_RESPONSE_CORE_FIELDS = {"jsonrpc", "id", "result", "error"} + + +class MCPProtocol(ProtocolAdapter): + """Adapter for MCP-style JSON-RPC request and response envelopes.""" + + name: ClassVar[str] = "mcp" + aliases: ClassVar[tuple[str, ...]] = ("model_context_protocol", "jsonrpc_mcp") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_MCP,) + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") + future_transports: ClassVar[tuple[str, ...]] = ("websocket",) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + metadata = { + "jsonrpc": request.get("jsonrpc", "2.0"), + "id": deepcopy(request.get("id")), + "method": request.get("method"), + } + return UnifiedRequest( + operation=OPERATION_MCP, + input=deepcopy(request.get("params") or {}), + metadata=metadata, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload = { + "jsonrpc": unified_request.metadata.get("jsonrpc", "2.0"), + "method": unified_request.metadata.get("method"), + "params": deepcopy(unified_request.input or {}), + } + if "id" in unified_request.metadata: + payload["id"] = deepcopy(unified_request.metadata.get("id")) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = raw_response if isinstance(raw_response, dict) else {} + metadata = {"jsonrpc": response.get("jsonrpc", "2.0"), "id": deepcopy(response.get("id"))} + extra = {k: deepcopy(v) for k, v in response.items() if k not in _RESPONSE_CORE_FIELDS} + if "error" in response: + # JSON-RPC errors are not provider exceptions here; they are protocol + # payloads that must survive transform logging and response rebuilds. + extra["error"] = deepcopy(response["error"]) + return UnifiedResponse( + operation=OPERATION_MCP, + data=[deepcopy(response["result"])] if "result" in response else [], + metadata=metadata, + raw=deepcopy(raw_response), + extra=extra, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + payload = {"jsonrpc": unified_response.metadata.get("jsonrpc", "2.0")} + if "id" in unified_response.metadata: + payload["id"] = deepcopy(unified_response.metadata.get("id")) + if unified_response.extra.get("error") is not None: + payload["error"] = deepcopy(unified_response.extra["error"]) + else: + payload["result"] = deepcopy(unified_response.data[0] if unified_response.data else None) + payload.update({k: deepcopy(v) for k, v in unified_response.extra.items() if k != "error"}) + return payload + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + data = raw_event if isinstance(raw_event, dict) else {"event": raw_event} + return UnifiedStreamEvent(type=str(data.get("method") or data.get("type") or "message"), operation=OPERATION_MCP, raw=deepcopy(raw_event), extra=deepcopy(data)) diff --git a/src/rotator_library/protocols/ollama.py b/src/rotator_library/protocols/ollama.py new file mode 100644 index 000000000..6ef36c6bb --- /dev/null +++ b/src/rotator_library/protocols/ollama.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Ollama-native chat, generate, and embeddings protocol adapter.""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_EMBEDDINGS, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, normalize_operation +from .types import ContentBlock, ProtocolContext, UnifiedMessage, UnifiedRequest, UnifiedResponse, UnifiedStreamEvent, Usage + +_OPTION_FIELDS = {"options", "format", "keep_alive", "template", "context", "raw", "suffix"} +_CORE_FIELDS = {"operation", "model", "messages", "prompt", "input", "stream", "system", *_OPTION_FIELDS} + + +class OllamaProtocol(ProtocolAdapter): + """Adapter for Ollama `/api/chat`, `/api/generate`, and embeddings shapes.""" + + name: ClassVar[str] = "ollama" + aliases: ClassVar[tuple[str, ...]] = ("ollama_native",) + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS) + supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse", "jsonl") + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + operation = _ollama_operation(request) + return UnifiedRequest( + operation=operation, + model=str(request.get("model") or getattr(context, "model", None) or ""), + messages=[_message_from_ollama(message) for message in request.get("messages") or []], + system=[ContentBlock(type="text", text=str(request["system"]))] if "system" in request else [], + stream=bool(request.get("stream", False)), + input=deepcopy(request.get("input") if operation == OPERATION_EMBEDDINGS else request.get("prompt")), + generation_params={k: deepcopy(request[k]) for k in _OPTION_FIELDS if k in request}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = {"model": unified_request.model, "stream": unified_request.stream} + if unified_request.operation == OPERATION_OLLAMA_CHAT: + payload["messages"] = [_message_to_ollama(message) for message in unified_request.messages] + elif unified_request.operation == OPERATION_EMBEDDINGS: + payload["input"] = deepcopy(unified_request.input) + else: + payload["prompt"] = deepcopy(unified_request.input) + if unified_request.system: + payload["system"] = "".join(block.text or "" for block in unified_request.system if block.type == "text") + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = raw_response if isinstance(raw_response, dict) else {} + output = [] + if isinstance(response.get("message"), dict): + output.append(_message_from_ollama(response["message"]).to_dict()) + elif "response" in response: + output.append(response.get("response")) + return UnifiedResponse( + operation=_ollama_operation(response), + model=response.get("model") or getattr(context, "model", None), + output=output, + data=deepcopy(response.get("embeddings") or response.get("embedding") or []), + usage=_ollama_usage(response), + raw=deepcopy(raw_response), + extra=deepcopy(response), + ) + + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: + data = _json_event(raw_event) + if not isinstance(data, dict): + return UnifiedStreamEvent(type="metadata", raw=deepcopy(raw_event), extra={"unparsed": True}) + event_type = "done" if data.get("done") else "message_delta" + delta = None + if isinstance(data.get("message"), dict): + delta = _message_from_ollama(data["message"]) + elif data.get("response"): + delta = UnifiedMessage(role="assistant", content=[ContentBlock(type="text", text=str(data["response"]))]) + return UnifiedStreamEvent(type=event_type, operation=_ollama_operation(data), delta=delta, usage=_ollama_usage(data), raw=deepcopy(raw_event), extra=deepcopy(data)) + + +def _ollama_operation(request: dict[str, Any]) -> str: + explicit = normalize_operation(request.get("operation")) + if explicit in {OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS}: + return explicit + if "messages" in request or "message" in request: + return OPERATION_OLLAMA_CHAT + if "embeddings" in request or "embedding" in request or "input" in request: + return OPERATION_EMBEDDINGS + return OPERATION_OLLAMA_GENERATE + + +def _message_from_ollama(message: dict[str, Any]) -> UnifiedMessage: + return UnifiedMessage(role=str(message.get("role") or "assistant"), content=[ContentBlock(type="text", text=str(message.get("content") or ""))], raw=deepcopy(message), extra={k: deepcopy(v) for k, v in message.items() if k not in {"role", "content"}}) + + +def _message_to_ollama(message: UnifiedMessage) -> dict[str, Any]: + payload = {"role": message.role, "content": "".join(block.text or "" for block in message.content if block.type == "text")} + payload.update(deepcopy(message.extra)) + return payload + + +def _json_event(raw_event: Any) -> Any: + if isinstance(raw_event, dict): + return raw_event + if not isinstance(raw_event, str): + return None + text = raw_event.strip() + if text.startswith("data:"): + text = text[5:].strip() + try: + return json.loads(text) + except json.JSONDecodeError: + return None + + +def _ollama_usage(response: dict[str, Any]) -> Usage | None: + prompt_tokens = int(response.get("prompt_eval_count") or 0) + output_tokens = int(response.get("eval_count") or 0) + if not prompt_tokens and not output_tokens: + return None + return Usage(input_tokens=prompt_tokens, output_tokens=output_tokens, raw={k: deepcopy(v) for k, v in response.items() if k.endswith("count") or k.endswith("duration")}) diff --git a/src/rotator_library/protocols/openai_audio.py b/src/rotator_library/protocols/openai_audio.py new file mode 100644 index 000000000..28da9a0f4 --- /dev/null +++ b/src/rotator_library/protocols/openai_audio.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI-compatible audio transcription/translation and speech protocol adapter.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH, normalize_operation +from .types import ProtocolContext, UnifiedRequest, UnifiedResponse + +_AUDIO_OPTION_FIELDS = {"language", "prompt", "response_format", "temperature", "timestamp_granularities"} +_SPEECH_OPTION_FIELDS = {"voice", "response_format", "speed"} +_CORE_FIELDS = {"operation", "model", "file", "input", *_AUDIO_OPTION_FIELDS, *_SPEECH_OPTION_FIELDS} + + +class OpenAIAudioProtocol(ProtocolAdapter): + """Adapter for OpenAI audio transcription/translation and speech requests. + + Audio file and generated-audio bytes are intentionally preserved instead of + interpreted. Transports/providers decide multipart and binary handling; this + protocol only records enough structure for routing, tracing, and tests. + """ + + name: ClassVar[str] = "openai_audio" + aliases: ClassVar[tuple[str, ...]] = ("audio", "audio_transcription", "speech", "tts") + supported_operations: ClassVar[tuple[str, ...]] = ( + OPERATION_AUDIO_TRANSCRIPTION, + OPERATION_AUDIO_TRANSLATION, + OPERATION_SPEECH, + ) + supported_transports: ClassVar[tuple[str, ...]] = ("http",) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + operation = _audio_operation(request) + files = [] + if "file" in request: + files.append({"field": "file", "value": deepcopy(request["file"])}) + options = _SPEECH_OPTION_FIELDS if operation == OPERATION_SPEECH else _AUDIO_OPTION_FIELDS + return UnifiedRequest( + operation=operation, + model=str(request.get("model") or getattr(context, "model", None) or ""), + input=deepcopy(request.get("input") if operation == OPERATION_SPEECH else request.get("prompt")), + files=files, + generation_params={k: deepcopy(request[k]) for k in options if k in request}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = {"model": unified_request.model} + if unified_request.operation == OPERATION_SPEECH: + payload["input"] = deepcopy(unified_request.input) + elif unified_request.input is not None: + payload["prompt"] = deepcopy(unified_request.input) + for file_entry in unified_request.files: + if isinstance(file_entry, dict) and file_entry.get("field"): + payload[str(file_entry["field"])] = deepcopy(file_entry.get("value")) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + if isinstance(raw_response, dict): + response = raw_response + return UnifiedResponse( + operation=normalize_operation(response.get("operation")), + model=response.get("model") or getattr(context, "model", None), + output=[deepcopy(response["text"])] if "text" in response else [], + data=deepcopy(response.get("data") or []), + content_type=response.get("content_type") or "application/json", + raw=deepcopy(raw_response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"operation", "model", "text", "data", "content_type"}}, + ) + content_type = "text/plain" if isinstance(raw_response, str) else "application/octet-stream" + return UnifiedResponse(content_type=content_type, raw=deepcopy(raw_response), output=[deepcopy(raw_response)] if isinstance(raw_response, str) else []) + + +def _audio_operation(request: dict[str, Any]) -> str: + explicit = normalize_operation(request.get("operation")) + if explicit in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: + return explicit + if "voice" in request or ("input" in request and "file" not in request): + return OPERATION_SPEECH + return OPERATION_AUDIO_TRANSCRIPTION diff --git a/src/rotator_library/protocols/openai_embeddings.py b/src/rotator_library/protocols/openai_embeddings.py new file mode 100644 index 000000000..1ba955b99 --- /dev/null +++ b/src/rotator_library/protocols/openai_embeddings.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI-compatible embeddings protocol adapter.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_EMBEDDINGS +from .types import ProtocolContext, UnifiedRequest, UnifiedResponse, Usage + +_REQUEST_CORE_FIELDS = {"model", "input", "encoding_format", "dimensions", "user", "operation"} +_REQUEST_OPTION_FIELDS = {"encoding_format", "dimensions", "user"} + + +class OpenAIEmbeddingsProtocol(ProtocolAdapter): + """Adapter for `/v1/embeddings` style request and response payloads. + + The adapter intentionally treats embedding vectors as opaque data entries. + That keeps it usable for OpenAI-compatible providers with additional index, + metadata, or sparse-vector fields without narrowing the schema too early. + """ + + name: ClassVar[str] = "openai_embeddings" + aliases: ClassVar[tuple[str, ...]] = ("embeddings", "openai_embedding") + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_EMBEDDINGS,) + supported_transports: ClassVar[tuple[str, ...]] = ("http",) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + return UnifiedRequest( + operation=OPERATION_EMBEDDINGS, + model=str(request.get("model") or getattr(context, "model", None) or ""), + input=deepcopy(request.get("input")), + generation_params={k: deepcopy(request[k]) for k in _REQUEST_OPTION_FIELDS if k in request}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload = {"model": unified_request.model, "input": deepcopy(unified_request.input)} + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = raw_response if isinstance(raw_response, dict) else {} + return UnifiedResponse( + operation=OPERATION_EMBEDDINGS, + model=response.get("model") or getattr(context, "model", None), + data=deepcopy(response.get("data") or []), + usage=self.extract_usage(response, context), + raw=deepcopy(raw_response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "data", "usage"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + payload = {"object": "list", "data": deepcopy(unified_response.data)} + if unified_response.model: + payload["model"] = unified_response.model + if unified_response.usage: + payload["usage"] = unified_response.usage.raw or unified_response.usage.to_dict() + payload.update(deepcopy(unified_response.extra)) + return payload + + def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: + if isinstance(raw_or_unified, UnifiedResponse): + return raw_or_unified.usage + usage = raw_or_unified.get("usage") if isinstance(raw_or_unified, dict) else None + if not isinstance(usage, dict): + return None + return Usage( + input_tokens=int(usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("total_tokens") or 0), + output_tokens=int(usage.get("output_tokens") or 0), + total_tokens=int(usage.get("total_tokens") or 0), + raw=deepcopy(usage), + ) diff --git a/src/rotator_library/protocols/openai_images.py b/src/rotator_library/protocols/openai_images.py new file mode 100644 index 000000000..84de94549 --- /dev/null +++ b/src/rotator_library/protocols/openai_images.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""OpenAI-compatible image generation/edit/variation protocol adapter.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, ClassVar + +from .base import ProtocolAdapter +from .operation import OPERATION_IMAGE_EDIT, OPERATION_IMAGE_GENERATION, OPERATION_IMAGE_VARIATION, normalize_operation +from .types import ProtocolContext, UnifiedRequest, UnifiedResponse + +_OPTION_FIELDS = {"n", "size", "quality", "style", "response_format", "user", "background", "moderation"} +_CORE_FIELDS = {"operation", "model", "prompt", "image", "mask", *_OPTION_FIELDS} + + +class OpenAIImagesProtocol(ProtocolAdapter): + """Adapter for image generation, edit, and variation request shapes. + + File references are preserved as metadata in ``UnifiedRequest.files``. The + adapter never reads file contents; multipart assembly belongs to transport or + provider execution code so protocol parsing remains side-effect free. + """ + + name: ClassVar[str] = "openai_images" + aliases: ClassVar[tuple[str, ...]] = ("images", "image_generation", "openai_image") + supported_operations: ClassVar[tuple[str, ...]] = ( + OPERATION_IMAGE_GENERATION, + OPERATION_IMAGE_EDIT, + OPERATION_IMAGE_VARIATION, + ) + supported_transports: ClassVar[tuple[str, ...]] = ("http",) + + def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + request = dict(raw_request or {}) + operation = _image_operation(request) + files = [] + if "image" in request: + files.append({"field": "image", "value": deepcopy(request["image"])}) + if "mask" in request: + files.append({"field": "mask", "value": deepcopy(request["mask"])}) + return UnifiedRequest( + operation=operation, + model=str(request.get("model") or getattr(context, "model", None) or ""), + input=deepcopy(request.get("prompt")), + files=files, + generation_params={k: deepcopy(request[k]) for k in _OPTION_FIELDS if k in request}, + raw=deepcopy(raw_request), + extra={k: deepcopy(v) for k, v in request.items() if k not in _CORE_FIELDS}, + ) + + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + payload: dict[str, Any] = {} + if unified_request.model: + payload["model"] = unified_request.model + if unified_request.input is not None: + payload["prompt"] = deepcopy(unified_request.input) + for file_entry in unified_request.files: + if isinstance(file_entry, dict) and file_entry.get("field"): + payload[str(file_entry["field"])] = deepcopy(file_entry.get("value")) + payload.update(deepcopy(unified_request.generation_params)) + payload.update(deepcopy(unified_request.extra)) + return payload + + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + response = raw_response if isinstance(raw_response, dict) else {} + return UnifiedResponse( + operation=OPERATION_IMAGE_GENERATION, + model=response.get("model") or getattr(context, "model", None), + data=deepcopy(response.get("data") or []), + raw=deepcopy(raw_response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "data"}}, + ) + + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + payload = {"data": deepcopy(unified_response.data)} + if unified_response.model: + payload["model"] = unified_response.model + payload.update(deepcopy(unified_response.extra)) + return payload + + +def _image_operation(request: dict[str, Any]) -> str: + explicit = normalize_operation(request.get("operation")) + if explicit in {OPERATION_IMAGE_GENERATION, OPERATION_IMAGE_EDIT, OPERATION_IMAGE_VARIATION}: + return explicit + if "image" in request and "prompt" in request: + return OPERATION_IMAGE_EDIT + if "image" in request: + return OPERATION_IMAGE_VARIATION + return OPERATION_IMAGE_GENERATION diff --git a/tests/test_protocol_ollama_mcp.py b/tests/test_protocol_ollama_mcp.py new file mode 100644 index 000000000..aeca7f7f4 --- /dev/null +++ b/tests/test_protocol_ollama_mcp.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from rotator_library.protocols import OPERATION_MCP, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, get_protocol + + +def test_ollama_chat_generate_and_stream_shapes() -> None: + adapter = get_protocol("ollama") + chat = adapter.parse_request({"model": "llama3", "messages": [{"role": "user", "content": "hi"}], "stream": True}) + generate = adapter.parse_request({"model": "llama3", "prompt": "write", "options": {"temperature": 0.1}}) + chunk = adapter.parse_stream_event('{"model":"llama3","response":"he","done":false,"eval_count":2}') + + assert chat.operation == OPERATION_OLLAMA_CHAT + assert adapter.build_request(chat)["messages"][0]["content"] == "hi" + assert generate.operation == OPERATION_OLLAMA_GENERATE + assert adapter.build_request(generate)["prompt"] == "write" + assert chunk.delta is not None + assert chunk.delta.content[0].text == "he" + assert chunk.usage is not None + assert chunk.usage.output_tokens == 2 + + +def test_mcp_jsonrpc_round_trip_and_error_preservation() -> None: + adapter = get_protocol("mcp") + request = {"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "lookup"}} + unified = adapter.parse_request(request) + rebuilt = adapter.build_request(unified) + response = adapter.parse_response({"jsonrpc": "2.0", "id": 1, "result": {"content": []}}) + error = adapter.parse_response({"jsonrpc": "2.0", "id": 1, "error": {"code": -1, "message": "failed"}}) + + assert unified.operation == OPERATION_MCP + assert rebuilt == request + assert adapter.format_response(response)["result"] == {"content": []} + assert adapter.format_response(error)["error"]["message"] == "failed" diff --git a/tests/test_protocol_openai_embeddings.py b/tests/test_protocol_openai_embeddings.py new file mode 100644 index 000000000..d8a63a288 --- /dev/null +++ b/tests/test_protocol_openai_embeddings.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from rotator_library.protocols import OPERATION_EMBEDDINGS, get_protocol, list_protocols + + +def test_openai_embeddings_protocol_round_trip_and_usage() -> None: + adapter = get_protocol("openai_embeddings") + raw = {"model": "text-embedding-test", "input": ["one", "two"], "dimensions": 128, "custom": True} + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + response = adapter.parse_response( + { + "model": "text-embedding-test", + "data": [{"object": "embedding", "embedding": [0.1, 0.2], "index": 0}], + "usage": {"prompt_tokens": 4, "total_tokens": 4}, + } + ) + + assert "openai_embeddings" in list_protocols() + assert get_protocol("embeddings") is adapter + assert adapter.supports_operation(OPERATION_EMBEDDINGS) + assert unified.operation == OPERATION_EMBEDDINGS + assert unified.input == ["one", "two"] + assert rebuilt == raw + assert response.data[0]["embedding"] == [0.1, 0.2] + assert response.usage is not None + assert response.usage.input_tokens == 4 diff --git a/tests/test_protocol_openai_images_audio.py b/tests/test_protocol_openai_images_audio.py new file mode 100644 index 000000000..be632fe11 --- /dev/null +++ b/tests/test_protocol_openai_images_audio.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from rotator_library.protocols import ( + OPERATION_AUDIO_TRANSCRIPTION, + OPERATION_IMAGE_EDIT, + OPERATION_IMAGE_GENERATION, + OPERATION_SPEECH, + get_protocol, +) + + +def test_openai_images_generation_and_edit_shapes() -> None: + adapter = get_protocol("openai_images") + generation = adapter.parse_request({"model": "gpt-image-test", "prompt": "draw a red cube", "size": "1024x1024"}) + edit = adapter.parse_request({"model": "gpt-image-test", "prompt": "make it blue", "image": "image-ref", "mask": "mask-ref"}) + response = adapter.parse_response({"data": [{"url": "https://example.test/image.png", "revised_prompt": "cube"}]}) + + assert generation.operation == OPERATION_IMAGE_GENERATION + assert adapter.build_request(generation)["prompt"] == "draw a red cube" + assert edit.operation == OPERATION_IMAGE_EDIT + assert {entry["field"] for entry in edit.files} == {"image", "mask"} + assert adapter.build_request(edit)["image"] == "image-ref" + assert response.data[0]["revised_prompt"] == "cube" + + +def test_openai_audio_transcription_and_speech_shapes() -> None: + adapter = get_protocol("openai_audio") + transcription = adapter.parse_request({"model": "whisper-test", "file": "audio-ref", "language": "en"}) + speech = adapter.parse_request({"model": "tts-test", "input": "hello", "voice": "alloy"}) + text_response = adapter.parse_response({"text": "hello world"}) + binary_response = adapter.parse_response(b"RIFF") + + assert transcription.operation == OPERATION_AUDIO_TRANSCRIPTION + assert transcription.files[0]["value"] == "audio-ref" + assert speech.operation == OPERATION_SPEECH + assert adapter.build_request(speech)["voice"] == "alloy" + assert text_response.output == ["hello world"] + assert binary_response.content_type == "application/octet-stream" From 41b6284993d17b7d065a0aa0d627ced83599a24e Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 08:18:26 +0200 Subject: [PATCH 077/182] fix(protocols): stamp operation metadata consistently Closes Phase 1b review findings by stamping parsed core protocol requests, responses, and stream events with their operations, excluding operation.py from protocol auto-discovery, and tightening multi-operation response fidelity. Hardens MCP JSON-RPC preservation for notifications and falsey params, narrows future-only transport metadata, and adds count-token operation coverage for Anthropic/Gemini protocol shapes. Tests: pytest tests/test_protocol_operation_model.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- .../protocols/anthropic_messages.py | 37 +++++++++++++++---- src/rotator_library/protocols/gemini.py | 32 ++++++++++++++-- src/rotator_library/protocols/mcp.py | 24 +++++++----- src/rotator_library/protocols/ollama.py | 8 ++-- src/rotator_library/protocols/openai_audio.py | 17 ++++++++- src/rotator_library/protocols/openai_chat.py | 7 +++- .../protocols/openai_images.py | 14 ++++++- src/rotator_library/protocols/registry.py | 2 +- src/rotator_library/protocols/responses.py | 14 ++++--- tests/test_protocol_ollama_mcp.py | 9 +++++ tests/test_protocol_openai_images_audio.py | 11 ++++++ tests/test_protocol_operation_model.py | 29 +++++++++++++++ 12 files changed, 168 insertions(+), 36 deletions(-) diff --git a/src/rotator_library/protocols/anthropic_messages.py b/src/rotator_library/protocols/anthropic_messages.py index 665d02ebb..0df594828 100644 --- a/src/rotator_library/protocols/anthropic_messages.py +++ b/src/rotator_library/protocols/anthropic_messages.py @@ -15,7 +15,7 @@ from typing import Any, ClassVar, Iterable from .base import ProtocolAdapter -from .operation import OPERATION_COUNT_TOKENS, OPERATION_MESSAGES +from .operation import OPERATION_COUNT_TOKENS, OPERATION_MESSAGES, OPERATION_UNKNOWN, normalize_operation from .types import ( ContentBlock, ProtocolContext, @@ -62,6 +62,7 @@ class AnthropicMessagesProtocol(ProtocolAdapter): def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) return UnifiedRequest( + operation=_operation_from_context(context, OPERATION_MESSAGES), model=str(request.get("model") or getattr(context, "model", None) or ""), messages=[self._parse_message(message) for message in request.get("messages") or []], system=self._parse_system(request.get("system")), @@ -101,6 +102,7 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) self._promote_message_blocks(message) return UnifiedResponse( + operation=_response_operation(response, context), id=response.get("id"), model=response.get("model") or getattr(context, "model", None), messages=[message] if response else [], @@ -129,20 +131,20 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: event = _decode_sse_data(raw_event) if event == "[DONE]": - return UnifiedStreamEvent(type="done", raw=deepcopy(raw_event)) + return UnifiedStreamEvent(type="done", operation=OPERATION_MESSAGES, raw=deepcopy(raw_event)) data = _as_dict(event) event_type = str(data.get("type") or "chunk") if event_type == "error" or data.get("error") is not None: - return UnifiedStreamEvent(type="error", error=deepcopy(data.get("error", data)), raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type="error", operation=OPERATION_MESSAGES, error=deepcopy(data.get("error", data)), raw=deepcopy(raw_event), extra={"payload": data}) if event_type == "message_start": response = self.parse_response(data.get("message") or {}, context) - return UnifiedStreamEvent(type="message_start", message=response.messages[0] if response.messages else None, usage=response.usage, raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type="message_start", operation=OPERATION_MESSAGES, message=response.messages[0] if response.messages else None, usage=response.usage, raw=deepcopy(raw_event), extra={"payload": data}) if event_type == "message_delta": - return UnifiedStreamEvent(type="message_delta", usage=self.extract_usage(data.get("usage") or {}, context), raw=deepcopy(raw_event), extra={"payload": data, "stop_reason": (data.get("delta") or {}).get("stop_reason")}) + return UnifiedStreamEvent(type="message_delta", operation=OPERATION_MESSAGES, usage=self.extract_usage(data.get("usage") or {}, context), raw=deepcopy(raw_event), extra={"payload": data, "stop_reason": (data.get("delta") or {}).get("stop_reason")}) if event_type in {"content_block_start", "content_block_delta", "content_block_stop"}: return self._parse_content_stream_event(data, raw_event) - return UnifiedStreamEvent(type=event_type, raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type=event_type, operation=OPERATION_MESSAGES, raw=deepcopy(raw_event), extra={"payload": data}) def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): @@ -322,7 +324,28 @@ def _parse_content_stream_event(self, data: dict[str, Any], raw_event: Any) -> U content_block = ContentBlock(type=str(delta_type), reasoning=reasoning, raw=deepcopy(delta)) message = UnifiedMessage(role="assistant", content=[content_block] if content_block else []) self._promote_message_blocks(message) - return UnifiedStreamEvent(type=str(data.get("type") or "content_block_delta"), delta=message, raw=deepcopy(raw_event), extra={"payload": data, "index": data.get("index")}) + return UnifiedStreamEvent(type=str(data.get("type") or "content_block_delta"), operation=OPERATION_MESSAGES, delta=message, raw=deepcopy(raw_event), extra={"payload": data, "index": data.get("index")}) + + +def _operation_from_context(context: ProtocolContext | None, default: str) -> str: + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation != OPERATION_UNKNOWN: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation != OPERATION_UNKNOWN: + return operation + return default + + +def _response_operation(response: dict[str, Any], context: ProtocolContext | None) -> str: + requested = _operation_from_context(context, OPERATION_MESSAGES) + if requested == OPERATION_COUNT_TOKENS: + return OPERATION_COUNT_TOKENS + if "input_tokens" in response and not response.get("content") and not response.get("id"): + return OPERATION_COUNT_TOKENS + return OPERATION_MESSAGES def _as_dict(value: Any) -> dict[str, Any]: diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py index cf7547dc3..f1fad2da4 100644 --- a/src/rotator_library/protocols/gemini.py +++ b/src/rotator_library/protocols/gemini.py @@ -15,7 +15,7 @@ from typing import Any, ClassVar, Iterable from .base import ProtocolAdapter -from .operation import OPERATION_CHAT, OPERATION_COUNT_TOKENS +from .operation import OPERATION_CHAT, OPERATION_COUNT_TOKENS, OPERATION_UNKNOWN, normalize_operation from .types import ( ContentBlock, ProtocolContext, @@ -62,6 +62,7 @@ def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | generation_config = deepcopy(request.get("generationConfig") or request.get("generation_config") or {}) safety_settings = deepcopy(request.get("safetySettings") or request.get("safety_settings") or []) return UnifiedRequest( + operation=_operation_from_context(context, OPERATION_CHAT), model=str(request.get("model") or getattr(context, "model", None) or ""), messages=[self._parse_content(content) for content in request.get("contents") or []], system=self._parse_system(request.get("systemInstruction") or request.get("system_instruction")), @@ -107,6 +108,7 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No if candidate.get("finishReason") is not None: stop_reason = candidate.get("finishReason") return UnifiedResponse( + operation=_response_operation(response, context), id=response.get("responseId") or response.get("id"), model=response.get("modelVersion") or getattr(context, "model", None), messages=messages, @@ -138,12 +140,13 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: event = _decode_sse_data(raw_event) if event == "[DONE]": - return UnifiedStreamEvent(type="done", raw=deepcopy(raw_event)) + return UnifiedStreamEvent(type="done", operation=OPERATION_CHAT, raw=deepcopy(raw_event)) data = _as_dict(event) response = self.parse_response(data, context) message = response.messages[0] if response.messages else None return UnifiedStreamEvent( type="message_delta" if message else "chunk", + operation=response.operation, delta=message, usage=response.usage, raw=deepcopy(raw_event), @@ -155,7 +158,7 @@ def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = N return raw_or_unified.usage payload = _as_dict(raw_or_unified) usage = payload.get("usageMetadata") if isinstance(payload.get("usageMetadata"), dict) else payload - if not isinstance(usage, dict) or not any(key.endswith("TokenCount") for key in usage): + if not isinstance(usage, dict) or (not any(key.endswith("TokenCount") for key in usage) and "totalTokens" not in usage): return None input_tokens = int(usage.get("promptTokenCount") or 0) output_tokens = int(usage.get("candidatesTokenCount") or 0) @@ -163,7 +166,7 @@ def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = N return Usage( input_tokens=input_tokens, output_tokens=output_tokens, - total_tokens=int(usage.get("totalTokenCount") or input_tokens + output_tokens + reasoning_tokens), + total_tokens=int(usage.get("totalTokenCount") or usage.get("totalTokens") or input_tokens + output_tokens + reasoning_tokens), cache_read_tokens=int(usage.get("cachedContentTokenCount") or 0), reasoning_tokens=reasoning_tokens, raw=deepcopy(usage), @@ -345,6 +348,27 @@ def _as_dict(value: Any) -> dict[str, Any]: return {} +def _operation_from_context(context: ProtocolContext | None, default: str) -> str: + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation != OPERATION_UNKNOWN: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation != OPERATION_UNKNOWN: + return operation + return default + + +def _response_operation(response: dict[str, Any], context: ProtocolContext | None) -> str: + requested = _operation_from_context(context, OPERATION_CHAT) + if requested == OPERATION_COUNT_TOKENS: + return OPERATION_COUNT_TOKENS + if "totalTokens" in response and "candidates" not in response: + return OPERATION_COUNT_TOKENS + return OPERATION_CHAT + + def _decode_sse_data(raw_event: Any) -> Any: if not isinstance(raw_event, str): return raw_event diff --git a/src/rotator_library/protocols/mcp.py b/src/rotator_library/protocols/mcp.py index f04ae2c65..845b210cc 100644 --- a/src/rotator_library/protocols/mcp.py +++ b/src/rotator_library/protocols/mcp.py @@ -27,19 +27,22 @@ class MCPProtocol(ProtocolAdapter): name: ClassVar[str] = "mcp" aliases: ClassVar[tuple[str, ...]] = ("model_context_protocol", "jsonrpc_mcp") supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_MCP,) - supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") - future_transports: ClassVar[tuple[str, ...]] = ("websocket",) + supported_transports: ClassVar[tuple[str, ...]] = ("http",) + future_transports: ClassVar[tuple[str, ...]] = ("sse", "websocket") def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) metadata = { "jsonrpc": request.get("jsonrpc", "2.0"), - "id": deepcopy(request.get("id")), "method": request.get("method"), + "has_id": "id" in request, + "has_params": "params" in request, } + if "id" in request: + metadata["id"] = deepcopy(request.get("id")) return UnifiedRequest( operation=OPERATION_MCP, - input=deepcopy(request.get("params") or {}), + input=deepcopy(request.get("params")) if "params" in request else None, metadata=metadata, raw=deepcopy(raw_request), extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, @@ -49,16 +52,19 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex payload = { "jsonrpc": unified_request.metadata.get("jsonrpc", "2.0"), "method": unified_request.metadata.get("method"), - "params": deepcopy(unified_request.input or {}), } - if "id" in unified_request.metadata: + if unified_request.metadata.get("has_params", True): + payload["params"] = deepcopy(unified_request.input) + if unified_request.metadata.get("has_id"): payload["id"] = deepcopy(unified_request.metadata.get("id")) payload.update(deepcopy(unified_request.extra)) return payload def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: response = raw_response if isinstance(raw_response, dict) else {} - metadata = {"jsonrpc": response.get("jsonrpc", "2.0"), "id": deepcopy(response.get("id"))} + metadata = {"jsonrpc": response.get("jsonrpc", "2.0"), "has_id": "id" in response} + if "id" in response: + metadata["id"] = deepcopy(response.get("id")) extra = {k: deepcopy(v) for k, v in response.items() if k not in _RESPONSE_CORE_FIELDS} if "error" in response: # JSON-RPC errors are not provider exceptions here; they are protocol @@ -74,7 +80,7 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: payload = {"jsonrpc": unified_response.metadata.get("jsonrpc", "2.0")} - if "id" in unified_response.metadata: + if unified_response.metadata.get("has_id"): payload["id"] = deepcopy(unified_response.metadata.get("id")) if unified_response.extra.get("error") is not None: payload["error"] = deepcopy(unified_response.extra["error"]) @@ -85,4 +91,4 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: data = raw_event if isinstance(raw_event, dict) else {"event": raw_event} - return UnifiedStreamEvent(type=str(data.get("method") or data.get("type") or "message"), operation=OPERATION_MCP, raw=deepcopy(raw_event), extra=deepcopy(data)) + return UnifiedStreamEvent(type=str(data.get("method") or data.get("type") or "message"), operation=OPERATION_MCP, error=deepcopy(data.get("error")), raw=deepcopy(raw_event), extra=deepcopy(data)) diff --git a/src/rotator_library/protocols/ollama.py b/src/rotator_library/protocols/ollama.py index 6ef36c6bb..f10af5116 100644 --- a/src/rotator_library/protocols/ollama.py +++ b/src/rotator_library/protocols/ollama.py @@ -23,7 +23,7 @@ class OllamaProtocol(ProtocolAdapter): name: ClassVar[str] = "ollama" aliases: ClassVar[tuple[str, ...]] = ("ollama_native",) supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS) - supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse", "jsonl") + supported_transports: ClassVar[tuple[str, ...]] = ("http", "jsonl") def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) @@ -68,13 +68,13 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No data=deepcopy(response.get("embeddings") or response.get("embedding") or []), usage=_ollama_usage(response), raw=deepcopy(raw_response), - extra=deepcopy(response), + extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "message", "response", "embeddings", "embedding", "prompt_eval_count", "eval_count"}}, ) def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: data = _json_event(raw_event) if not isinstance(data, dict): - return UnifiedStreamEvent(type="metadata", raw=deepcopy(raw_event), extra={"unparsed": True}) + return UnifiedStreamEvent(type="metadata", operation=OPERATION_OLLAMA_GENERATE, raw=deepcopy(raw_event), extra={"unparsed": True}) event_type = "done" if data.get("done") else "message_delta" delta = None if isinstance(data.get("message"), dict): @@ -90,7 +90,7 @@ def _ollama_operation(request: dict[str, Any]) -> str: return explicit if "messages" in request or "message" in request: return OPERATION_OLLAMA_CHAT - if "embeddings" in request or "embedding" in request or "input" in request: + if "embeddings" in request or "embedding" in request or "input" in request or request.get("endpoint") == "embeddings": return OPERATION_EMBEDDINGS return OPERATION_OLLAMA_GENERATE diff --git a/src/rotator_library/protocols/openai_audio.py b/src/rotator_library/protocols/openai_audio.py index 28da9a0f4..b626f8c3b 100644 --- a/src/rotator_library/protocols/openai_audio.py +++ b/src/rotator_library/protocols/openai_audio.py @@ -68,7 +68,7 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No if isinstance(raw_response, dict): response = raw_response return UnifiedResponse( - operation=normalize_operation(response.get("operation")), + operation=_context_operation(context, normalize_operation(response.get("operation"))), model=response.get("model") or getattr(context, "model", None), output=[deepcopy(response["text"])] if "text" in response else [], data=deepcopy(response.get("data") or []), @@ -77,7 +77,8 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No extra={k: deepcopy(v) for k, v in response.items() if k not in {"operation", "model", "text", "data", "content_type"}}, ) content_type = "text/plain" if isinstance(raw_response, str) else "application/octet-stream" - return UnifiedResponse(content_type=content_type, raw=deepcopy(raw_response), output=[deepcopy(raw_response)] if isinstance(raw_response, str) else []) + default_operation = OPERATION_AUDIO_TRANSCRIPTION if isinstance(raw_response, str) else OPERATION_SPEECH + return UnifiedResponse(operation=_context_operation(context, default_operation), content_type=content_type, raw=deepcopy(raw_response), output=[deepcopy(raw_response)] if isinstance(raw_response, str) else []) def _audio_operation(request: dict[str, Any]) -> str: @@ -87,3 +88,15 @@ def _audio_operation(request: dict[str, Any]) -> str: if "voice" in request or ("input" in request and "file" not in request): return OPERATION_SPEECH return OPERATION_AUDIO_TRANSCRIPTION + + +def _context_operation(context: ProtocolContext | None, default: str) -> str: + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: + return operation + return default diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index 9e496c806..119572fcd 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -93,6 +93,7 @@ def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | extra = {k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS} return UnifiedRequest( + operation=OPERATION_CHAT, model=str(request.get("model") or getattr(context, "model", None) or ""), messages=messages, tools=tools, @@ -135,6 +136,7 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No stop_reason = choice.get("finish_reason") return UnifiedResponse( + operation=OPERATION_CHAT, id=response.get("id"), model=response.get("model") or getattr(context, "model", None), messages=messages, @@ -173,10 +175,10 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: event = _decode_sse_data(raw_event) if event == "[DONE]": - return UnifiedStreamEvent(type="done", raw=deepcopy(raw_event)) + return UnifiedStreamEvent(type="done", operation=OPERATION_CHAT, raw=deepcopy(raw_event)) data = _as_dict(event) if data.get("error") is not None: - return UnifiedStreamEvent(type="error", error=deepcopy(data["error"]), raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type="error", operation=OPERATION_CHAT, error=deepcopy(data["error"]), raw=deepcopy(raw_event), extra={"payload": data}) delta_message = None finish_reason = None @@ -192,6 +194,7 @@ def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = N usage = self.extract_usage(data, context) return UnifiedStreamEvent( type="message_delta" if delta_message else "chunk", + operation=OPERATION_CHAT, delta=delta_message, usage=usage, raw=deepcopy(raw_event), diff --git a/src/rotator_library/protocols/openai_images.py b/src/rotator_library/protocols/openai_images.py index 84de94549..fbf382dbb 100644 --- a/src/rotator_library/protocols/openai_images.py +++ b/src/rotator_library/protocols/openai_images.py @@ -67,7 +67,7 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: response = raw_response if isinstance(raw_response, dict) else {} return UnifiedResponse( - operation=OPERATION_IMAGE_GENERATION, + operation=_context_operation(context, OPERATION_IMAGE_GENERATION), model=response.get("model") or getattr(context, "model", None), data=deepcopy(response.get("data") or []), raw=deepcopy(raw_response), @@ -91,3 +91,15 @@ def _image_operation(request: dict[str, Any]) -> str: if "image" in request: return OPERATION_IMAGE_VARIATION return OPERATION_IMAGE_GENERATION + + +def _context_operation(context: ProtocolContext | None, default: str) -> str: + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation in {OPERATION_IMAGE_GENERATION, OPERATION_IMAGE_EDIT, OPERATION_IMAGE_VARIATION}: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation in {OPERATION_IMAGE_GENERATION, OPERATION_IMAGE_EDIT, OPERATION_IMAGE_VARIATION}: + return operation + return default diff --git a/src/rotator_library/protocols/registry.py b/src/rotator_library/protocols/registry.py index 1923d5830..a27a192fa 100644 --- a/src/rotator_library/protocols/registry.py +++ b/src/rotator_library/protocols/registry.py @@ -19,7 +19,7 @@ PROTOCOL_ALIASES: dict[str, str] = {} _PROTOCOL_INSTANCES: dict[str, ProtocolAdapter] = {} -_INFRASTRUCTURE_MODULES = {"base", "registry", "types"} +_INFRASTRUCTURE_MODULES = {"base", "operation", "registry", "types"} def register_protocol(protocol_class: Type[ProtocolAdapter], *, replace: bool = False) -> Type[ProtocolAdapter]: diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py index a722495f8..ffeb89e25 100644 --- a/src/rotator_library/protocols/responses.py +++ b/src/rotator_library/protocols/responses.py @@ -76,6 +76,7 @@ def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | request = dict(raw_request or {}) generation_params = {k: deepcopy(request[k]) for k in _GENERATION_PARAMS if k in request and k != "instructions"} return UnifiedRequest( + operation=OPERATION_RESPONSES, model=str(request.get("model") or getattr(context, "model", None) or ""), messages=self._parse_input(request.get("input")), system=text_blocks(request.get("instructions")) if request.get("instructions") is not None else [], @@ -119,6 +120,7 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No parsed.extra["_output_index"] = index messages.append(parsed) return UnifiedResponse( + operation=OPERATION_RESPONSES, id=response.get("id"), model=response.get("model") or getattr(context, "model", None), messages=messages, @@ -156,21 +158,21 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: event = _decode_sse_data(raw_event) if event == "[DONE]": - return UnifiedStreamEvent(type="done", raw=deepcopy(raw_event)) + return UnifiedStreamEvent(type="done", operation=OPERATION_RESPONSES, raw=deepcopy(raw_event)) data = _as_dict(event) event_type = str(data.get("type") or data.get("event") or "chunk") if event_type in {"error", "response.error"} or data.get("error") is not None: - return UnifiedStreamEvent(type="error", error=deepcopy(data.get("error", data)), raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type="error", operation=OPERATION_RESPONSES, error=deepcopy(data.get("error", data)), raw=deepcopy(raw_event), extra={"payload": data}) if event_type in {"response.completed", "response.failed", "response.incomplete"}: response = self.parse_response(data.get("response") or {}, context) - return UnifiedStreamEvent(type=event_type, message=response.messages[0] if response.messages else None, usage=response.usage, raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type=event_type, operation=OPERATION_RESPONSES, message=response.messages[0] if response.messages else None, usage=response.usage, raw=deepcopy(raw_event), extra={"payload": data}) if event_type == "response.output_text.delta": message = UnifiedMessage(role="assistant", content=text_blocks(data.get("delta") or "")) - return UnifiedStreamEvent(type="message_delta", delta=message, raw=deepcopy(raw_event), extra={"payload": data, "output_index": data.get("output_index"), "content_index": data.get("content_index")}) + return UnifiedStreamEvent(type="message_delta", operation=OPERATION_RESPONSES, delta=message, raw=deepcopy(raw_event), extra={"payload": data, "output_index": data.get("output_index"), "content_index": data.get("content_index")}) if event_type in {"response.output_item.added", "response.output_item.done"} and isinstance(data.get("item"), dict): message = self._parse_output_item(data["item"]) - return UnifiedStreamEvent(type=event_type, message=message, raw=deepcopy(raw_event), extra={"payload": data}) - return UnifiedStreamEvent(type=event_type, raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type=event_type, operation=OPERATION_RESPONSES, message=message, raw=deepcopy(raw_event), extra={"payload": data}) + return UnifiedStreamEvent(type=event_type, operation=OPERATION_RESPONSES, raw=deepcopy(raw_event), extra={"payload": data}) def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: if isinstance(raw_or_unified, (UnifiedResponse, UnifiedStreamEvent)): diff --git a/tests/test_protocol_ollama_mcp.py b/tests/test_protocol_ollama_mcp.py index aeca7f7f4..654902604 100644 --- a/tests/test_protocol_ollama_mcp.py +++ b/tests/test_protocol_ollama_mcp.py @@ -31,3 +31,12 @@ def test_mcp_jsonrpc_round_trip_and_error_preservation() -> None: assert rebuilt == request assert adapter.format_response(response)["result"] == {"content": []} assert adapter.format_response(error)["error"]["message"] == "failed" + + +def test_mcp_preserves_notifications_and_falsey_params() -> None: + adapter = get_protocol("mcp") + notification = {"jsonrpc": "2.0", "method": "notifications/initialized"} + false_params = {"jsonrpc": "2.0", "id": 0, "method": "tools/call", "params": False} + + assert adapter.build_request(adapter.parse_request(notification)) == notification + assert adapter.build_request(adapter.parse_request(false_params)) == false_params diff --git a/tests/test_protocol_openai_images_audio.py b/tests/test_protocol_openai_images_audio.py index be632fe11..6f8a6e99b 100644 --- a/tests/test_protocol_openai_images_audio.py +++ b/tests/test_protocol_openai_images_audio.py @@ -5,6 +5,7 @@ OPERATION_IMAGE_EDIT, OPERATION_IMAGE_GENERATION, OPERATION_SPEECH, + ProtocolContext, get_protocol, ) @@ -14,6 +15,10 @@ def test_openai_images_generation_and_edit_shapes() -> None: generation = adapter.parse_request({"model": "gpt-image-test", "prompt": "draw a red cube", "size": "1024x1024"}) edit = adapter.parse_request({"model": "gpt-image-test", "prompt": "make it blue", "image": "image-ref", "mask": "mask-ref"}) response = adapter.parse_response({"data": [{"url": "https://example.test/image.png", "revised_prompt": "cube"}]}) + edit_response = adapter.parse_response( + {"data": [{"url": "https://example.test/edit.png"}]}, + ProtocolContext(provider_options={"operation": OPERATION_IMAGE_EDIT}), + ) assert generation.operation == OPERATION_IMAGE_GENERATION assert adapter.build_request(generation)["prompt"] == "draw a red cube" @@ -21,6 +26,7 @@ def test_openai_images_generation_and_edit_shapes() -> None: assert {entry["field"] for entry in edit.files} == {"image", "mask"} assert adapter.build_request(edit)["image"] == "image-ref" assert response.data[0]["revised_prompt"] == "cube" + assert edit_response.operation == OPERATION_IMAGE_EDIT def test_openai_audio_transcription_and_speech_shapes() -> None: @@ -28,6 +34,10 @@ def test_openai_audio_transcription_and_speech_shapes() -> None: transcription = adapter.parse_request({"model": "whisper-test", "file": "audio-ref", "language": "en"}) speech = adapter.parse_request({"model": "tts-test", "input": "hello", "voice": "alloy"}) text_response = adapter.parse_response({"text": "hello world"}) + translation_response = adapter.parse_response( + {"text": "bonjour"}, + ProtocolContext(provider_options={"operation": "audio_translation"}), + ) binary_response = adapter.parse_response(b"RIFF") assert transcription.operation == OPERATION_AUDIO_TRANSCRIPTION @@ -35,4 +45,5 @@ def test_openai_audio_transcription_and_speech_shapes() -> None: assert speech.operation == OPERATION_SPEECH assert adapter.build_request(speech)["voice"] == "alloy" assert text_response.output == ["hello world"] + assert translation_response.operation == "audio_translation" assert binary_response.content_type == "application/octet-stream" diff --git a/tests/test_protocol_operation_model.py b/tests/test_protocol_operation_model.py index 0f989269a..509575994 100644 --- a/tests/test_protocol_operation_model.py +++ b/tests/test_protocol_operation_model.py @@ -2,9 +2,13 @@ from rotator_library.protocols import ( OPERATION_CHAT, + OPERATION_COUNT_TOKENS, OPERATION_EMBEDDINGS, + OPERATION_MESSAGES, + OPERATION_RESPONSES, OPERATION_UNKNOWN, ProtocolAdapter, + ProtocolContext, UnifiedRequest, UnifiedResponse, get_protocol, @@ -62,3 +66,28 @@ def test_protocols_advertise_supported_operations() -> None: assert get_protocol("openai_chat").supports_operation(OPERATION_CHAT) assert not get_protocol("openai_chat").supports_operation(OPERATION_EMBEDDINGS) assert get_protocol("litellm_fallback").supports_operation(OPERATION_UNKNOWN) + + +def test_core_protocols_stamp_parsed_operation_fields() -> None: + chat = get_protocol("openai_chat").parse_request({"model": "m", "messages": []}) + messages = get_protocol("anthropic_messages").parse_request({"model": "m", "messages": []}) + responses = get_protocol("responses").parse_request({"model": "m", "input": "hi"}) + gemini = get_protocol("gemini").parse_request({"contents": []}) + + assert chat.operation == OPERATION_CHAT + assert messages.operation == OPERATION_MESSAGES + assert responses.operation == OPERATION_RESPONSES + assert gemini.operation == OPERATION_CHAT + + +def test_count_tokens_operation_can_be_context_selected() -> None: + context = ProtocolContext(provider_options={"operation": OPERATION_COUNT_TOKENS}) + anthropic = get_protocol("anthropic_messages").parse_response({"input_tokens": 12}, context) + gemini = get_protocol("gemini").parse_response({"totalTokens": 14}, context) + + assert anthropic.operation == OPERATION_COUNT_TOKENS + assert anthropic.usage is not None + assert anthropic.usage.input_tokens == 12 + assert gemini.operation == OPERATION_COUNT_TOKENS + assert gemini.usage is not None + assert gemini.usage.total_tokens == 14 From 90e26cea62a7d47e1073aa427972f5fddc7395b8 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 08:30:23 +0200 Subject: [PATCH 078/182] fix(protocols): close audio and ollama review gaps Addresses the remaining Phase 1b re-review findings by honoring context-selected audio translation requests, preventing stale transcription prompts from overwriting transformed input, and preserving binary/text audio response formatting. Also fixes Ollama prompt-style embeddings, maps final chat responses into unified messages, narrows duplicate extra fields, and adds regression coverage for the edge cases identified by the review agents. Tests: pytest tests/test_protocol_operation_model.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- src/rotator_library/protocols/ollama.py | 21 +++++++++--- src/rotator_library/protocols/openai_audio.py | 32 ++++++++++++++++--- tests/test_protocol_ollama_mcp.py | 7 +++- tests/test_protocol_openai_images_audio.py | 12 +++++++ 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/src/rotator_library/protocols/ollama.py b/src/rotator_library/protocols/ollama.py index f10af5116..5b3a96474 100644 --- a/src/rotator_library/protocols/ollama.py +++ b/src/rotator_library/protocols/ollama.py @@ -28,14 +28,20 @@ class OllamaProtocol(ProtocolAdapter): def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) operation = _ollama_operation(request) + input_field = "input" + if operation == OPERATION_EMBEDDINGS and "input" not in request and "prompt" in request: + # Older Ollama embeddings endpoints use `prompt`; newer endpoints use + # `input`. Preserve the original spelling so round-trips are lossless. + input_field = "prompt" return UnifiedRequest( operation=operation, model=str(request.get("model") or getattr(context, "model", None) or ""), messages=[_message_from_ollama(message) for message in request.get("messages") or []], system=[ContentBlock(type="text", text=str(request["system"]))] if "system" in request else [], stream=bool(request.get("stream", False)), - input=deepcopy(request.get("input") if operation == OPERATION_EMBEDDINGS else request.get("prompt")), + input=deepcopy(request.get(input_field) if operation == OPERATION_EMBEDDINGS else request.get("prompt")), generation_params={k: deepcopy(request[k]) for k in _OPTION_FIELDS if k in request}, + metadata={"embedding_input_field": input_field} if operation == OPERATION_EMBEDDINGS else {}, raw=deepcopy(raw_request), extra={k: deepcopy(v) for k, v in request.items() if k not in _CORE_FIELDS}, ) @@ -45,7 +51,8 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex if unified_request.operation == OPERATION_OLLAMA_CHAT: payload["messages"] = [_message_to_ollama(message) for message in unified_request.messages] elif unified_request.operation == OPERATION_EMBEDDINGS: - payload["input"] = deepcopy(unified_request.input) + input_field = str(unified_request.metadata.get("embedding_input_field") or "input") + payload[input_field] = deepcopy(unified_request.input) else: payload["prompt"] = deepcopy(unified_request.input) if unified_request.system: @@ -57,18 +64,22 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: response = raw_response if isinstance(raw_response, dict) else {} output = [] + messages: list[UnifiedMessage] = [] if isinstance(response.get("message"), dict): - output.append(_message_from_ollama(response["message"]).to_dict()) + message = _message_from_ollama(response["message"]) + messages.append(message) + output.append(message.to_dict()) elif "response" in response: output.append(response.get("response")) return UnifiedResponse( operation=_ollama_operation(response), model=response.get("model") or getattr(context, "model", None), + messages=messages, output=output, data=deepcopy(response.get("embeddings") or response.get("embedding") or []), usage=_ollama_usage(response), raw=deepcopy(raw_response), - extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "message", "response", "embeddings", "embedding", "prompt_eval_count", "eval_count"}}, + extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "message", "response", "embeddings", "embedding", "prompt_eval_count", "eval_count", "total_duration", "load_duration", "prompt_eval_duration", "eval_duration"}}, ) def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: @@ -90,6 +101,8 @@ def _ollama_operation(request: dict[str, Any]) -> str: return explicit if "messages" in request or "message" in request: return OPERATION_OLLAMA_CHAT + if "prompt" in request and request.get("endpoint") != "embeddings": + return OPERATION_OLLAMA_GENERATE if "embeddings" in request or "embedding" in request or "input" in request or request.get("endpoint") == "embeddings": return OPERATION_EMBEDDINGS return OPERATION_OLLAMA_GENERATE diff --git a/src/rotator_library/protocols/openai_audio.py b/src/rotator_library/protocols/openai_audio.py index b626f8c3b..be0fb5b2d 100644 --- a/src/rotator_library/protocols/openai_audio.py +++ b/src/rotator_library/protocols/openai_audio.py @@ -12,9 +12,9 @@ from .operation import OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH, normalize_operation from .types import ProtocolContext, UnifiedRequest, UnifiedResponse -_AUDIO_OPTION_FIELDS = {"language", "prompt", "response_format", "temperature", "timestamp_granularities"} +_AUDIO_OPTION_FIELDS = {"language", "response_format", "temperature", "timestamp_granularities"} _SPEECH_OPTION_FIELDS = {"voice", "response_format", "speed"} -_CORE_FIELDS = {"operation", "model", "file", "input", *_AUDIO_OPTION_FIELDS, *_SPEECH_OPTION_FIELDS} +_CORE_FIELDS = {"operation", "model", "file", "input", "prompt", *_AUDIO_OPTION_FIELDS, *_SPEECH_OPTION_FIELDS} class OpenAIAudioProtocol(ProtocolAdapter): @@ -36,7 +36,7 @@ class OpenAIAudioProtocol(ProtocolAdapter): def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) - operation = _audio_operation(request) + operation = _audio_operation(request, context) files = [] if "file" in request: files.append({"field": "file", "value": deepcopy(request["file"])}) @@ -64,6 +64,27 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex payload.update(deepcopy(unified_request.extra)) return payload + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> Any: + """Format audio responses without forcing binary/text payloads into JSON. + + Speech endpoints can return raw audio bytes while transcription endpoints + can return JSON or plain text depending on `response_format`. Returning + the preserved raw body for non-dict responses keeps the protocol adapter + honest until a transport layer decides headers and streaming behavior. + """ + + if unified_response.raw is not None and not isinstance(unified_response.raw, dict): + return deepcopy(unified_response.raw) + payload: dict[str, Any] = {} + if unified_response.output: + payload["text"] = deepcopy(unified_response.output[0]) + if unified_response.data: + payload["data"] = deepcopy(unified_response.data) + if unified_response.content_type: + payload["content_type"] = unified_response.content_type + payload.update(deepcopy(unified_response.extra)) + return payload + def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: if isinstance(raw_response, dict): response = raw_response @@ -81,10 +102,13 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No return UnifiedResponse(operation=_context_operation(context, default_operation), content_type=content_type, raw=deepcopy(raw_response), output=[deepcopy(raw_response)] if isinstance(raw_response, str) else []) -def _audio_operation(request: dict[str, Any]) -> str: +def _audio_operation(request: dict[str, Any], context: ProtocolContext | None = None) -> str: explicit = normalize_operation(request.get("operation")) if explicit in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: return explicit + contextual = _context_operation(context, "unknown") + if contextual in {OPERATION_AUDIO_TRANSCRIPTION, OPERATION_AUDIO_TRANSLATION, OPERATION_SPEECH}: + return contextual if "voice" in request or ("input" in request and "file" not in request): return OPERATION_SPEECH return OPERATION_AUDIO_TRANSCRIPTION diff --git a/tests/test_protocol_ollama_mcp.py b/tests/test_protocol_ollama_mcp.py index 654902604..aea8522ae 100644 --- a/tests/test_protocol_ollama_mcp.py +++ b/tests/test_protocol_ollama_mcp.py @@ -1,18 +1,23 @@ from __future__ import annotations -from rotator_library.protocols import OPERATION_MCP, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, get_protocol +from rotator_library.protocols import OPERATION_EMBEDDINGS, OPERATION_MCP, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, get_protocol def test_ollama_chat_generate_and_stream_shapes() -> None: adapter = get_protocol("ollama") chat = adapter.parse_request({"model": "llama3", "messages": [{"role": "user", "content": "hi"}], "stream": True}) generate = adapter.parse_request({"model": "llama3", "prompt": "write", "options": {"temperature": 0.1}}) + embeddings = adapter.parse_request({"model": "llama3", "operation": OPERATION_EMBEDDINGS, "prompt": "embed this"}) + final_chat = adapter.parse_response({"model": "llama3", "message": {"role": "assistant", "content": "hello"}, "done": True}) chunk = adapter.parse_stream_event('{"model":"llama3","response":"he","done":false,"eval_count":2}') assert chat.operation == OPERATION_OLLAMA_CHAT assert adapter.build_request(chat)["messages"][0]["content"] == "hi" assert generate.operation == OPERATION_OLLAMA_GENERATE assert adapter.build_request(generate)["prompt"] == "write" + assert embeddings.operation == OPERATION_EMBEDDINGS + assert adapter.build_request(embeddings)["prompt"] == "embed this" + assert final_chat.messages[0].content[0].text == "hello" assert chunk.delta is not None assert chunk.delta.content[0].text == "he" assert chunk.usage is not None diff --git a/tests/test_protocol_openai_images_audio.py b/tests/test_protocol_openai_images_audio.py index 6f8a6e99b..ad108202c 100644 --- a/tests/test_protocol_openai_images_audio.py +++ b/tests/test_protocol_openai_images_audio.py @@ -2,8 +2,10 @@ from rotator_library.protocols import ( OPERATION_AUDIO_TRANSCRIPTION, + OPERATION_AUDIO_TRANSLATION, OPERATION_IMAGE_EDIT, OPERATION_IMAGE_GENERATION, + OPERATION_IMAGE_VARIATION, OPERATION_SPEECH, ProtocolContext, get_protocol, @@ -14,6 +16,7 @@ def test_openai_images_generation_and_edit_shapes() -> None: adapter = get_protocol("openai_images") generation = adapter.parse_request({"model": "gpt-image-test", "prompt": "draw a red cube", "size": "1024x1024"}) edit = adapter.parse_request({"model": "gpt-image-test", "prompt": "make it blue", "image": "image-ref", "mask": "mask-ref"}) + variation = adapter.parse_request({"model": "gpt-image-test", "image": "image-ref"}) response = adapter.parse_response({"data": [{"url": "https://example.test/image.png", "revised_prompt": "cube"}]}) edit_response = adapter.parse_response( {"data": [{"url": "https://example.test/edit.png"}]}, @@ -23,6 +26,7 @@ def test_openai_images_generation_and_edit_shapes() -> None: assert generation.operation == OPERATION_IMAGE_GENERATION assert adapter.build_request(generation)["prompt"] == "draw a red cube" assert edit.operation == OPERATION_IMAGE_EDIT + assert variation.operation == OPERATION_IMAGE_VARIATION assert {entry["field"] for entry in edit.files} == {"image", "mask"} assert adapter.build_request(edit)["image"] == "image-ref" assert response.data[0]["revised_prompt"] == "cube" @@ -32,6 +36,10 @@ def test_openai_images_generation_and_edit_shapes() -> None: def test_openai_audio_transcription_and_speech_shapes() -> None: adapter = get_protocol("openai_audio") transcription = adapter.parse_request({"model": "whisper-test", "file": "audio-ref", "language": "en"}) + translation = adapter.parse_request( + {"model": "whisper-test", "file": "audio-ref", "prompt": "domain hint"}, + ProtocolContext(provider_options={"operation": OPERATION_AUDIO_TRANSLATION}), + ) speech = adapter.parse_request({"model": "tts-test", "input": "hello", "voice": "alloy"}) text_response = adapter.parse_response({"text": "hello world"}) translation_response = adapter.parse_response( @@ -42,8 +50,12 @@ def test_openai_audio_transcription_and_speech_shapes() -> None: assert transcription.operation == OPERATION_AUDIO_TRANSCRIPTION assert transcription.files[0]["value"] == "audio-ref" + assert translation.operation == OPERATION_AUDIO_TRANSLATION + translation.input = "mutated hint" + assert adapter.build_request(translation)["prompt"] == "mutated hint" assert speech.operation == OPERATION_SPEECH assert adapter.build_request(speech)["voice"] == "alloy" assert text_response.output == ["hello world"] assert translation_response.operation == "audio_translation" assert binary_response.content_type == "application/octet-stream" + assert adapter.format_response(binary_response) == b"RIFF" From 495904e313bf81be07f0a25f3cc4407374314218 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 08:46:18 +0200 Subject: [PATCH 079/182] fix(protocols): harden count tokens and ollama semantics Closes the final Phase 1b heavy-review findings by formatting Anthropic and Gemini count-token responses in their native shapes and adding tests for those round trips. Preserves Ollama omitted stream semantics, allows context-selected embeddings without nonstandard body fields, keeps prompt/input embeddings round-trippable, and documents/preserves MCP JSON-RPC batches. Also stamps common audio JSON responses with a useful default operation while still allowing context overrides. Tests: pytest tests/test_protocol_operation_model.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- .../protocols/anthropic_messages.py | 10 ++++-- src/rotator_library/protocols/gemini.py | 5 +++ src/rotator_library/protocols/mcp.py | 15 +++++++-- src/rotator_library/protocols/ollama.py | 22 +++++++++---- src/rotator_library/protocols/openai_audio.py | 7 +++- tests/test_protocol_ollama_mcp.py | 33 ++++++++++++++++++- tests/test_protocol_openai_images_audio.py | 1 + tests/test_protocol_operation_model.py | 2 ++ 8 files changed, 82 insertions(+), 13 deletions(-) diff --git a/src/rotator_library/protocols/anthropic_messages.py b/src/rotator_library/protocols/anthropic_messages.py index 0df594828..5e5b31fa9 100644 --- a/src/rotator_library/protocols/anthropic_messages.py +++ b/src/rotator_library/protocols/anthropic_messages.py @@ -94,6 +94,7 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: response = _as_dict(raw_response) + operation = _response_operation(response, context) message = UnifiedMessage( role=str(response.get("role") or "assistant"), content=self._parse_content(response.get("content")), @@ -102,10 +103,10 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) self._promote_message_blocks(message) return UnifiedResponse( - operation=_response_operation(response, context), + operation=operation, id=response.get("id"), model=response.get("model") or getattr(context, "model", None), - messages=[message] if response else [], + messages=[] if operation == OPERATION_COUNT_TOKENS else [message] if response else [], stop_reason=response.get("stop_reason"), usage=self.extract_usage(response, context), metadata={"stop_sequence": response.get("stop_sequence"), "type": response.get("type")}, @@ -114,6 +115,11 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if unified_response.operation == OPERATION_COUNT_TOKENS: + usage = unified_response.usage + payload = {"input_tokens": usage.input_tokens if usage else 0} + payload.update(deepcopy(unified_response.extra)) + return payload message = unified_response.messages[0] if unified_response.messages else UnifiedMessage(role="assistant") payload = { "id": unified_response.id, diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py index f1fad2da4..90141511a 100644 --- a/src/rotator_library/protocols/gemini.py +++ b/src/rotator_library/protocols/gemini.py @@ -120,6 +120,11 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No ) def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + if unified_response.operation == OPERATION_COUNT_TOKENS: + usage = unified_response.usage + payload = {"totalTokens": usage.total_tokens if usage else 0} + payload.update(deepcopy(unified_response.extra)) + return payload candidates = [] for index, message in enumerate(unified_response.messages): candidate = {"index": index, "content": self._format_content(message)} diff --git a/src/rotator_library/protocols/mcp.py b/src/rotator_library/protocols/mcp.py index 845b210cc..e51a1ec5c 100644 --- a/src/rotator_library/protocols/mcp.py +++ b/src/rotator_library/protocols/mcp.py @@ -5,7 +5,8 @@ This is not a full MCP proxy implementation. It gives the native protocol layer a lossless request/response carrier for future MCP gateway work, keeping method, -params, ids, results, and errors intact for transform logging and routing. +params, ids, results, errors, and JSON-RPC batch arrays intact for transform +logging and routing. """ from __future__ import annotations @@ -31,6 +32,8 @@ class MCPProtocol(ProtocolAdapter): future_transports: ClassVar[tuple[str, ...]] = ("sse", "websocket") def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: + if isinstance(raw_request, list): + return UnifiedRequest(operation=OPERATION_MCP, input=deepcopy(raw_request), metadata={"batch": True}, raw=deepcopy(raw_request)) request = dict(raw_request or {}) metadata = { "jsonrpc": request.get("jsonrpc", "2.0"), @@ -48,7 +51,9 @@ def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | extra={k: deepcopy(v) for k, v in request.items() if k not in _REQUEST_CORE_FIELDS}, ) - def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: + def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> Any: + if unified_request.metadata.get("batch"): + return deepcopy(unified_request.input or []) payload = { "jsonrpc": unified_request.metadata.get("jsonrpc", "2.0"), "method": unified_request.metadata.get("method"), @@ -61,6 +66,8 @@ def build_request(self, unified_request: UnifiedRequest, context: ProtocolContex return payload def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: + if isinstance(raw_response, list): + return UnifiedResponse(operation=OPERATION_MCP, data=deepcopy(raw_response), metadata={"batch": True}, raw=deepcopy(raw_response)) response = raw_response if isinstance(raw_response, dict) else {} metadata = {"jsonrpc": response.get("jsonrpc", "2.0"), "has_id": "id" in response} if "id" in response: @@ -78,7 +85,9 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No extra=extra, ) - def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> Any: + if unified_response.metadata.get("batch"): + return deepcopy(unified_response.data) payload = {"jsonrpc": unified_response.metadata.get("jsonrpc", "2.0")} if unified_response.metadata.get("has_id"): payload["id"] = deepcopy(unified_response.metadata.get("id")) diff --git a/src/rotator_library/protocols/ollama.py b/src/rotator_library/protocols/ollama.py index 5b3a96474..45e613820 100644 --- a/src/rotator_library/protocols/ollama.py +++ b/src/rotator_library/protocols/ollama.py @@ -27,7 +27,7 @@ class OllamaProtocol(ProtocolAdapter): def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) - operation = _ollama_operation(request) + operation = _ollama_operation(request, context) input_field = "input" if operation == OPERATION_EMBEDDINGS and "input" not in request and "prompt" in request: # Older Ollama embeddings endpoints use `prompt`; newer endpoints use @@ -41,13 +41,15 @@ def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | stream=bool(request.get("stream", False)), input=deepcopy(request.get(input_field) if operation == OPERATION_EMBEDDINGS else request.get("prompt")), generation_params={k: deepcopy(request[k]) for k in _OPTION_FIELDS if k in request}, - metadata={"embedding_input_field": input_field} if operation == OPERATION_EMBEDDINGS else {}, + metadata={**({"embedding_input_field": input_field} if operation == OPERATION_EMBEDDINGS else {}), "has_stream": "stream" in request}, raw=deepcopy(raw_request), extra={k: deepcopy(v) for k, v in request.items() if k not in _CORE_FIELDS}, ) def build_request(self, unified_request: UnifiedRequest, context: ProtocolContext | None = None) -> dict[str, Any]: - payload: dict[str, Any] = {"model": unified_request.model, "stream": unified_request.stream} + payload: dict[str, Any] = {"model": unified_request.model} + if unified_request.stream or unified_request.metadata.get("has_stream"): + payload["stream"] = unified_request.stream if unified_request.operation == OPERATION_OLLAMA_CHAT: payload["messages"] = [_message_to_ollama(message) for message in unified_request.messages] elif unified_request.operation == OPERATION_EMBEDDINGS: @@ -72,7 +74,7 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No elif "response" in response: output.append(response.get("response")) return UnifiedResponse( - operation=_ollama_operation(response), + operation=_ollama_operation(response, context), model=response.get("model") or getattr(context, "model", None), messages=messages, output=output, @@ -92,13 +94,21 @@ def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = N delta = _message_from_ollama(data["message"]) elif data.get("response"): delta = UnifiedMessage(role="assistant", content=[ContentBlock(type="text", text=str(data["response"]))]) - return UnifiedStreamEvent(type=event_type, operation=_ollama_operation(data), delta=delta, usage=_ollama_usage(data), raw=deepcopy(raw_event), extra=deepcopy(data)) + return UnifiedStreamEvent(type=event_type, operation=_ollama_operation(data, context), delta=delta, usage=_ollama_usage(data), raw=deepcopy(raw_event), extra=deepcopy(data)) -def _ollama_operation(request: dict[str, Any]) -> str: +def _ollama_operation(request: dict[str, Any], context: ProtocolContext | None = None) -> str: explicit = normalize_operation(request.get("operation")) if explicit in {OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS}: return explicit + if context and isinstance(context.provider_options, dict): + operation = normalize_operation(context.provider_options.get("operation")) + if operation in {OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS}: + return operation + if context and isinstance(context.metadata, dict): + operation = normalize_operation(context.metadata.get("operation")) + if operation in {OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, OPERATION_EMBEDDINGS}: + return operation if "messages" in request or "message" in request: return OPERATION_OLLAMA_CHAT if "prompt" in request and request.get("endpoint") != "embeddings": diff --git a/src/rotator_library/protocols/openai_audio.py b/src/rotator_library/protocols/openai_audio.py index be0fb5b2d..392f5b19e 100644 --- a/src/rotator_library/protocols/openai_audio.py +++ b/src/rotator_library/protocols/openai_audio.py @@ -88,8 +88,11 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo def parse_response(self, raw_response: Any, context: ProtocolContext | None = None) -> UnifiedResponse: if isinstance(raw_response, dict): response = raw_response + operation = normalize_operation(response.get("operation")) + if operation == "unknown": + operation = OPERATION_AUDIO_TRANSCRIPTION return UnifiedResponse( - operation=_context_operation(context, normalize_operation(response.get("operation"))), + operation=_context_operation(context, operation), model=response.get("model") or getattr(context, "model", None), output=[deepcopy(response["text"])] if "text" in response else [], data=deepcopy(response.get("data") or []), @@ -98,6 +101,8 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No extra={k: deepcopy(v) for k, v in response.items() if k not in {"operation", "model", "text", "data", "content_type"}}, ) content_type = "text/plain" if isinstance(raw_response, str) else "application/octet-stream" + # Providers normally return transcription text as JSON/text and speech as + # bytes. Context can still override this heuristic for unusual endpoints. default_operation = OPERATION_AUDIO_TRANSCRIPTION if isinstance(raw_response, str) else OPERATION_SPEECH return UnifiedResponse(operation=_context_operation(context, default_operation), content_type=content_type, raw=deepcopy(raw_response), output=[deepcopy(raw_response)] if isinstance(raw_response, str) else []) diff --git a/tests/test_protocol_ollama_mcp.py b/tests/test_protocol_ollama_mcp.py index aea8522ae..fbd494565 100644 --- a/tests/test_protocol_ollama_mcp.py +++ b/tests/test_protocol_ollama_mcp.py @@ -1,23 +1,39 @@ from __future__ import annotations -from rotator_library.protocols import OPERATION_EMBEDDINGS, OPERATION_MCP, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, get_protocol +from rotator_library.protocols import OPERATION_EMBEDDINGS, OPERATION_MCP, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, ProtocolContext, get_protocol def test_ollama_chat_generate_and_stream_shapes() -> None: adapter = get_protocol("ollama") chat = adapter.parse_request({"model": "llama3", "messages": [{"role": "user", "content": "hi"}], "stream": True}) generate = adapter.parse_request({"model": "llama3", "prompt": "write", "options": {"temperature": 0.1}}) + explicit_stream_false = adapter.parse_request({"model": "llama3", "prompt": "write", "stream": False}) embeddings = adapter.parse_request({"model": "llama3", "operation": OPERATION_EMBEDDINGS, "prompt": "embed this"}) + contextual_embeddings = adapter.parse_request( + {"model": "llama3", "prompt": "embed via context"}, + ProtocolContext(provider_options={"operation": OPERATION_EMBEDDINGS}), + ) + new_embeddings = adapter.parse_request({"model": "llama3", "input": "embed input"}) final_chat = adapter.parse_response({"model": "llama3", "message": {"role": "assistant", "content": "hello"}, "done": True}) + final_generate = adapter.parse_response({"model": "llama3", "response": "generated", "done": True}) + embeddings_response = adapter.parse_response({"model": "llama3", "embeddings": [[0.1, 0.2]]}) chunk = adapter.parse_stream_event('{"model":"llama3","response":"he","done":false,"eval_count":2}') assert chat.operation == OPERATION_OLLAMA_CHAT assert adapter.build_request(chat)["messages"][0]["content"] == "hi" assert generate.operation == OPERATION_OLLAMA_GENERATE assert adapter.build_request(generate)["prompt"] == "write" + assert "stream" not in adapter.build_request(generate) + assert adapter.build_request(explicit_stream_false)["stream"] is False assert embeddings.operation == OPERATION_EMBEDDINGS assert adapter.build_request(embeddings)["prompt"] == "embed this" + assert contextual_embeddings.operation == OPERATION_EMBEDDINGS + assert adapter.build_request(contextual_embeddings)["prompt"] == "embed via context" + assert adapter.build_request(new_embeddings)["input"] == "embed input" assert final_chat.messages[0].content[0].text == "hello" + assert final_generate.output == ["generated"] + assert final_generate.messages == [] + assert embeddings_response.data == [[0.1, 0.2]] assert chunk.delta is not None assert chunk.delta.content[0].text == "he" assert chunk.usage is not None @@ -45,3 +61,18 @@ def test_mcp_preserves_notifications_and_falsey_params() -> None: assert adapter.build_request(adapter.parse_request(notification)) == notification assert adapter.build_request(adapter.parse_request(false_params)) == false_params + + +def test_mcp_preserves_jsonrpc_batches() -> None: + adapter = get_protocol("mcp") + batch = [ + {"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}, + {"jsonrpc": "2.0", "id": 2, "method": "resources/list", "params": None}, + ] + response_batch = [ + {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}, + {"jsonrpc": "2.0", "id": 2, "error": {"code": -1, "message": "failed"}}, + ] + + assert adapter.build_request(adapter.parse_request(batch)) == batch + assert adapter.format_response(adapter.parse_response(response_batch)) == response_batch diff --git a/tests/test_protocol_openai_images_audio.py b/tests/test_protocol_openai_images_audio.py index ad108202c..111175b82 100644 --- a/tests/test_protocol_openai_images_audio.py +++ b/tests/test_protocol_openai_images_audio.py @@ -55,6 +55,7 @@ def test_openai_audio_transcription_and_speech_shapes() -> None: assert adapter.build_request(translation)["prompt"] == "mutated hint" assert speech.operation == OPERATION_SPEECH assert adapter.build_request(speech)["voice"] == "alloy" + assert text_response.operation == OPERATION_AUDIO_TRANSCRIPTION assert text_response.output == ["hello world"] assert translation_response.operation == "audio_translation" assert binary_response.content_type == "application/octet-stream" diff --git a/tests/test_protocol_operation_model.py b/tests/test_protocol_operation_model.py index 509575994..7048c8e0b 100644 --- a/tests/test_protocol_operation_model.py +++ b/tests/test_protocol_operation_model.py @@ -88,6 +88,8 @@ def test_count_tokens_operation_can_be_context_selected() -> None: assert anthropic.operation == OPERATION_COUNT_TOKENS assert anthropic.usage is not None assert anthropic.usage.input_tokens == 12 + assert get_protocol("anthropic_messages").format_response(anthropic) == {"input_tokens": 12} assert gemini.operation == OPERATION_COUNT_TOKENS assert gemini.usage is not None assert gemini.usage.total_tokens == 14 + assert get_protocol("gemini").format_response(gemini) == {"totalTokens": 14} From 77f7666b9cdd3196d7b6730638459dbe438be295 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 08:58:13 +0200 Subject: [PATCH 080/182] fix(protocols): close final operation hardening notes Restricts Anthropic and Gemini context-selected operations to operations each adapter actually supports, preventing bad context metadata from stamping unsupported operations onto parsed payloads. Ensures normalized count-token usage wins over preserved raw fields during format_response, and adds response-format coverage for embeddings and audio dict outputs. Tests: pytest tests/test_protocol_operation_model.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- .../protocols/anthropic_messages.py | 11 +++++++---- src/rotator_library/protocols/gemini.py | 11 +++++++---- tests/test_protocol_openai_embeddings.py | 2 ++ tests/test_protocol_openai_images_audio.py | 1 + tests/test_protocol_operation_model.py | 15 +++++++++++++++ 5 files changed, 32 insertions(+), 8 deletions(-) diff --git a/src/rotator_library/protocols/anthropic_messages.py b/src/rotator_library/protocols/anthropic_messages.py index 5e5b31fa9..024bdfae5 100644 --- a/src/rotator_library/protocols/anthropic_messages.py +++ b/src/rotator_library/protocols/anthropic_messages.py @@ -117,8 +117,10 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: if unified_response.operation == OPERATION_COUNT_TOKENS: usage = unified_response.usage - payload = {"input_tokens": usage.input_tokens if usage else 0} - payload.update(deepcopy(unified_response.extra)) + payload = deepcopy(unified_response.extra) + # Normalized usage wins over raw preserved fields so later adapters + # can correct counts without stale provider keys shadowing them. + payload["input_tokens"] = usage.input_tokens if usage else 0 return payload message = unified_response.messages[0] if unified_response.messages else UnifiedMessage(role="assistant") payload = { @@ -334,13 +336,14 @@ def _parse_content_stream_event(self, data: dict[str, Any], raw_event: Any) -> U def _operation_from_context(context: ProtocolContext | None, default: str) -> str: + supported = {OPERATION_MESSAGES, OPERATION_COUNT_TOKENS} if context and isinstance(context.provider_options, dict): operation = normalize_operation(context.provider_options.get("operation")) - if operation != OPERATION_UNKNOWN: + if operation in supported: return operation if context and isinstance(context.metadata, dict): operation = normalize_operation(context.metadata.get("operation")) - if operation != OPERATION_UNKNOWN: + if operation in supported: return operation return default diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py index 90141511a..203810ad8 100644 --- a/src/rotator_library/protocols/gemini.py +++ b/src/rotator_library/protocols/gemini.py @@ -122,8 +122,10 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: if unified_response.operation == OPERATION_COUNT_TOKENS: usage = unified_response.usage - payload = {"totalTokens": usage.total_tokens if usage else 0} - payload.update(deepcopy(unified_response.extra)) + payload = deepcopy(unified_response.extra) + # Normalized usage wins over raw preserved fields so later adapters + # can correct counts without stale provider keys shadowing them. + payload["totalTokens"] = usage.total_tokens if usage else 0 return payload candidates = [] for index, message in enumerate(unified_response.messages): @@ -354,13 +356,14 @@ def _as_dict(value: Any) -> dict[str, Any]: def _operation_from_context(context: ProtocolContext | None, default: str) -> str: + supported = {OPERATION_CHAT, OPERATION_COUNT_TOKENS} if context and isinstance(context.provider_options, dict): operation = normalize_operation(context.provider_options.get("operation")) - if operation != OPERATION_UNKNOWN: + if operation in supported: return operation if context and isinstance(context.metadata, dict): operation = normalize_operation(context.metadata.get("operation")) - if operation != OPERATION_UNKNOWN: + if operation in supported: return operation return default diff --git a/tests/test_protocol_openai_embeddings.py b/tests/test_protocol_openai_embeddings.py index d8a63a288..f398f7a33 100644 --- a/tests/test_protocol_openai_embeddings.py +++ b/tests/test_protocol_openai_embeddings.py @@ -26,3 +26,5 @@ def test_openai_embeddings_protocol_round_trip_and_usage() -> None: assert response.data[0]["embedding"] == [0.1, 0.2] assert response.usage is not None assert response.usage.input_tokens == 4 + assert adapter.format_response(response)["data"] == response.data + assert adapter.format_response(response)["usage"]["total_tokens"] == 4 diff --git a/tests/test_protocol_openai_images_audio.py b/tests/test_protocol_openai_images_audio.py index 111175b82..4fd0cd3ce 100644 --- a/tests/test_protocol_openai_images_audio.py +++ b/tests/test_protocol_openai_images_audio.py @@ -57,6 +57,7 @@ def test_openai_audio_transcription_and_speech_shapes() -> None: assert adapter.build_request(speech)["voice"] == "alloy" assert text_response.operation == OPERATION_AUDIO_TRANSCRIPTION assert text_response.output == ["hello world"] + assert adapter.format_response(text_response)["text"] == "hello world" assert translation_response.operation == "audio_translation" assert binary_response.content_type == "application/octet-stream" assert adapter.format_response(binary_response) == b"RIFF" diff --git a/tests/test_protocol_operation_model.py b/tests/test_protocol_operation_model.py index 7048c8e0b..c2c5d0e93 100644 --- a/tests/test_protocol_operation_model.py +++ b/tests/test_protocol_operation_model.py @@ -93,3 +93,18 @@ def test_count_tokens_operation_can_be_context_selected() -> None: assert gemini.usage is not None assert gemini.usage.total_tokens == 14 assert get_protocol("gemini").format_response(gemini) == {"totalTokens": 14} + + anthropic.usage.input_tokens = 13 + gemini.usage.total_tokens = 15 + assert get_protocol("anthropic_messages").format_response(anthropic)["input_tokens"] == 13 + assert get_protocol("gemini").format_response(gemini)["totalTokens"] == 15 + + +def test_context_operation_helpers_reject_unsupported_operations() -> None: + context = ProtocolContext(provider_options={"operation": OPERATION_EMBEDDINGS}) + + anthropic = get_protocol("anthropic_messages").parse_request({"model": "m", "messages": []}, context) + gemini = get_protocol("gemini").parse_request({"contents": []}, context) + + assert anthropic.operation == OPERATION_MESSAGES + assert gemini.operation == OPERATION_CHAT From 60997093d902a437748dbf408f6e88998e1984d3 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 09:02:31 +0200 Subject: [PATCH 081/182] docs(experimental): plan transform trace coverage correction Adds the corrective Phase 2b plan for closing transform-trace coverage gaps found during the full audit pass. The plan covers provider transform boundaries, LiteLLM execution request/response states, native protocol parse/build/adapter/cache passes, Responses service traces, stream events, tests, and review requirements. User-facing reports and unrelated dirty files remain uncommitted. --- .../phase-2b-transform-trace-coverage.md | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 docs/experimental/phase-2b-transform-trace-coverage.md diff --git a/docs/experimental/phase-2b-transform-trace-coverage.md b/docs/experimental/phase-2b-transform-trace-coverage.md new file mode 100644 index 000000000..034874939 --- /dev/null +++ b/docs/experimental/phase-2b-transform-trace-coverage.md @@ -0,0 +1,184 @@ +# Phase 2b: Complete Live Transform-Pass Trace Coverage + +## Goal + +Correct the Phase 2 audit gap. Phase 2 created the trace writer and added important request, response, stream, provider, and error trace entries, but the validation pass found that it still does not capture every meaningful real transformation boundary. Phase 2b makes transform tracing systematic across the live LiteLLM path, provider transforms, native protocol execution, Responses service/streaming, adapter chains, field-cache passes, and stream wrappers without changing request behavior. + +## Non-Goals + +- Do not replace LiteLLM execution in this phase. +- Do not make non-chat protocols live routes here. +- Do not implement field-cache persistence or new field-cache semantics here; only improve logging of existing cache operations. +- Do not implement WebSocket runtime support here. +- Do not add security/multi-user features beyond existing redaction behavior. +- Do not commit user-facing reports. +- Do not touch unrelated dirty `ARCHITECTURE.md`, `STRUCTURE.md`, `.opencode/`, `docs/issues/`, or old phase reports unless explicitly needed. + +## Current Code State + +- `TransformTraceWriter` exists with JSONL and snapshot files. +- `TransactionLogger.log_transform_pass()` and `log_transform_error()` exist and are additive. +- Existing legacy trace entries include `raw_client_request`, `prepared_provider_request`, `final_client_response`, stream chunk entries, provider request/chunk/final/error entries, routing metadata, retry/cooldown metadata, and usage summaries. +- `ProviderTransforms.apply()` mutates request kwargs through built-in transforms, provider hook transforms, model options, and LiteLLM conversion, but it does not emit per-step trace entries. +- `RequestExecutor._prepare_request_kwargs()` calls `ProviderTransforms.apply()`, then `log_transformed_request()` logs only the final prepared provider kwargs. +- Non-streaming LiteLLM/provider responses are logged only after success as `final_client_response`; there is no explicit `raw_provider_response` or `post_usage_normalization_response` distinction. +- Streaming logs `raw_stream_chunk`, `parsed_stream_chunk`, and `assembled_stream_response`, but the live stream handler has additional normalization/usage/error decision work that is not fully trace-covered. +- `NativeProviderExecutor` traces some native passes, but it skips key states: raw native input, parsed unified request, built provider request before adapters, after response adapters, field-cache extraction details, and final formatting boundaries in a complete request/response sequence. +- `run_adapter_chain()` traces `before_adapter_chain` and `after_adapter`, but not a final `after_adapter_chain` summary and not enough metadata to distinguish built-in vs provider-declared adapter chains in live reviews. +- `FieldCacheEngine` logs per-rule operations, but injection/extraction needs start/end summary trace entries so debugging can see when a cache pass happened but no rule matched. +- Responses service currently logs a service create pass, but stream output and bridge conversions are not fully represented as protocol/bridge/final boundaries. + +## Trace Taxonomy + +### Client Request + +- `raw_client_request` +- `pre_provider_transform_request` +- `after_builtin_provider_transform` +- `after_provider_hook_transform` +- `after_provider_model_options` +- `before_litellm_conversion` +- `after_litellm_conversion` +- `prepared_provider_request` + +### LiteLLM / Provider Execution + +- `provider_execution_request` +- `raw_provider_response` +- `post_usage_normalization_response` +- `final_client_response` + +### Native Protocol Request + +- `raw_native_client_request` +- `native_protocol_selected` +- `parsed_native_unified_request` +- `built_native_provider_request` +- `before_adapter_chain` +- `after_adapter` +- `after_adapter_chain` +- `field_cache_injection_start` +- `field_cache_rule_*` +- `field_cache_injection_complete` +- `native_provider_request` + +### Native Protocol Response + +- `raw_native_provider_response` +- `parsed_native_unified_response` +- `formatted_native_response` +- `after_response_adapter_chain` +- `field_cache_extraction_start` +- `field_cache_rule_*` +- `field_cache_extraction_complete` +- `usage_accounting_summary` +- `final_client_response` + +### Native Stream + +- `native_provider_stream_request` +- `raw_native_provider_stream_chunk` +- `parsed_native_stream_event` +- `after_stream_event_adapter_chain` +- `after_field_cache_stream_extraction` +- `formatted_client_stream_event` + +### Responses API + +- `responses_raw_request` +- `responses_parsed_request` +- `responses_bridge_chat_request` +- `responses_bridge_chat_response` +- `responses_stored_response` +- `responses_final_response` +- stream equivalents for created, delta, completed, bridge chunk, and final states where available. + +### Errors + +- Continue using `transform_log_error`. +- Include failed pass name, component, provider, protocol, transport, and sanitized payload. + +## Implementation Plan + +1. Add small local trace helpers where they avoid repeated fragile `if transaction_logger` blocks. + - Avoid global state. + - Avoid dependency cycles. + - Prefer local helpers when a shared helper would create import loops. + +2. Expand `ProviderTransforms.apply()` tracing. + - Add optional keyword-only `transaction_logger`, `credential_id`, `transport`, and `trace_metadata` parameters. + - Trace `pre_provider_transform_request` before any mutation. + - Trace `after_builtin_provider_transform` after each built-in transform that returns a modification string. + - Trace `after_provider_hook_transform` when provider hooks modify or report modifications. + - Trace provider hook errors as `transform_log_error` without changing failure behavior. + - Trace `after_provider_model_options` when options are applied. + - Trace `before_litellm_conversion` and `after_litellm_conversion` around `convert_for_litellm()`. + +3. Pass trace context from `RequestExecutor._prepare_request_kwargs()`. + - Provide transaction logger, provider/model/session/scope/classifier metadata, transport, and stable credential ID where available. + - Add an optional `credential_id` argument rather than changing existing call semantics. + +4. Add raw provider and normalization response trace in non-streaming executor. + - Trace `provider_execution_request` immediately before the provider call after pre-request callback. + - Trace `after_pre_request_callback` if callback changed kwargs. + - Trace `raw_provider_response` immediately after provider call returns. + - Trace `post_usage_normalization_response` after `_normalize_response_usage()` and before returning. + +5. Expand streaming trace coverage. + - Ensure streaming request path gets the same provider-transform traces as non-streaming. + - Keep `raw_stream_chunk`, `parsed_stream_chunk`, and `assembled_stream_response`. + - Add `stream_error_event` for error SSE payloads produced by executor fallback/error handling. + - Add `stream_done_event` for `[DONE]` boundaries with snapshots disabled. + - Do not change emitted SSE bytes. + +6. Expand native provider trace coverage. + - Add `raw_native_client_request` before protocol parse. + - Trace `parsed_native_unified_request` after `protocol.parse_request()`. + - Trace `built_native_provider_request` after `protocol.build_request()` and before adapters. + - Trace `after_request_adapter_chain` after request adapters. + - Trace `parsed_native_unified_response`, `formatted_native_response`, and `after_response_adapter_chain` on response path. + - Add equivalent stream-event boundaries where applicable. + +7. Expand adapter chain summary traces. + - Keep `before_adapter_chain` and per-adapter `after_adapter`. + - Add `after_adapter_chain` with adapter list, stage, and count even when there are no adapters. + - Include `changed_from_previous` where feasible. + - Disable snapshots for stream-event summaries. + +8. Expand field-cache summary traces. + - Log `field_cache_injection_start` / `field_cache_extraction_start` before the pass. + - Log `field_cache_injection_complete` / `field_cache_extraction_complete` after completion. + - Include rule count, matched count, changed count, source, and operation type. + - Keep per-rule trace entries intact. + +9. Expand Responses API and bridge tracing. + - Trace raw request, parsed unified request, bridge chat request, bridge response, stored response, and final response where those boundaries exist. + - Trace emitted streaming response events with stable pass names and snapshots disabled for per-chunk events. + - Preserve route bodies and SSE event ordering. + +10. Tests. + - Extend `tests/test_transaction_logger_transform_trace.py` for new pass names and redaction stability. + - Add or extend tests for `ProviderTransforms.apply()` tracing each stage without changing payload behavior. + - Add native provider executor trace order tests for non-streaming and streaming. + - Add adapter chain summary trace test. + - Add field-cache start/complete trace tests, including no-rules/no-match cases. + - Add Responses service/stream trace tests if existing fixtures can cover it without large integration cost. + - Run protocol regressions because trace serialization touches protocol objects. + +## Commit Checkpoints + +1. `docs(experimental): plan transform trace coverage correction`. +2. Provider transform trace coverage and tests. +3. Executor request/response/stream trace coverage and tests. +4. Native provider/adapter/field-cache trace summaries and tests. +5. Responses trace coverage and tests. +6. Review fixes after `explore` and `explore-heavy`. +7. User-facing Phase 2b report, uncommitted. + +## Risks And Mitigations + +- Trace logging accidentally mutates payloads. Mitigation: pass deep-copied or read-only snapshots and rely on `sanitize_for_trace()`. +- New trace arguments break callers. Mitigation: make all new args keyword-only optional and keep old call sites valid. +- Stream trace output grows too much. Mitigation: keep per-stream-event snapshots disabled and log compact metadata. +- Import cycles. Mitigation: use private local helpers when needed. +- Brittle tests. Mitigation: assert pass presence and causal relative order only where required. From bd4a904399f41a1785bad8c86ee60d7999d697ad Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 09:06:58 +0200 Subject: [PATCH 082/182] feat(logging): trace provider transform boundaries Adds observability-only trace entries around live provider request transforms, including the pre-transform payload, built-in transforms, provider hook transforms, model option application, and LiteLLM conversion boundaries. RequestExecutor now passes transaction, transport, stable credential, session, scope, and classifier context into ProviderTransforms without changing request behavior. Tests: pytest tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 22 +++- src/rotator_library/client/transforms.py | 122 ++++++++++++++++++ ...test_transaction_logger_transform_trace.py | 53 ++++++++ 3 files changed, 195 insertions(+), 2 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index c1f6ed4ce..983a6ff50 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -372,6 +372,8 @@ async def _prepare_request_kwargs( model: str, cred: str, context: "RequestContext", + *, + credential_id: Optional[str] = None, ) -> Dict[str, Any]: """ Prepare request kwargs with transforms, sanitization, and provider params. @@ -392,6 +394,14 @@ async def _prepare_request_kwargs( cred, context.kwargs.copy(), provider_config_override=context.provider_config, + transaction_logger=context.transaction_logger, + credential_id=credential_id, + transport="sse" if context.streaming else "http", + trace_metadata={ + "session_id": context.session_id, + "scope_key": context.usage_manager_key, + "classifier": context.classifier, + }, ) # Sanitize request payload @@ -881,7 +891,11 @@ async def _execute_non_streaming( try: # Prepare request kwargs kwargs = await self._prepare_request_kwargs( - provider, model, cred, context + provider, + model, + cred, + context, + credential_id=cred_context.stable_id, ) # Log transformed request if it differs from original @@ -1100,7 +1114,11 @@ async def _execute_streaming( try: # Prepare request kwargs kwargs = await self._prepare_request_kwargs( - provider, model, cred, context + provider, + model, + cred, + context, + credential_id=cred_context.stable_id, ) # Log transformed request if it differs from original diff --git a/src/rotator_library/client/transforms.py b/src/rotator_library/client/transforms.py index fb2a02d88..c909b4215 100644 --- a/src/rotator_library/client/transforms.py +++ b/src/rotator_library/client/transforms.py @@ -15,6 +15,7 @@ """ import logging +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional lib_logger = logging.getLogger("rotator_library") @@ -81,6 +82,11 @@ async def apply( credential: str, kwargs: Dict[str, Any], provider_config_override: Optional[Dict[str, Any]] = None, + *, + transaction_logger: Optional[Any] = None, + credential_id: Optional[str] = None, + transport: Optional[str] = None, + trace_metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ Apply all applicable transforms to request kwargs. @@ -95,43 +101,126 @@ async def apply( Modified kwargs """ modifications: List[str] = [] + trace_metadata = dict(trace_metadata or {}) + _trace_transform_pass( + transaction_logger, + "pre_provider_transform_request", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + metadata={"phase": "start", **trace_metadata}, + ) # 1. Apply built-in transforms for transform_provider, transforms in self._transforms.items(): # Check if transform applies (provider match or model contains pattern) if transform_provider == provider or transform_provider in model.lower(): for transform in transforms: + before = deepcopy(kwargs) result = transform(kwargs, model, provider) if result: modifications.append(result) + _trace_transform_pass( + transaction_logger, + "after_builtin_provider_transform", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + changed_from_previous=before != kwargs, + metadata={ + "transform_provider": transform_provider, + "transform_name": getattr(transform, "__name__", repr(transform)), + "modification": result, + **trace_metadata, + }, + ) # 2. Apply provider hook transforms (async) plugin = self._get_plugin_instance(provider) if plugin and hasattr(plugin, "transform_request"): try: + before = deepcopy(kwargs) hook_result = await plugin.transform_request(kwargs, model, credential) if hook_result: modifications.extend(hook_result) + if hook_result or before != kwargs: + _trace_transform_pass( + transaction_logger, + "after_provider_hook_transform", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + changed_from_previous=before != kwargs, + metadata={"modifications": hook_result or [], **trace_metadata}, + ) except Exception as e: lib_logger.debug(f"Provider transform_request hook failed: {e}") + if transaction_logger: + transaction_logger.log_transform_error( + "provider_hook_transform", + e, + payload=kwargs, + stage="client", + metadata={"provider": provider, "model": model, **trace_metadata}, + ) # 3. Apply model-specific options from provider if plugin and hasattr(plugin, "get_model_options"): model_options = plugin.get_model_options(model) if model_options: + before = deepcopy(kwargs) for key, value in model_options.items(): if key == "reasoning_effort": kwargs["reasoning_effort"] = value elif key not in kwargs: kwargs[key] = value modifications.append(f"applied model options for {model}") + _trace_transform_pass( + transaction_logger, + "after_provider_model_options", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + changed_from_previous=before != kwargs, + metadata={"model_options": deepcopy(model_options), **trace_metadata}, + ) # 4. Apply LiteLLM conversion if config available if self._config and hasattr(self._config, "convert_for_litellm"): + before = deepcopy(kwargs) + _trace_transform_pass( + transaction_logger, + "before_litellm_conversion", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + metadata={"provider_config_override": bool(provider_config_override), **trace_metadata}, + ) kwargs = self._config.convert_for_litellm( provider_override=provider_config_override, **kwargs, ) + _trace_transform_pass( + transaction_logger, + "after_litellm_conversion", + kwargs, + provider=provider, + model=model, + credential_id=credential_id, + transport=transport, + changed_from_previous=before != kwargs, + metadata={"provider_config_override": bool(provider_config_override), **trace_metadata}, + ) if modifications: lib_logger.debug( @@ -312,3 +401,36 @@ def _transform_dedaluslabs_tool_choice( # caused 400 errors on models that don't support those categories (e.g. Gemma). # See gemini_provider.py for the full removal comment with previous defaults. # ========================================================================= + + +def _trace_transform_pass( + transaction_logger: Optional[Any], + pass_name: str, + payload: Dict[str, Any], + *, + provider: str, + model: str, + credential_id: Optional[str], + transport: Optional[str], + metadata: Dict[str, Any], + changed_from_previous: Optional[bool] = None, +) -> None: + """Record provider-transform states without changing transform behavior. + + Transform tracing is observability-only. This helper centralizes pass + metadata so individual transforms can stay focused on payload mutation while + the transaction trace still shows each live request state. + """ + + if not transaction_logger: + return + transaction_logger.log_transform_pass( + pass_name, + payload, + direction="request", + stage="client", + credential_id=credential_id, + transport=transport, + changed_from_previous=changed_from_previous, + metadata={"provider": provider, "model": model, **metadata}, + ) diff --git a/tests/test_transaction_logger_transform_trace.py b/tests/test_transaction_logger_transform_trace.py index 956b23ba3..1b090516c 100644 --- a/tests/test_transaction_logger_transform_trace.py +++ b/tests/test_transaction_logger_transform_trace.py @@ -5,6 +5,7 @@ import pytest from rotator_library.client.executor import RequestExecutor +from rotator_library.client.transforms import ProviderTransforms from rotator_library.transaction_logger import ProviderLogger, TransactionLogger from rotator_library.transform_trace import REDACTED @@ -140,3 +141,55 @@ def test_log_transform_error_uses_standard_shape_and_scrubs_text(tmp_path) -> No assert entries[0]["pass_name"] == "transform_log_error" assert entries[0]["data"]["failed_pass_name"] == "after_field_cache_injection" assert "secret" not in json.dumps(entries[0]["data"]) + + +@pytest.mark.asyncio +async def test_provider_transforms_trace_each_live_boundary(tmp_path) -> None: + class HookPlugin: + async def transform_request(self, kwargs, model, credential): + kwargs["hooked"] = credential + return ["hooked request"] + + def get_model_options(self, model): + return {"reasoning_effort": "low", "temperature": 0.2} + + class Config: + def convert_for_litellm(self, provider_override=None, **kwargs): + converted = dict(kwargs) + converted["converted_for_litellm"] = True + return converted + + logger = TransactionLogger("dedaluslabs", "dedaluslabs/test", parent_dir=tmp_path) + transforms = ProviderTransforms( + {"dedaluslabs": HookPlugin()}, + provider_config=Config(), + ) + + result = await transforms.apply( + "dedaluslabs", + "dedaluslabs/test", + "secret-credential", + {"model": "dedaluslabs/test", "tool_choice": "auto"}, + transaction_logger=logger, + credential_id="stable_cred", + transport="http", + trace_metadata={"scope_key": "scope"}, + ) + + assert "tool_choice" not in result + assert result["hooked"] == "secret-credential" + assert result["reasoning_effort"] == "low" + assert result["converted_for_litellm"] is True + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names == [ + "pre_provider_transform_request", + "after_builtin_provider_transform", + "after_provider_hook_transform", + "after_provider_model_options", + "before_litellm_conversion", + "after_litellm_conversion", + ] + assert all(entry["credential_id"] == "stable_cred" for entry in entries) + assert entries[1]["metadata"]["transform_provider"] == "dedaluslabs" + assert entries[-1]["changed_from_previous"] is True From 25e631d7c316265e2457082bc89485519f34a1ea Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 09:09:13 +0200 Subject: [PATCH 083/182] feat(logging): trace executor response boundaries Adds live executor trace entries for provider execution requests, raw provider responses, post-usage-normalization responses, pre-request callback mutations, stream error events, and stream DONE boundaries. The new trace entries are additive and preserve request/response/SSE behavior while making LiteLLM and custom execution boundaries visible in transform_trace.jsonl. Tests: pytest tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 125 +++++++++++++++++- ...test_transaction_logger_transform_trace.py | 18 +++ 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 983a6ff50..accd32d6b 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -18,6 +18,7 @@ import os import random import time +from copy import deepcopy from typing import ( Any, AsyncGenerator, @@ -496,7 +497,18 @@ async def _run_pre_request_callback( """ if context.pre_request_callback: try: + before = deepcopy(kwargs) await context.pre_request_callback(context.request, kwargs) + if before != kwargs: + self._log_executor_trace( + context, + "after_pre_request_callback", + kwargs, + direction="request", + stage="client", + changed_from_previous=True, + snapshot=True, + ) except Exception as e: if self._abort_on_callback_error: raise PreRequestCallbackError(str(e)) from e @@ -516,6 +528,15 @@ async def _execute_provider_request( target = _current_route_target(context) execution = target.execution if target else "auto" + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=credential_id, + metadata={"execution": execution, "provider": provider, "model": model}, + ) if execution == "litellm_fallback": self._log_routing_trace( context, @@ -787,6 +808,42 @@ def _log_routing_trace(context: RequestContext, pass_name: str, data: Any, *, me snapshot=False, ) + @staticmethod + def _log_executor_trace( + context: RequestContext, + pass_name: str, + data: Any, + *, + direction: str, + stage: str, + credential_id: Optional[str] = None, + changed_from_previous: Optional[bool] = None, + metadata: Optional[Dict[str, Any]] = None, + snapshot: bool = True, + ) -> None: + """Record live executor trace boundaries without affecting requests.""" + + if not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + pass_name, + data, + direction=direction, + stage=stage, + credential_id=credential_id, + transport="sse" if context.streaming else "http", + changed_from_previous=changed_from_previous, + metadata={ + "provider": context.provider, + "model": context.model, + "session_id": context.session_id, + "scope_key": context.usage_manager_key, + "classifier": context.classifier, + **(metadata or {}), + }, + snapshot=snapshot, + ) + async def _prepare_execution( self, context: RequestContext, @@ -933,6 +990,15 @@ async def _execute_non_streaming( kwargs, context, ) + self._log_executor_trace( + context, + "raw_provider_response", + response, + direction="response", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"provider": provider, "model": model}, + ) # Success! Extract token usage if available usage_record, cost_breakdown = self._account_for_response_usage( @@ -980,7 +1046,17 @@ async def _execute_non_streaming( f"Failed to log response: {log_err}" ) - return self._normalize_response_usage(response, model) + normalized_response = self._normalize_response_usage(response, model) + self._log_executor_trace( + context, + "post_usage_normalization_response", + normalized_response, + direction="response", + stage="final", + credential_id=cred_context.stable_id, + metadata={"provider": provider, "model": model}, + ) + return normalized_response except Exception as e: last_exception = e @@ -1164,6 +1240,15 @@ async def _execute_streaming( # Make the API call if plugin and plugin.has_custom_logic(): kwargs["credential_identifier"] = credential_secret + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"execution": "custom_stream", "provider": provider, "model": model}, + ) stream = await plugin.acompletion( self._http_client, **kwargs ) @@ -1173,8 +1258,28 @@ async def _execute_streaming( self._apply_litellm_logger(kwargs) # Remove internal context before litellm call kwargs.pop("transaction_context", None) + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"execution": "litellm_stream", "provider": provider, "model": model}, + ) stream = await litellm.acompletion(**kwargs) + self._log_executor_trace( + context, + "raw_provider_stream_response", + stream, + direction="response", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"provider": provider, "model": model}, + snapshot=False, + ) + # Hand off to streaming handler with cred_context # The handler will call mark_success on completion base_stream = self._streaming_handler.wrap_stream( @@ -1894,6 +1999,15 @@ async def _transaction_logging_stream_wrapper( transport="sse", snapshot=False, ) + if sse_line.startswith("data: [DONE]"): + transaction_logger.log_transform_pass( + "stream_done_event", + {"raw": sse_line}, + direction="stream", + stage="final", + transport="sse", + snapshot=False, + ) yield sse_line # Parse and accumulate for final logging @@ -1906,6 +2020,15 @@ async def _transaction_logging_stream_wrapper( chunk_data = json.loads(content) chunks.append(chunk_data) transaction_logger.log_stream_chunk(chunk_data) + if isinstance(chunk_data, dict) and chunk_data.get("error") is not None: + transaction_logger.log_transform_pass( + "stream_error_event", + chunk_data, + direction="stream", + stage="client", + transport="sse", + snapshot=False, + ) except json.JSONDecodeError: lib_logger.debug( f"Failed to parse chunk for logging: {sse_line[:100]}" diff --git a/tests/test_transaction_logger_transform_trace.py b/tests/test_transaction_logger_transform_trace.py index 1b090516c..cbcd78f78 100644 --- a/tests/test_transaction_logger_transform_trace.py +++ b/tests/test_transaction_logger_transform_trace.py @@ -102,10 +102,28 @@ async def stream(): pass_names = [entry["pass_name"] for entry in entries] assert pass_names.count("raw_stream_chunk") == 3 assert pass_names.count("parsed_stream_chunk") == 2 + assert pass_names.count("stream_done_event") == 1 assert "assembled_stream_response" in pass_names assert "final_client_response" in pass_names +@pytest.mark.asyncio +async def test_stream_wrapper_records_error_events(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + executor = RequestExecutor.__new__(RequestExecutor) + + async def stream(): + yield 'data: {"error":{"type":"rate_limit","message":"slow down"}}\n\n' + yield "data: [DONE]\n\n" + + chunks = [chunk async for chunk in executor._transaction_logging_stream_wrapper(stream(), logger, {})] + + assert chunks[-1] == "data: [DONE]\n\n" + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "stream_error_event" in pass_names + assert "stream_done_event" in pass_names + + def test_transaction_logger_disabled_writes_no_trace(tmp_path) -> None: logger = TransactionLogger("openai", "openai/gpt-test", enabled=False, parent_dir=tmp_path) From 46429163624bea2f41ab18ad0b967e6fbce77385 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 09:29:04 +0200 Subject: [PATCH 084/182] feat(logging): trace native adapter cache boundaries Adds final adapter-chain summary traces, field-cache start/complete summaries, and fuller native provider parse/build/adapter/cache response trace boundaries. The changes make no-rule cache passes visible, distinguish unified request/response protocol states from formatted provider payloads, and preserve existing execution behavior. Tests: pytest tests/test_adapter_registry.py tests/test_field_cache_engine.py tests/test_native_provider_executor.py --- src/rotator_library/adapters/base.py | 9 +++- src/rotator_library/field_cache/engine.py | 43 ++++++++++++++++++- .../native_provider/executor.py | 13 +++++- tests/test_adapter_registry.py | 27 ++++++++++++ tests/test_field_cache_engine.py | 28 ++++++++++++ tests/test_native_provider_executor.py | 12 +++++- 6 files changed, 127 insertions(+), 5 deletions(-) diff --git a/src/rotator_library/adapters/base.py b/src/rotator_library/adapters/base.py index 9e7caa6da..26f39bf18 100644 --- a/src/rotator_library/adapters/base.py +++ b/src/rotator_library/adapters/base.py @@ -107,6 +107,7 @@ async def run_adapter_chain( """ current = payload if mutate else deepcopy(payload) + original = deepcopy(current) adapter_names = [adapter.name for adapter in adapters] _trace(context, "before_adapter_chain", current, stage=stage, metadata={"adapters": adapter_names}) for adapter in adapters: @@ -132,5 +133,11 @@ async def run_adapter_chain( stage=stage, metadata={"adapter": adapter.name, "adapter_stage": stage, "changed": current != before}, ) - _trace(context, "after_adapter_chain", current, stage=stage, metadata={"adapters": adapter_names}) + _trace( + context, + "after_adapter_chain", + current, + stage=stage, + metadata={"adapters": adapter_names, "adapter_stage": stage, "adapter_count": len(adapter_names), "changed": original != current}, + ) return current diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index 58b43e962..c40ba67df 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -88,7 +88,9 @@ async def extract( transaction_logger: Optional[Any] = None, ) -> list[FieldCacheOperation]: operations: list[FieldCacheOperation] = [] - for rule in self._rules_for_source(source): + rules = self._rules_for_source(source) + self._trace_summary(transaction_logger, "field_cache_extraction_start", payload, source=source, target=None, rules=rules, operations=operations) + for rule in rules: operation = FieldCacheOperation(rule_name=rule.name, cache_key=build_cache_key(rule, context)) self._trace(transaction_logger, "before_field_cache_extraction", payload, rule, operation, source=source) if not operation.cache_key: @@ -109,6 +111,7 @@ async def extract( raise operations.append(operation) self._trace(transaction_logger, "after_field_cache_extraction", payload, rule, operation, source=source) + self._trace_summary(transaction_logger, "field_cache_extraction_complete", payload, source=source, target=None, rules=rules, operations=operations) return operations async def inject( @@ -122,7 +125,9 @@ async def inject( ) -> tuple[Any, list[FieldCacheOperation]]: updated = payload if mutate else deepcopy(payload) operations: list[FieldCacheOperation] = [] - for rule in self._rules_for_injection(target): + rules = self._rules_for_injection(target) + self._trace_summary(transaction_logger, "field_cache_injection_start", updated, source=None, target=target, rules=rules, operations=operations) + for rule in rules: operation = FieldCacheOperation(rule_name=rule.name, cache_key=build_cache_key(rule, context)) if not rule.inject: continue @@ -154,6 +159,7 @@ async def inject( raise operations.append(operation) self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + self._trace_summary(transaction_logger, "field_cache_injection_complete", updated, source=None, target=target, rules=rules, operations=operations) return updated, operations def _rules_for_source(self, source: str) -> list[FieldCacheRule]: @@ -214,6 +220,39 @@ def _trace( snapshot=rule.source != "stream_event", ) + def _trace_summary( + self, + transaction_logger: Optional[Any], + pass_name: str, + payload: Any, + *, + source: Optional[str], + target: Optional[str], + rules: list[FieldCacheRule], + operations: list[FieldCacheOperation], + ) -> None: + """Record cache-pass boundaries even when no individual rule matches.""" + + if not transaction_logger: + return + transaction_logger.log_transform_pass( + pass_name, + payload, + direction="request" if target else "response" if source == "response" else "stream" if source == "stream_event" else "metadata", + stage="adapter", + metadata={ + "source": source, + "target": target, + "rule_count": len(rules), + "operation_count": len(operations), + "matched_count": sum(operation.matched for operation in operations), + "changed_count": sum(1 for operation in operations if operation.changed), + "hit_count": sum(1 for operation in operations if operation.hit), + "skipped_count": sum(1 for operation in operations if operation.skipped), + }, + snapshot=(source != "stream_event"), + ) + def _log_error(self, transaction_logger: Optional[Any], pass_name: str, error: BaseException, payload: Any, rule: FieldCacheRule) -> None: if not transaction_logger: return diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index 84e776f2d..a2532e3fa 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -34,11 +34,15 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont protocol = get_protocol(context.protocol_name) self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") try: + self._trace(context, "raw_native_client_request", raw_request, direction="request", stage="client") protocol_context = context.protocol_context() unified_request = protocol.parse_request(raw_request, protocol_context) + self._trace(context, "parsed_native_unified_request", unified_request, direction="request", stage="protocol") provider_request = protocol.build_request(unified_request, protocol_context) + self._trace(context, "built_native_provider_request", provider_request, direction="request", stage="protocol") adapters = [get_adapter(name) for name in context.adapter_names] provider_request = await run_adapter_chain(adapters, provider_request, context.adapter_context(), stage="request") + self._trace(context, "after_request_adapter_chain", provider_request, direction="request", stage="adapter") cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) provider_request, _ = await cache_engine.inject( "request", @@ -51,9 +55,11 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont raw_response = await transport.post_json(context.endpoint, headers=context.headers, payload=provider_request) self._trace(context, "raw_native_provider_response", raw_response, direction="response", stage="provider") unified_response = protocol.parse_response(raw_response, protocol_context) + self._trace(context, "parsed_native_unified_response", unified_response, direction="response", stage="protocol") provider_response = protocol.format_response(unified_response, protocol_context) - self._trace(context, "parsed_native_provider_response", provider_response, direction="response", stage="protocol") + self._trace(context, "formatted_native_response", provider_response, direction="response", stage="protocol") provider_response = await run_adapter_chain(adapters, provider_response, context.adapter_context(), stage="response") + self._trace(context, "after_response_adapter_chain", provider_response, direction="response", stage="adapter") await cache_engine.extract("response", provider_response, context.field_cache_context(), transaction_logger=logger) usage_record = extract_usage_record( provider_response, @@ -90,13 +96,17 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte protocol = get_protocol(context.protocol_name) self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") try: + self._trace(context, "raw_native_client_request", raw_request, direction="request", stage="client") protocol_context = context.protocol_context() request_payload = dict(raw_request) request_payload["stream"] = True unified_request = protocol.parse_request(request_payload, protocol_context) + self._trace(context, "parsed_native_unified_request", unified_request, direction="request", stage="protocol") provider_request = protocol.build_request(unified_request, protocol_context) + self._trace(context, "built_native_provider_request", provider_request, direction="request", stage="protocol") adapters = [get_adapter(name) for name in context.adapter_names] provider_request = await run_adapter_chain(adapters, provider_request, context.adapter_context(), stage="request") + self._trace(context, "after_request_adapter_chain", provider_request, direction="request", stage="adapter") cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) provider_request, _ = await cache_engine.inject( "request", @@ -109,6 +119,7 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte async for raw_chunk in transport.stream_json_lines(context.endpoint, headers=context.headers, payload=provider_request): self._trace(context, "raw_native_provider_stream_chunk", raw_chunk, direction="stream", stage="provider") event = protocol.parse_stream_event(raw_chunk, protocol_context) + self._trace(context, "parsed_native_unified_stream_event", event, direction="stream", stage="protocol", snapshot=False) event_payload = stream_event_payload(event) self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") await cache_engine.extract("stream_event", event_payload, context.field_cache_context(), transaction_logger=logger) diff --git a/tests/test_adapter_registry.py b/tests/test_adapter_registry.py index 59965f4d5..739a079d8 100644 --- a/tests/test_adapter_registry.py +++ b/tests/test_adapter_registry.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + import pytest from rotator_library.adapters import ( @@ -12,6 +14,11 @@ resolve_adapter_name, run_adapter_chain, ) +from rotator_library.transaction_logger import TransactionLogger + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] def test_adapter_registry_auto_discovers_builtins_and_aliases() -> None: @@ -120,3 +127,23 @@ async def test_adapter_chain_order_is_preserved() -> None: assert result["model"] == "native" assert result["messages"][0]["role"] == "system" + + +@pytest.mark.asyncio +async def test_adapter_chain_traces_final_summary(tmp_path) -> None: + logger = TransactionLogger("native", "native/test", parent_dir=tmp_path) + payload = {"model": "public", "messages": []} + context = AdapterContext( + adapter_config={"model_override": {"model": "native"}}, + transaction_logger=logger, + protocol="openai_chat", + ) + + result = await run_adapter_chain([get_adapter("model_override")], payload, context, stage="request") + + assert result["model"] == "native" + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names == ["before_adapter_chain", "after_adapter", "after_adapter_chain"] + assert entries[-1]["metadata"]["adapter_count"] == 1 + assert entries[-1]["metadata"]["changed"] is True diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py index 6a47545a6..741aaca81 100644 --- a/tests/test_field_cache_engine.py +++ b/tests/test_field_cache_engine.py @@ -13,6 +13,11 @@ ProviderCacheFieldStore, build_cache_key, ) +from rotator_library.transaction_logger import TransactionLogger + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] def _reasoning_rule(mode: str = "last", scope=("provider", "model", "session")) -> FieldCacheRule: @@ -159,6 +164,29 @@ async def test_trace_sample_values_are_truncated() -> None: assert operations[0].sample_values[0].endswith("...") +@pytest.mark.asyncio +async def test_field_cache_traces_start_and_complete_even_without_matching_rules(tmp_path) -> None: + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + engine = FieldCacheEngine([]) + + operations = await engine.extract("response", {"choices": []}, _context(), transaction_logger=logger) + updated, injection_operations = await engine.inject("request", {"messages": []}, _context(), transaction_logger=logger) + + assert operations == [] + assert injection_operations == [] + assert updated == {"messages": []} + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert pass_names == [ + "field_cache_extraction_start", + "field_cache_extraction_complete", + "field_cache_injection_start", + "field_cache_injection_complete", + ] + assert entries[1]["metadata"]["rule_count"] == 0 + assert entries[-1]["metadata"]["operation_count"] == 0 + + def test_per_tool_call_requires_tool_call_id_path() -> None: with pytest.raises(ValueError): FieldCacheEngine([ diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index 14ca38850..20593d96a 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -72,11 +72,21 @@ async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tm assert client.calls[0]["json"]["model"] == "provider/gpt-test" pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] assert "native_protocol_selected" in pass_names + assert "raw_native_client_request" in pass_names + assert "parsed_native_unified_request" in pass_names + assert "built_native_provider_request" in pass_names + assert "after_request_adapter_chain" in pass_names + assert "field_cache_injection_start" in pass_names assert "after_field_cache_injection" in pass_names + assert "field_cache_injection_complete" in pass_names assert "native_provider_request" in pass_names assert "raw_native_provider_response" in pass_names - assert "parsed_native_provider_response" in pass_names + assert "parsed_native_unified_response" in pass_names + assert "formatted_native_response" in pass_names + assert "after_response_adapter_chain" in pass_names + assert "field_cache_extraction_start" in pass_names assert "after_field_cache_extraction" in pass_names + assert "field_cache_extraction_complete" in pass_names assert "final_client_response" in pass_names From 6b261b79cdea7d32a6343957b05bc437240adb4e Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 09:31:05 +0200 Subject: [PATCH 085/182] feat(logging): align responses trace boundaries Renames and expands Responses API transform trace entries to match the Phase 2b taxonomy for raw request, parsed request, bridge response, stored response, final response, and emitted stream events. The route payloads and SSE event order are unchanged; only transform_trace.jsonl receives more precise pass names and stream event boundaries. Tests: pytest tests/test_responses_service.py tests/test_responses_streaming.py --- src/rotator_library/responses/service.py | 33 ++++++++++++++---------- tests/test_responses_service.py | 12 ++++----- tests/test_responses_streaming.py | 3 +++ 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 9b58a6a00..47530c33e 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -73,9 +73,9 @@ async def create_response( if raw_request.get("stream"): raise ResponsesServiceError("Use stream_response for streaming requests", status_code=400) - self._trace(transaction_logger, "raw_responses_request", raw_request, direction="request", stage="client") + self._trace(transaction_logger, "responses_raw_request", raw_request, direction="request", stage="client") unified = self.protocol.parse_request(raw_request, ProtocolContext(source_protocol="responses")) - self._trace(transaction_logger, "parsed_unified_request", unified.to_dict(), direction="request", stage="protocol") + self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) @@ -90,18 +90,18 @@ async def create_response( ) chat_response = await client.acompletion(request=request, **chat_kwargs) - self._trace(transaction_logger, "raw_chat_bridge_response", self._response_to_dict(chat_response), direction="response", stage="provider") + self._trace(transaction_logger, "responses_bridge_chat_response", self._response_to_dict(chat_response), direction="response", stage="provider") response_payload = self.bridge.from_chat_response(chat_response, unified) - self._trace(transaction_logger, "parsed_unified_response", response_payload, direction="response", stage="protocol") + self._trace(transaction_logger, "responses_parsed_response", response_payload, direction="response", stage="protocol") self._trace_responses_usage(transaction_logger, response_payload, unified.model, source="responses_response") if raw_request.get("store", True): stored = self._stored_response(raw_request, response_payload, parent) await self.store.save(stored) - self._trace(transaction_logger, "stored_responses_response", stored.to_dict(), direction="metadata", stage="final") + self._trace(transaction_logger, "responses_stored_response", stored.to_dict(), direction="metadata", stage="final") - self._trace(transaction_logger, "final_responses_response", response_payload, direction="response", stage="final") + self._trace(transaction_logger, "responses_final_response", response_payload, direction="response", stage="final") return response_payload async def stream_response( @@ -119,9 +119,9 @@ async def stream_response( formatter = ResponsesSSEFormatter() stream_request = dict(raw_request) stream_request["stream"] = True - self._trace(transaction_logger, "raw_responses_request", stream_request, direction="request", stage="client") + self._trace(transaction_logger, "responses_raw_request", stream_request, direction="request", stage="client") unified = self.protocol.parse_request(stream_request, ProtocolContext(source_protocol="responses", transport="sse")) - self._trace(transaction_logger, "parsed_unified_request", unified.to_dict(), direction="request", stage="protocol") + self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) @@ -148,7 +148,9 @@ async def stream_response( stage="client", metadata={"transport": "sse"}, ) - yield formatter.format_event("response.created", response_created_payload(response_id, unified.model)) + created = response_created_payload(response_id, unified.model) + self._trace(transaction_logger, "responses_stream_event_created", created, direction="stream", stage="final", metadata={"transport": "sse"}) + yield formatter.format_event("response.created", created) try: chat_stream = await client.acompletion(request=request, **chat_kwargs) async for raw_chunk in chat_stream: @@ -175,7 +177,7 @@ async def stream_response( if not item_started: item_started = True added = output_item_added_payload(state) - self._trace(transaction_logger, "formatted_responses_stream_event", added, direction="stream", stage="final") + self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.format_event("response.output_item.added", added) state = ResponsesStreamState( response_id=state.response_id, @@ -202,17 +204,20 @@ async def stream_response( stage="final", metadata={"transport": "sse"}, ) - self._trace(transaction_logger, "formatted_responses_stream_event", event, direction="stream", stage="final") + self._trace(transaction_logger, "responses_stream_event_output_text_delta", event, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.format_event("response.output_text.delta", event) if not item_started: - yield formatter.format_event("response.output_item.added", output_item_added_payload(state)) + added = output_item_added_payload(state) + self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": "sse"}) + yield formatter.format_event("response.output_item.added", added) done_item = output_item_done_payload(state) + self._trace(transaction_logger, "responses_stream_event_output_item_done", done_item, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.format_event("response.output_item.done", done_item) completed = response_completed_payload(state, _usage_to_responses_stream(usage)) self._trace_responses_usage(transaction_logger, completed, unified.model, source="responses_stream") await self._store_stream_response(stream_request, completed, parent) - self._trace(transaction_logger, "stored_responses_stream_response", completed, direction="metadata", stage="final") + self._trace(transaction_logger, "responses_stored_stream_response", completed, direction="metadata", stage="final") monitor.complete() self._trace( transaction_logger, @@ -230,12 +235,14 @@ async def stream_response( stage="final", metadata={"transport": "sse"}, ) + self._trace(transaction_logger, "responses_stream_event_completed", completed, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.format_event("response.completed", completed) yield formatter.done() except Exception as exc: monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) + self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": "sse"}) self._trace( transaction_logger, "stream_metrics_final", diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py index 1832491ea..4495d95f1 100644 --- a/tests/test_responses_service.py +++ b/tests/test_responses_service.py @@ -112,12 +112,12 @@ async def test_service_emits_transform_trace_passes(tmp_path) -> None: pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] assert pass_names == [ - "raw_responses_request", - "parsed_unified_request", + "responses_raw_request", + "responses_parsed_request", "responses_bridge_chat_request", - "raw_chat_bridge_response", - "parsed_unified_response", + "responses_bridge_chat_response", + "responses_parsed_response", "usage_accounting_summary", - "stored_responses_response", - "final_responses_response", + "responses_stored_response", + "responses_final_response", ] diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 937d822e0..67ed2eba7 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -100,3 +100,6 @@ async def test_responses_stream_records_common_stream_metrics(tmp_path) -> None: assert "stream_first_visible_output" in trace_text assert "stream_completed" in trace_text assert "stream_metrics_final" in trace_text + assert "responses_stream_event_created" in trace_text + assert "responses_stream_event_output_text_delta" in trace_text + assert "responses_stream_event_completed" in trace_text From c0e745c9b77644caf74d656496391b926e9716ab Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 09:44:26 +0200 Subject: [PATCH 086/182] fix(logging): close transform trace review gaps Fixes Phase 2b review findings by tracing executor-generated terminal stream errors and DONE events, exposing sanitizer/provider-param/framework-context request mutations, and recording final credential-injected LiteLLM/custom execution requests. Avoids trace-only deep copies when tracing is disabled, removes raw cached sample values and payload bodies from field-cache trace entries, scrubs Responses stream failure traces, and only logs stream storage when storage actually occurs. Tests: pytest tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py tests/test_adapter_registry.py tests/test_field_cache_engine.py tests/test_native_provider_executor.py tests/test_responses_service.py tests/test_responses_streaming.py tests/test_protocol_operation_model.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- src/rotator_library/adapters/base.py | 11 +- src/rotator_library/client/executor.py | 124 +++++++++++++++--- src/rotator_library/client/transforms.py | 21 +-- src/rotator_library/field_cache/engine.py | 26 +++- src/rotator_library/responses/service.py | 16 ++- tests/test_field_cache_engine.py | 21 +++ tests/test_responses_streaming.py | 34 ++++- ...test_transaction_logger_transform_trace.py | 46 +++++++ 8 files changed, 251 insertions(+), 48 deletions(-) diff --git a/src/rotator_library/adapters/base.py b/src/rotator_library/adapters/base.py index 26f39bf18..89f4e48a8 100644 --- a/src/rotator_library/adapters/base.py +++ b/src/rotator_library/adapters/base.py @@ -107,11 +107,12 @@ async def run_adapter_chain( """ current = payload if mutate else deepcopy(payload) - original = deepcopy(current) + tracing_enabled = context.transaction_logger is not None + original = deepcopy(current) if tracing_enabled else None adapter_names = [adapter.name for adapter in adapters] _trace(context, "before_adapter_chain", current, stage=stage, metadata={"adapters": adapter_names}) for adapter in adapters: - before = deepcopy(current) + before = deepcopy(current) if tracing_enabled else None try: current = await adapter.transform(stage, current, context) except Exception as exc: @@ -119,7 +120,7 @@ async def run_adapter_chain( context.transaction_logger.log_transform_error( f"adapter:{adapter.name}:{stage}", exc, - payload=before, + payload=before if before is not None else current, stage="adapter", protocol=context.protocol, transport=context.transport, @@ -131,13 +132,13 @@ async def run_adapter_chain( "after_adapter", current, stage=stage, - metadata={"adapter": adapter.name, "adapter_stage": stage, "changed": current != before}, + metadata={"adapter": adapter.name, "adapter_stage": stage, "changed": (current != before) if before is not None else None}, ) _trace( context, "after_adapter_chain", current, stage=stage, - metadata={"adapters": adapter_names, "adapter_stage": stage, "adapter_count": len(adapter_names), "changed": original != current}, + metadata={"adapters": adapter_names, "adapter_stage": stage, "adapter_count": len(adapter_names), "changed": (original != current) if original is not None else None}, ) return current diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index accd32d6b..b2e9ab102 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -405,15 +405,48 @@ async def _prepare_request_kwargs( }, ) - # Sanitize request payload + # Sanitize request payload. Some provider compatibility fields are + # intentionally removed here, so record it as its own transform pass. + before_sanitize = deepcopy(kwargs) if context.transaction_logger else None kwargs = sanitize_request_payload(kwargs, model) + self._log_executor_trace( + context, + "after_request_sanitization", + kwargs, + direction="request", + stage="client", + credential_id=credential_id, + changed_from_previous=(before_sanitize != kwargs) if before_sanitize is not None else None, + metadata={"provider": provider, "model": model}, + ) # Apply provider-specific LiteLLM params + before_params = deepcopy(kwargs) if context.transaction_logger else None self._apply_litellm_provider_params(provider, kwargs) + self._log_executor_trace( + context, + "after_litellm_provider_params", + kwargs, + direction="request", + stage="client", + credential_id=credential_id, + changed_from_previous=(before_params != kwargs) if before_params is not None else None, + metadata={"provider": provider, "model": model}, + ) # Add transaction context for provider logging if context.transaction_logger: kwargs["transaction_context"] = context.transaction_logger.get_context() + self._log_executor_trace( + context, + "after_transaction_context_attached", + kwargs, + direction="request", + stage="client", + credential_id=credential_id, + changed_from_previous=True, + metadata={"provider": provider, "model": model}, + ) return kwargs @@ -497,9 +530,9 @@ async def _run_pre_request_callback( """ if context.pre_request_callback: try: - before = deepcopy(kwargs) + before = deepcopy(kwargs) if context.transaction_logger else None await context.pre_request_callback(context.request, kwargs) - if before != kwargs: + if before is not None and before != kwargs: self._log_executor_trace( context, "after_pre_request_callback", @@ -530,7 +563,7 @@ async def _execute_provider_request( execution = target.execution if target else "auto" self._log_executor_trace( context, - "provider_execution_request", + "pre_provider_execution_request", kwargs, direction="request", stage="provider", @@ -543,12 +576,21 @@ async def _execute_provider_request( "routing_litellm_fallback", _target_trace(target) if target else {"provider": provider, "model": model}, ) - return await self._execute_litellm_request(kwargs, credential_secret) + return await self._execute_litellm_request(kwargs, credential_secret, context=context, credential_id=credential_id) if execution == "custom" or (execution == "auto" and plugin and plugin.has_custom_logic()): if not plugin or not plugin.has_custom_logic(): raise RoutingExecutionError(f"Provider {provider} does not support custom execution") kwargs["credential_identifier"] = credential_secret + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=credential_id, + metadata={"execution": "custom", "provider": provider, "model": model}, + ) return await plugin.acompletion(self._http_client, **kwargs) if execution == "native" or (execution == "auto" and _provider_native_protocol(plugin, model, target)): @@ -569,14 +611,31 @@ async def _execute_provider_request( ) return await NativeProviderExecutor().execute(dict(kwargs), native_context, NativeHTTPTransport(self._http_client)) - return await self._execute_litellm_request(kwargs, credential_secret) + return await self._execute_litellm_request(kwargs, credential_secret, context=context, credential_id=credential_id) - async def _execute_litellm_request(self, kwargs: Dict[str, Any], credential_secret: str) -> Any: + async def _execute_litellm_request( + self, + kwargs: Dict[str, Any], + credential_secret: str, + *, + context: Optional[RequestContext] = None, + credential_id: Optional[str] = None, + ) -> Any: """Execute the existing LiteLLM request path.""" kwargs["api_key"] = credential_secret self._apply_litellm_logger(kwargs) kwargs.pop("transaction_context", None) + if context: + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=credential_id, + metadata={"execution": "litellm", "provider": context.provider, "model": context.model}, + ) return await litellm.acompletion(**kwargs) def _build_native_provider_context( @@ -844,6 +903,29 @@ def _log_executor_trace( snapshot=snapshot, ) + def _terminal_stream_error_lines(self, context: RequestContext, error_data: Dict[str, Any]) -> Tuple[str, str]: + """Return executor-created terminal SSE lines and trace them first.""" + + error_line = f"data: {json.dumps(error_data)}\n\n" + done_line = "data: [DONE]\n\n" + self._log_executor_trace( + context, + "stream_error_event", + error_data, + direction="stream", + stage="client", + snapshot=False, + ) + self._log_executor_trace( + context, + "stream_done_event", + {"raw": done_line}, + direction="stream", + stage="final", + snapshot=False, + ) + return error_line, done_line + async def _prepare_execution( self, context: RequestContext, @@ -1130,8 +1212,8 @@ async def _execute_streaming( "type": "proxy_error", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line return error_accumulator = RequestErrorAccumulator() @@ -1353,8 +1435,8 @@ async def _execute_streaming( "type": "quota_exhausted", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line return else: retry_state.reset_quota_failures() @@ -1374,8 +1456,8 @@ async def _execute_streaming( "type": classified.error_type, } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line return small_cooldown_threshold = int( @@ -1444,8 +1526,8 @@ async def _execute_streaming( "type": "quota_exhausted", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line return else: retry_state.reset_quota_failures() @@ -1599,20 +1681,20 @@ async def _execute_streaming( # All credentials exhausted or timeout error_accumulator.timeout_occurred = time.time() >= deadline error_data = error_accumulator.build_client_error_response() - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line except NoAvailableKeysError as e: lib_logger.error(f"No keys available: {e}") error_data = {"error": {"message": str(e), "type": "proxy_busy"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line except Exception as e: lib_logger.error(f"Unhandled exception in streaming: {e}", exc_info=True) error_data = {"error": {"message": str(e), "type": "proxy_internal_error"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + for line in self._terminal_stream_error_lines(context, error_data): + yield line def _apply_litellm_provider_params( self, provider: str, kwargs: Dict[str, Any] diff --git a/src/rotator_library/client/transforms.py b/src/rotator_library/client/transforms.py index c909b4215..8538c069f 100644 --- a/src/rotator_library/client/transforms.py +++ b/src/rotator_library/client/transforms.py @@ -118,7 +118,7 @@ async def apply( # Check if transform applies (provider match or model contains pattern) if transform_provider == provider or transform_provider in model.lower(): for transform in transforms: - before = deepcopy(kwargs) + before = deepcopy(kwargs) if transaction_logger else None result = transform(kwargs, model, provider) if result: modifications.append(result) @@ -130,7 +130,7 @@ async def apply( model=model, credential_id=credential_id, transport=transport, - changed_from_previous=before != kwargs, + changed_from_previous=(before != kwargs) if before is not None else None, metadata={ "transform_provider": transform_provider, "transform_name": getattr(transform, "__name__", repr(transform)), @@ -143,11 +143,11 @@ async def apply( plugin = self._get_plugin_instance(provider) if plugin and hasattr(plugin, "transform_request"): try: - before = deepcopy(kwargs) + before = deepcopy(kwargs) if transaction_logger else None hook_result = await plugin.transform_request(kwargs, model, credential) if hook_result: modifications.extend(hook_result) - if hook_result or before != kwargs: + if hook_result or (before is not None and before != kwargs): _trace_transform_pass( transaction_logger, "after_provider_hook_transform", @@ -156,7 +156,7 @@ async def apply( model=model, credential_id=credential_id, transport=transport, - changed_from_previous=before != kwargs, + changed_from_previous=(before != kwargs) if before is not None else None, metadata={"modifications": hook_result or [], **trace_metadata}, ) except Exception as e: @@ -167,14 +167,15 @@ async def apply( e, payload=kwargs, stage="client", - metadata={"provider": provider, "model": model, **trace_metadata}, + transport=transport, + metadata={"provider": provider, "model": model, "credential_id": credential_id, **trace_metadata}, ) # 3. Apply model-specific options from provider if plugin and hasattr(plugin, "get_model_options"): model_options = plugin.get_model_options(model) if model_options: - before = deepcopy(kwargs) + before = deepcopy(kwargs) if transaction_logger else None for key, value in model_options.items(): if key == "reasoning_effort": kwargs["reasoning_effort"] = value @@ -189,13 +190,13 @@ async def apply( model=model, credential_id=credential_id, transport=transport, - changed_from_previous=before != kwargs, + changed_from_previous=(before != kwargs) if before is not None else None, metadata={"model_options": deepcopy(model_options), **trace_metadata}, ) # 4. Apply LiteLLM conversion if config available if self._config and hasattr(self._config, "convert_for_litellm"): - before = deepcopy(kwargs) + before = deepcopy(kwargs) if transaction_logger else None _trace_transform_pass( transaction_logger, "before_litellm_conversion", @@ -218,7 +219,7 @@ async def apply( model=model, credential_id=credential_id, transport=transport, - changed_from_previous=before != kwargs, + changed_from_previous=(before != kwargs) if before is not None else None, metadata={"provider_config_override": bool(provider_config_override), **trace_metadata}, ) diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index c40ba67df..7b55cc9f0 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -199,7 +199,7 @@ def _trace( return transaction_logger.log_transform_pass( pass_name, - payload, + _payload_shape(payload), direction=_trace_direction(pass_name, rule.source, extra_metadata), stage="adapter", metadata={ @@ -214,7 +214,10 @@ def _trace( "hit": operation.hit, "skipped": operation.skipped, "reason": operation.reason, - "sample_values": operation.sample_values[:3], + # Cached fields can include provider signatures or session keys. + # Trace only shape/count metadata; keep raw samples out of logs. + "sample_value_count": len(operation.sample_values), + "sample_value_types": [type(value).__name__ for value in operation.sample_values[:3]], **extra_metadata, }, snapshot=rule.source != "stream_event", @@ -237,8 +240,8 @@ def _trace_summary( return transaction_logger.log_transform_pass( pass_name, - payload, - direction="request" if target else "response" if source == "response" else "stream" if source == "stream_event" else "metadata", + _payload_shape(payload), + direction="request" if target or source == "request" else "response" if source == "response" else "stream" if source == "stream_event" else "metadata", stage="adapter", metadata={ "source": source, @@ -294,3 +297,18 @@ def _sample_values(values: list[Any], *, max_items: int = 3, max_text: int = 500 else: samples.append(value) return samples + + +def _payload_shape(payload: Any) -> dict[str, Any]: + """Return non-sensitive payload shape metadata for cache traces. + + Cache rules often target provider signatures, session IDs, and other opaque + state. Logging full payloads would expose exactly the fields the cache is + designed to preserve, so field-cache traces record structure only. + """ + + if isinstance(payload, dict): + return {"payload_type": "dict", "keys": sorted(str(key) for key in payload.keys())[:20]} + if isinstance(payload, list): + return {"payload_type": "list", "length": len(payload)} + return {"payload_type": type(payload).__name__} diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 47530c33e..69642429f 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -216,8 +216,11 @@ async def stream_response( yield formatter.format_event("response.output_item.done", done_item) completed = response_completed_payload(state, _usage_to_responses_stream(usage)) self._trace_responses_usage(transaction_logger, completed, unified.model, source="responses_stream") - await self._store_stream_response(stream_request, completed, parent) - self._trace(transaction_logger, "responses_stored_stream_response", completed, direction="metadata", stage="final") + stored = await self._store_stream_response(stream_request, completed, parent) + if stored: + self._trace(transaction_logger, "responses_stored_stream_response", completed, direction="metadata", stage="final") + else: + self._trace(transaction_logger, "responses_store_skipped", {"response_id": completed.get("id")}, direction="metadata", stage="final") monitor.complete() self._trace( transaction_logger, @@ -242,7 +245,7 @@ async def stream_response( monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) - self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": "sse"}) + self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": "sse"}, scrub_strings=True) self._trace( transaction_logger, "stream_metrics_final", @@ -337,6 +340,7 @@ def _trace( direction: str, stage: str, metadata: Optional[dict[str, Any]] = None, + scrub_strings: bool = False, ) -> None: if not transaction_logger: return @@ -347,6 +351,7 @@ def _trace( stage=stage, protocol="responses", metadata=metadata or {}, + scrub_strings=scrub_strings, ) @staticmethod @@ -382,10 +387,11 @@ async def _store_stream_response( raw_request: dict[str, Any], response_payload: dict[str, Any], parent: Optional[StoredResponse], - ) -> None: + ) -> bool: if not raw_request.get("store", True): - return + return False await self.store.save(self._stored_response(raw_request, response_payload, parent)) + return True def _input_items(raw_request: dict[str, Any]) -> list[Any]: diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py index 741aaca81..a911951e6 100644 --- a/tests/test_field_cache_engine.py +++ b/tests/test_field_cache_engine.py @@ -164,6 +164,27 @@ async def test_trace_sample_values_are_truncated() -> None: assert operations[0].sample_values[0].endswith("...") +@pytest.mark.asyncio +async def test_field_cache_trace_omits_raw_sample_values(tmp_path) -> None: + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + rule = _reasoning_rule() + engine = FieldCacheEngine([rule]) + + await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "provider-signature-secret"}}]}, + _context(), + transaction_logger=logger, + ) + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "provider-signature-secret" not in trace_text + entries = _trace_entries(logger.log_dir) + after_entry = next(entry for entry in entries if entry["pass_name"] == "after_field_cache_extraction") + assert after_entry["metadata"]["sample_value_count"] == 1 + assert after_entry["metadata"]["sample_value_types"] == ["str"] + + @pytest.mark.asyncio async def test_field_cache_traces_start_and_complete_even_without_matching_rules(tmp_path) -> None: logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 67ed2eba7..8eb408b3d 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -22,10 +22,15 @@ async def chunks(): class FailingStreamingClient: + def __init__(self, message: str = "stream exploded") -> None: + self.message = message + async def acompletion(self, **kwargs): + message = self.message + async def chunks(): yield 'data: {"choices":[{"delta":{"content":"before"}}]}\n\n' - raise RuntimeError("stream exploded") + raise RuntimeError(message) return chunks() @@ -56,14 +61,18 @@ async def test_stream_response_emits_responses_sse_events_and_stores_final_respo @pytest.mark.asyncio -async def test_stream_response_store_false_does_not_persist() -> None: +async def test_stream_response_store_false_does_not_persist(tmp_path) -> None: store = InMemoryResponsesStore() service = ResponsesService(store=store) + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) - events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True, "store": False}, FakeStreamingClient())] + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True, "store": False}, FakeStreamingClient(), transaction_logger=logger)] response_id = events[0].split('"id": "')[1].split('"')[0] assert await store.get(response_id) is None + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "responses_store_skipped" in trace_text + assert "responses_stored_stream_response" not in trace_text @pytest.mark.asyncio @@ -78,6 +87,25 @@ async def test_stream_response_errors_emit_failed_event() -> None: assert event_text.endswith("data: [DONE]\n\n") +@pytest.mark.asyncio +async def test_stream_response_failure_trace_scrubs_header_like_secret_text(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + _ = [ + chunk + async for chunk in service.stream_response( + {"model": "gpt-test", "input": "Hello", "stream": True}, + FailingStreamingClient("Authorization: Bearer secret-token"), + transaction_logger=logger, + ) + ] + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "secret-token" not in trace_text + assert "[REDACTED]" in trace_text + + def test_transport_formatters_expose_sse_and_websocket_seam() -> None: assert ResponsesSSEFormatter().transport == "sse" websocket = ResponsesWebSocketFormatter() diff --git a/tests/test_transaction_logger_transform_trace.py b/tests/test_transaction_logger_transform_trace.py index cbcd78f78..60be0b93d 100644 --- a/tests/test_transaction_logger_transform_trace.py +++ b/tests/test_transaction_logger_transform_trace.py @@ -124,6 +124,28 @@ async def stream(): assert "stream_done_event" in pass_names +def test_executor_terminal_stream_errors_are_traced(tmp_path) -> None: + class Context: + pass + + context = Context() + context.transaction_logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + context.streaming = True + context.provider = "openai" + context.model = "openai/gpt-test" + context.session_id = None + context.usage_manager_key = "openai" + context.classifier = None + + executor = RequestExecutor.__new__(RequestExecutor) + + lines = executor._terminal_stream_error_lines(context, {"error": {"type": "proxy_error"}}) + + assert lines[-1] == "data: [DONE]\n\n" + pass_names = [entry["pass_name"] for entry in _trace_entries(context.transaction_logger.log_dir)] + assert pass_names == ["stream_error_event", "stream_done_event"] + + def test_transaction_logger_disabled_writes_no_trace(tmp_path) -> None: logger = TransactionLogger("openai", "openai/gpt-test", enabled=False, parent_dir=tmp_path) @@ -211,3 +233,27 @@ def convert_for_litellm(self, provider_override=None, **kwargs): assert all(entry["credential_id"] == "stable_cred" for entry in entries) assert entries[1]["metadata"]["transform_provider"] == "dedaluslabs" assert entries[-1]["changed_from_previous"] is True + + +@pytest.mark.asyncio +async def test_provider_transforms_do_not_deepcopy_for_trace_when_disabled() -> None: + class NoCopyValue: + def __deepcopy__(self, memo): + raise AssertionError("trace comparison should not copy when tracing is disabled") + + class HookPlugin: + async def transform_request(self, kwargs, model, credential): + kwargs["hooked"] = True + return ["hooked request"] + + transforms = ProviderTransforms({"dedaluslabs": HookPlugin()}) + + result = await transforms.apply( + "dedaluslabs", + "dedaluslabs/test", + "secret-credential", + {"model": "dedaluslabs/test", "tool_choice": "auto", "opaque": NoCopyValue()}, + ) + + assert "tool_choice" not in result + assert result["hooked"] is True From 55d673744cc93bd07abaa50d59a8989747ffaf9c Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 09:54:47 +0200 Subject: [PATCH 087/182] fix(logging): harden trace redaction and done events Closes the remaining Phase 2b heavy-review findings by shape-logging field-cache error payloads, removing raw cached samples from trace data, broadening string redaction for quoted header forms, and scrubbing Responses stream failure event traces. Also avoids trace-only response/deepcopy work when tracing is disabled and records Responses stream DONE sentinels without changing emitted SSE bytes. Tests: pytest tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py tests/test_adapter_registry.py tests/test_field_cache_engine.py tests/test_native_provider_executor.py tests/test_responses_service.py tests/test_responses_streaming.py tests/test_protocol_operation_model.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- src/rotator_library/client/transforms.py | 2 +- src/rotator_library/field_cache/engine.py | 2 +- src/rotator_library/responses/service.py | 5 ++++- src/rotator_library/transform_trace.py | 2 +- tests/test_field_cache_engine.py | 22 ++++++++++++++++++++++ tests/test_responses_streaming.py | 3 ++- tests/test_transform_trace.py | 3 ++- 7 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/rotator_library/client/transforms.py b/src/rotator_library/client/transforms.py index 8538c069f..3b9558f93 100644 --- a/src/rotator_library/client/transforms.py +++ b/src/rotator_library/client/transforms.py @@ -191,7 +191,7 @@ async def apply( credential_id=credential_id, transport=transport, changed_from_previous=(before != kwargs) if before is not None else None, - metadata={"model_options": deepcopy(model_options), **trace_metadata}, + metadata={"model_options": deepcopy(model_options) if transaction_logger else None, **trace_metadata}, ) # 4. Apply LiteLLM conversion if config available diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index 7b55cc9f0..b8bee8de1 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -262,7 +262,7 @@ def _log_error(self, transaction_logger: Optional[Any], pass_name: str, error: B transaction_logger.log_transform_error( pass_name, error, - payload=payload, + payload=_payload_shape(payload), stage="adapter", metadata={"rule_name": rule.name, "path": rule.path, "mode": rule.mode}, ) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 69642429f..ec51c8908 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -90,7 +90,8 @@ async def create_response( ) chat_response = await client.acompletion(request=request, **chat_kwargs) - self._trace(transaction_logger, "responses_bridge_chat_response", self._response_to_dict(chat_response), direction="response", stage="provider") + if transaction_logger: + self._trace(transaction_logger, "responses_bridge_chat_response", self._response_to_dict(chat_response), direction="response", stage="provider") response_payload = self.bridge.from_chat_response(chat_response, unified) self._trace(transaction_logger, "responses_parsed_response", response_payload, direction="response", stage="protocol") @@ -240,6 +241,7 @@ async def stream_response( ) self._trace(transaction_logger, "responses_stream_event_completed", completed, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.format_event("response.completed", completed) + self._trace(transaction_logger, "stream_done_event", {"raw": formatter.done()}, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.done() except Exception as exc: monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) @@ -255,6 +257,7 @@ async def stream_response( metadata={"transport": "sse", "failed": True}, ) yield formatter.format_event("response.failed", failed) + self._trace(transaction_logger, "stream_done_event", {"raw": formatter.done()}, direction="stream", stage="final", metadata={"transport": "sse", "failed": True}) yield formatter.done() async def get_response(self, response_id: str) -> dict[str, Any]: diff --git a/src/rotator_library/transform_trace.py b/src/rotator_library/transform_trace.py index 00144facb..8e74c6303 100644 --- a/src/rotator_library/transform_trace.py +++ b/src/rotator_library/transform_trace.py @@ -47,7 +47,7 @@ ) _SENSITIVE_TEXT_RE = re.compile( - r"(?im)\b(authorization|proxy-authorization|x-api-key|x-goog-api-key|openai-api-key|api[_-]?key|access[_-]?token|refresh[_-]?token|client[_-]?secret|cookie|set-cookie)\b(\s*[:=]\s*)([^\r\n]+)" + r"(?im)\b(authorization|proxy-authorization|x-api-key|x-goog-api-key|openai-api-key|api[_-]?key|access[_-]?token|refresh[_-]?token|client[_-]?secret|cookie|set-cookie)\b(['\"]?\s*[:=]\s*['\"]?)([^'\"\r\n,}]+)" ) diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py index a911951e6..8a8150d15 100644 --- a/tests/test_field_cache_engine.py +++ b/tests/test_field_cache_engine.py @@ -185,6 +185,28 @@ async def test_field_cache_trace_omits_raw_sample_values(tmp_path) -> None: assert after_entry["metadata"]["sample_value_types"] == ["str"] +@pytest.mark.asyncio +async def test_field_cache_error_trace_omits_raw_payload_values(tmp_path) -> None: + class FailingStore(InMemoryFieldCacheStore): + async def set(self, key, value, *, append=False): + raise RuntimeError("store failed") + + logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) + engine = FieldCacheEngine([_reasoning_rule()], store=FailingStore()) + + with pytest.raises(RuntimeError): + await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "provider-signature-secret"}}]}, + _context(), + transaction_logger=logger, + ) + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "provider-signature-secret" not in trace_text + assert "payload_type" in trace_text + + @pytest.mark.asyncio async def test_field_cache_traces_start_and_complete_even_without_matching_rules(tmp_path) -> None: logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 8eb408b3d..54f2d7c98 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -96,7 +96,7 @@ async def test_stream_response_failure_trace_scrubs_header_like_secret_text(tmp_ chunk async for chunk in service.stream_response( {"model": "gpt-test", "input": "Hello", "stream": True}, - FailingStreamingClient("Authorization: Bearer secret-token"), + FailingStreamingClient("{'Authorization': 'Bearer secret-token'}"), transaction_logger=logger, ) ] @@ -128,6 +128,7 @@ async def test_responses_stream_records_common_stream_metrics(tmp_path) -> None: assert "stream_first_visible_output" in trace_text assert "stream_completed" in trace_text assert "stream_metrics_final" in trace_text + assert "stream_done_event" in trace_text assert "responses_stream_event_created" in trace_text assert "responses_stream_event_output_text_delta" in trace_text assert "responses_stream_event_completed" in trace_text diff --git a/tests/test_transform_trace.py b/tests/test_transform_trace.py index 928be5720..b3ed8928d 100644 --- a/tests/test_transform_trace.py +++ b/tests/test_transform_trace.py @@ -38,13 +38,14 @@ def test_sanitize_for_trace_redacts_sensitive_keys_recursively() -> None: def test_scrub_sensitive_text_targets_header_like_fragments_only() -> None: - text = "normal token text remains\nAuthorization: Bearer abc123\nset-cookie: sid=secret" + text = "normal token text remains\nAuthorization: Bearer abc123\nset-cookie: sid=secret\n{'Authorization': 'Bearer quoted'}" scrubbed = scrub_sensitive_text(text) assert "normal token text remains" in scrubbed assert "Authorization: [REDACTED]" in scrubbed assert "set-cookie: [REDACTED]" in scrubbed + assert "quoted" not in scrubbed def test_sanitize_for_trace_extracts_sdk_like_objects_before_repr() -> None: From 0880df1eea297ca9167ec10b7d8a1060639bbf14 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 10:04:41 +0200 Subject: [PATCH 088/182] fix(logging): close phase 2b acceptance gaps Updates the older field-cache trace test to expect the new start/complete summary boundaries and avoids trace-only Responses conversions and metrics payload construction when tracing is disabled. This closes the final Phase 2b acceptance findings while preserving route bodies and SSE output. Tests: pytest tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py tests/test_adapter_registry.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_responses_service.py tests/test_responses_streaming.py tests/test_protocol_operation_model.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- src/rotator_library/responses/service.py | 107 ++++++++++++----------- tests/test_field_cache_trace.py | 12 ++- 2 files changed, 65 insertions(+), 54 deletions(-) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index ec51c8908..1a32d385b 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -75,7 +75,8 @@ async def create_response( self._trace(transaction_logger, "responses_raw_request", raw_request, direction="request", stage="client") unified = self.protocol.parse_request(raw_request, ProtocolContext(source_protocol="responses")) - self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") + if transaction_logger: + self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) @@ -122,7 +123,8 @@ async def stream_response( stream_request["stream"] = True self._trace(transaction_logger, "responses_raw_request", stream_request, direction="request", stage="client") unified = self.protocol.parse_request(stream_request, ProtocolContext(source_protocol="responses", transport="sse")) - self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") + if transaction_logger: + self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) @@ -141,14 +143,15 @@ async def stream_response( usage = None item_started = False monitor = StreamMonitor(clock=time.monotonic) - self._trace( - transaction_logger, - "stream_started", - {"event": StreamEvent("started", protocol="responses").to_dict(), "metrics": monitor.metrics.to_dict()}, - direction="stream", - stage="client", - metadata={"transport": "sse"}, - ) + if transaction_logger: + self._trace( + transaction_logger, + "stream_started", + {"event": StreamEvent("started", protocol="responses").to_dict(), "metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="client", + metadata={"transport": "sse"}, + ) created = response_created_payload(response_id, unified.model) self._trace(transaction_logger, "responses_stream_event_created", created, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.format_event("response.created", created) @@ -157,14 +160,15 @@ async def stream_response( async for raw_chunk in chat_stream: if monitor.metrics.first_byte_at is None: monitor.record_event(StreamEvent("raw_chunk", protocol="responses", raw=raw_chunk)) - self._trace( - transaction_logger, - "stream_first_byte", - {"metrics": monitor.metrics.to_dict()}, - direction="stream", - stage="provider", - metadata={"transport": "sse"}, - ) + if transaction_logger: + self._trace( + transaction_logger, + "stream_first_byte", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="provider", + metadata={"transport": "sse"}, + ) self._trace(transaction_logger, "raw_chat_bridge_stream_chunk", raw_chunk, direction="stream", stage="provider") chunk = parse_chat_sse_chunk(raw_chunk) if not chunk or chunk.get("type") == "done": @@ -197,14 +201,15 @@ async def stream_response( ) ) if first_visible: - self._trace( - transaction_logger, - "stream_first_visible_output", - {"event": event, "metrics": monitor.metrics.to_dict()}, - direction="stream", - stage="final", - metadata={"transport": "sse"}, - ) + if transaction_logger: + self._trace( + transaction_logger, + "stream_first_visible_output", + {"event": event, "metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": "sse"}, + ) self._trace(transaction_logger, "responses_stream_event_output_text_delta", event, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.format_event("response.output_text.delta", event) @@ -223,22 +228,23 @@ async def stream_response( else: self._trace(transaction_logger, "responses_store_skipped", {"response_id": completed.get("id")}, direction="metadata", stage="final") monitor.complete() - self._trace( - transaction_logger, - "stream_completed", - {"metrics": monitor.metrics.to_dict()}, - direction="stream", - stage="final", - metadata={"transport": "sse"}, - ) - self._trace( - transaction_logger, - "stream_metrics_final", - {"metrics": monitor.metrics.to_dict()}, - direction="stream", - stage="final", - metadata={"transport": "sse"}, - ) + if transaction_logger: + self._trace( + transaction_logger, + "stream_completed", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": "sse"}, + ) + self._trace( + transaction_logger, + "stream_metrics_final", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": "sse"}, + ) self._trace(transaction_logger, "responses_stream_event_completed", completed, direction="stream", stage="final", metadata={"transport": "sse"}) yield formatter.format_event("response.completed", completed) self._trace(transaction_logger, "stream_done_event", {"raw": formatter.done()}, direction="stream", stage="final", metadata={"transport": "sse"}) @@ -248,14 +254,15 @@ async def stream_response( failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": "sse"}, scrub_strings=True) - self._trace( - transaction_logger, - "stream_metrics_final", - {"metrics": monitor.metrics.to_dict()}, - direction="stream", - stage="final", - metadata={"transport": "sse", "failed": True}, - ) + if transaction_logger: + self._trace( + transaction_logger, + "stream_metrics_final", + {"metrics": monitor.metrics.to_dict()}, + direction="stream", + stage="final", + metadata={"transport": "sse", "failed": True}, + ) yield formatter.format_event("response.failed", failed) self._trace(transaction_logger, "stream_done_event", {"raw": formatter.done()}, direction="stream", stage="final", metadata={"transport": "sse", "failed": True}) yield formatter.done() diff --git a/tests/test_field_cache_trace.py b/tests/test_field_cache_trace.py index 775f033bc..8dd42e718 100644 --- a/tests/test_field_cache_trace.py +++ b/tests/test_field_cache_trace.py @@ -55,15 +55,19 @@ async def test_field_cache_extract_and_inject_emit_before_after_trace_entries(tm pass_names = [entry["pass_name"] for entry in entries] assert updated["messages"][-1]["reasoning_content"] == "hidden" assert pass_names == [ + "field_cache_extraction_start", "before_field_cache_extraction", "after_field_cache_extraction", + "field_cache_extraction_complete", + "field_cache_injection_start", "before_field_cache_injection", "after_field_cache_injection", + "field_cache_injection_complete", ] - assert entries[1]["metadata"]["rule_name"] == "reasoning_content" - assert entries[1]["metadata"]["matched"] == 1 - assert entries[3]["metadata"]["hit"] is True - assert entries[3]["metadata"]["changed"] is True + assert entries[2]["metadata"]["rule_name"] == "reasoning_content" + assert entries[2]["metadata"]["matched"] == 1 + assert entries[6]["metadata"]["hit"] is True + assert entries[6]["metadata"]["changed"] is True @pytest.mark.asyncio From 1cc1aba2049eb1a139218a40ceaa3c2491bf25a5 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 10:19:37 +0200 Subject: [PATCH 089/182] fix(logging): avoid responses trace work when disabled Guards the remaining Responses trace-only usage conversion and previous-response payload conversion paths so they do not run when no transaction logger is attached. Adds regression tests proving usage conversion and previous-response to_dict work are skipped without tracing, closing the final Phase 2b heavy-review finding. Tests: pytest tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py tests/test_adapter_registry.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_responses_service.py tests/test_responses_streaming.py tests/test_protocol_operation_model.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py --- src/rotator_library/responses/service.py | 29 +++++++++++--------- tests/test_responses_service.py | 34 ++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 1a32d385b..11714041b 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -297,19 +297,20 @@ async def _load_previous_response(self, response_id: Optional[str], transaction_ parent = await self.store.get(response_id) if parent is None: raise ResponsesServiceError(f"Previous response not found: {response_id}", status_code=404, error_type="not_found_error") - self._trace( - transaction_logger, - "responses_previous_response_loaded", - parent.to_dict(), - direction="metadata", - stage="adapter", - metadata={ - "previous_response_id": response_id, - "output_count": len(parent.output_items), - "input_item_count": len(parent.input_items), - "bridge_context_expanded": True, - }, - ) + if transaction_logger: + self._trace( + transaction_logger, + "responses_previous_response_loaded", + parent.to_dict(), + direction="metadata", + stage="adapter", + metadata={ + "previous_response_id": response_id, + "output_count": len(parent.output_items), + "input_item_count": len(parent.input_items), + "bridge_context_expanded": True, + }, + ) return parent def _stored_response( @@ -379,6 +380,8 @@ def _trace_responses_usage( ) -> None: """Trace normalized Responses usage without changing stored payloads.""" + if not transaction_logger: + return usage = response_payload.get("usage") if isinstance(response_payload, dict) else None if not usage: return diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py index 4495d95f1..52585e589 100644 --- a/tests/test_responses_service.py +++ b/tests/test_responses_service.py @@ -4,6 +4,7 @@ import pytest +import rotator_library.responses.service as responses_service_module from rotator_library.responses import InMemoryResponsesStore, ResponsesService, ResponsesServiceError, StoredResponse from rotator_library.transaction_logger import TransactionLogger @@ -121,3 +122,36 @@ async def test_service_emits_transform_trace_passes(tmp_path) -> None: "responses_stored_response", "responses_final_response", ] + + +def test_trace_responses_usage_returns_before_conversion_without_logger(monkeypatch) -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + def fail_extract(*args, **kwargs): + raise AssertionError("usage conversion should be skipped when tracing is disabled") + + monkeypatch.setattr(responses_service_module, "extract_usage_record", fail_extract) + + service._trace_responses_usage(None, {"usage": {"input_tokens": 1}}, "gpt-test", source="test") + + +@pytest.mark.asyncio +async def test_previous_response_trace_payload_skipped_without_logger() -> None: + class Parent: + id = "resp_parent" + response = {"output": []} + output_items = [] + input_items = [] + + def to_dict(self): + raise AssertionError("previous response trace payload should not be built without a logger") + + class Store(InMemoryResponsesStore): + async def get(self, response_id): + return Parent() + + service = ResponsesService(store=Store()) + + parent = await service._load_previous_response("resp_parent", None) + + assert parent.id == "resp_parent" From 0810c189e396946d1d2a40756ffc4285d0a1989b Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 10:24:46 +0200 Subject: [PATCH 090/182] docs(experimental): plan field cache runtime correction Adds the corrective Phase 3b plan for making field-cache behavior runtime-complete instead of mostly declarative. The plan covers process-local default native persistence, turn-aware and per-tool-call modes, insert and TTL behavior, JSON rule merging, tests, and dual-agent review requirements. User-facing reports and unrelated dirty files remain uncommitted. --- .../phase-3b-field-cache-runtime-semantics.md | 147 ++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 docs/experimental/phase-3b-field-cache-runtime-semantics.md diff --git a/docs/experimental/phase-3b-field-cache-runtime-semantics.md b/docs/experimental/phase-3b-field-cache-runtime-semantics.md new file mode 100644 index 000000000..d65494933 --- /dev/null +++ b/docs/experimental/phase-3b-field-cache-runtime-semantics.md @@ -0,0 +1,147 @@ +# Phase 3b: Field-Cache Runtime Semantics And Persistence + +## Goal + +Correct the Phase 3 audit gap. Phase 3 built the adapter and field-cache foundation, but the audit found that the system is still partly declarative: `last_user_turn`, `last_assistant_turn`, and `per_tool_call` are validated/declared but not semantically implemented enough; the native provider executor creates a fresh `InMemoryFieldCacheStore` by default so cached values do not persist across native requests; JSON-configured field-cache rules are parsed but not merged into runtime provider declarations. + +## Non-Goals + +- Do not introduce SQLite or any database. +- Do not replace `SessionTracker`, `UsageManager`, or provider retry logic. +- Do not wire non-native LiteLLM paths through field-cache rules in this phase. +- Do not make field cache a substitute for session affinity. It preserves provider protocol state only. +- Do not implement a full external admin UI/config editor. +- Do not commit user-facing reports. +- Do not touch unrelated dirty files (`ARCHITECTURE.md`, `STRUCTURE.md`, `.opencode/`, `docs/issues/`, old phase reports). + +## Current Code State + +- `FieldCacheRule.mode` allows `last`, `all`, `last_user_turn`, `last_assistant_turn`, and `per_tool_call`. +- `FieldCacheEngine._store_values()` currently treats `last_user_turn` and `last_assistant_turn` exactly like `last`. +- `per_tool_call` currently extracts tool IDs from each extracted value, which only works when the selected value itself contains the tool ID. It cannot correlate values with a sibling or parent tool-call ID from the original payload. +- `FieldCacheInjection` has `as_list`, `when_missing_only`, and `insert`, but engine injection does not use `insert` and does not support selecting a per-tool-call value at injection time. +- `NativeProviderExecutor.__init__(field_cache_store=None)` passes `None` into each `FieldCacheEngine`, causing each request to get a new `InMemoryFieldCacheStore`. That means default native field-cache persistence is per-engine/per-request, not process-level. +- `ProviderCacheFieldStore` exists and no SQLite is present. +- `ExperimentalConfig.parse_field_cache_rules()` can parse JSON rules, but `RequestExecutor._native_context()` currently only uses `plugin.get_field_cache_rules(model)`. +- `NativeProviderContext.field_cache_context()` includes provider/model/credential/session/conversation/classifier plus metadata; metadata is the right place to carry current request payload/turn/tool-call hints without expanding core scope dimensions. +- Phase 2b hardened traces so field-cache logs expose shapes and counts, not raw cached values. + +## Implementation Plan + +### 1. Process-Local Default Native Store + +- Change `NativeProviderExecutor` so a missing `field_cache_store` creates one shared `InMemoryFieldCacheStore` in the executor instance, not a new store inside every `FieldCacheEngine`. +- Keep injected stores honored exactly as before. +- Add tests proving two calls through the same executor can extract on response and inject on a later request using the default store. +- Do not persist to disk by default; process-local is enough for runtime continuity and respects the no-SQLite rule. + +### 2. Value Envelopes For Contextual Modes + +- Add a small internal stored representation for contextual modes, likely dicts with fields such as `value`, `role`, `turn_index`, `tool_call_id`, and `source_path`. +- Keep stored values JSON-serializable for `ProviderCacheFieldStore`. +- Do not expose envelopes to injected payloads except where needed; injection should unwrap `value`. +- Preserve current external behavior for `last` and `all` as much as possible. + +### 3. Turn-Aware Extraction + +- Implement real `last_user_turn` and `last_assistant_turn` semantics. +- Minimum supported shapes: + - OpenAI/Responses-like `messages[*]` objects with `role`. + - Anthropic-like content lists nested under message objects with `role`. + - Generic payloads where a rule provides metadata hints. +- Add rule metadata hints: + - `turn_container_path`: path to turn/message list, default inferred from common roots such as `messages`. + - `turn_role_path`: role field inside each turn, default `role`. + - `turn_value_path`: value path relative to each turn when inference is better than global path. +- If no turn context can be inferred, skip with `reason="turn_context_not_found"` rather than silently behaving like `last`. +- Store only values from the latest matching user/assistant turn. +- Add tests for user turn, assistant turn, skip/no-op, and metadata-configured relative extraction. + +### 4. Per-Tool-Call Correlation And Injection + +- Keep requiring `metadata.tool_call_id_path` for `per_tool_call`. +- Support existing value-relative extraction when each selected value contains the tool ID. +- Add payload-relative correlation with metadata: + - `tool_container_path` + - `tool_call_id_path` relative to each tool container + - `tool_value_path` relative to each tool container +- Store a mapping of `tool_call_id -> value`. +- Add injection selection: + - If `metadata.inject_tool_call_id_path` exists, extract current tool IDs from target payload and inject matching values. + - If `FieldCacheContext.metadata["tool_call_id"]` exists, inject that value. + - If `inject.as_list=True`, inject all matching values as a list. + - Otherwise skip ambiguous maps with `reason="tool_call_id_not_found"` or `reason="ambiguous_tool_call_values"`. +- Add tests for sibling tool ID/value extraction and target-specific injection. + +### 5. Honor `FieldCacheInjection.insert` + +- Wire `rule.inject.insert` into `inject_path()`. +- Add tests for list insertion where path targets a list index or append-like position if supported by the path engine. +- If the path engine does not support safe insertion yet, add minimal support or explicitly validate/reject unsafe insert paths with clear errors. + +### 6. TTL Handling Without New Persistence + +- `FieldCacheRule.ttl_seconds` exists but stores ignore it. +- Add TTL support to `InMemoryFieldCacheStore`. +- Keep `ProviderCacheFieldStore` compatible: + - If the injected provider cache supports TTL arguments, pass them. + - If not, wrap values with expiry metadata and enforce expiry on `get()`. +- Update `FieldCacheStore` protocol with optional TTL on `set`/`append` only if safe; otherwise add internal engine helper methods to avoid breaking third-party stores. +- Add tests for memory-store expiry using a fake clock if possible. +- No SQLite. + +### 7. Merge JSON-Configured Rules Into Native Runtime Context + +- In `RequestExecutor._native_context()`, load optional `ExperimentalConfig` and merge `parse_field_cache_rules(config, provider, model)` with provider-declared rules. +- Provider-declared rules should come first; JSON rules append and can add/override by rule name. +- Override policy: + - If a JSON rule name matches a provider rule name, JSON replaces that rule. + - Otherwise append. +- Environment remains primary for config path and JSON secrets remain rejected by the existing loader. +- Add tests for native context rule merge without real network calls or credentials. +- Avoid top-level imports that create cycles; use local imports where needed. + +### 8. Tests + +- Engine semantics tests for contextual modes. +- Native executor test proving default process-local persistence across separate `execute()` calls. +- Config merge test for JSON field-cache rules. +- TTL store tests. +- Existing trace and protocol regression tests where touched. + +Focused suites: + +- `tests/test_field_cache_paths.py` +- `tests/test_field_cache_engine.py` +- `tests/test_field_cache_trace.py` +- `tests/test_native_provider_executor.py` +- `tests/test_experimental_config.py` +- relevant provider declaration tests +- Phase 1b/2b regression subsets if field-cache/protocol trace serialization is touched + +### 9. Documentation And Comments + +- Update docstrings in `FieldCacheEngine`, `FieldCacheRule`, and store classes to explain mode semantics. +- Document when turn/tool modes skip rather than fallback. +- Document process-local default store vs injected production store. +- Keep comments focused on algorithmic rationale: turn context inference, tool-call correlation, TTL handling, and JSON rule merge precedence. + +## Commit Checkpoints + +1. `docs(experimental): plan field cache runtime correction`. +2. Default persistent native store. +3. Turn/tool-call mode semantics. +4. TTL/store behavior. +5. JSON rule merge/runtime wiring. +6. Review fixes after `explore` and `explore-heavy`. +7. User-facing Phase 3b report, uncommitted. + +## Acceptance Criteria + +- Native field cache persists across requests by default within one executor instance. +- `last_user_turn` and `last_assistant_turn` have real turn-aware semantics and skip safely when context is unavailable. +- `per_tool_call` can correlate tool IDs with sibling values and inject the right cached value by current tool ID. +- `insert`, `ttl_seconds`, and JSON-configured rules are implemented or explicitly validated with clear errors if a requested shape is unsafe. +- No raw cached values leak in traces after Phase 2b hardening. +- No SQLite or new database is introduced. +- Focused tests and dual-agent review pass. From ba79340baf74ac80ce06e9f81a6514a6282fc46c Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 10:26:32 +0200 Subject: [PATCH 091/182] fix(field-cache): persist native cache per executor Changes the native provider executor default field-cache store from a per-request in-memory store to a process-local store owned by the executor instance. This preserves native protocol state across requests without adding SQLite or any new persistence dependency, while injected stores remain supported for production use. Tests: pytest tests/test_native_provider_executor.py --- .../native_provider/executor.py | 10 ++--- tests/test_native_provider_executor.py | 38 +++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index a2532e3fa..95bebc750 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -8,7 +8,7 @@ from typing import Any, AsyncGenerator from ..adapters import get_adapter, run_adapter_chain -from ..field_cache import FieldCacheEngine +from ..field_cache import FieldCacheEngine, InMemoryFieldCacheStore from ..protocols import get_protocol from ..usage.accounting import extract_usage_record from .context import NativeProviderContext @@ -19,13 +19,13 @@ class NativeProviderExecutor: """Run one native provider request through protocol/adapter/cache passes. - This executor is intentionally not wired into the live client path yet. Phase - 5 providers can test native behavior here first, then later phases can route - declared providers into it without disturbing undeclared providers. + The default field-cache store is process-local per executor. That preserves + provider protocol state across native requests without adding a database; + production callers can still inject a persistent store when needed. """ def __init__(self, *, field_cache_store: Any = None) -> None: - self.field_cache_store = field_cache_store + self.field_cache_store = field_cache_store or InMemoryFieldCacheStore() async def execute(self, raw_request: dict[str, Any], context: NativeProviderContext, transport: NativeHTTPTransport) -> dict[str, Any]: """Execute a non-streaming native provider request.""" diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index 20593d96a..c8ddc7e4a 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -90,6 +90,44 @@ async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tm assert "final_client_response" in pass_names +@pytest.mark.asyncio +async def test_native_provider_default_field_cache_persists_across_requests() -> None: + rule = FieldCacheRule( + name="reasoning", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.reasoning_content"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + first_client = FakeHTTPClient( + { + "id": "chat_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "cached-state"}}], + } + ) + second_client = FakeHTTPClient( + { + "id": "chat_2", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "ok"}}], + } + ) + executor = NativeProviderExecutor() + + await executor.execute({"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]}, context, NativeHTTPTransport(first_client)) + await executor.execute({"model": "gpt-test", "messages": [{"role": "user", "content": "again"}]}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["reasoning_content"] == "cached-state" + + @pytest.mark.asyncio async def test_native_provider_executor_logs_transform_errors(tmp_path) -> None: logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) From 8d60abda92ac89cb45892d325f71a40ba175aadc Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 10:30:53 +0200 Subject: [PATCH 092/182] feat(field-cache): implement runtime cache modes Adds real runtime semantics for turn-aware and per-tool-call field-cache modes, including safe skips when turn or tool context is unavailable. Also wires insert injection, adds TTL support to in-memory and provider-cache field stores, and preserves existing last/all behavior for simple rules. Tests: pytest tests/test_field_cache_paths.py tests/test_field_cache_engine.py --- src/rotator_library/field_cache/engine.py | 162 +++++++++++++++++++--- src/rotator_library/field_cache/paths.py | 11 +- src/rotator_library/field_cache/store.py | 64 +++++++-- tests/test_field_cache_engine.py | 145 ++++++++++++++++++- tests/test_field_cache_paths.py | 9 ++ 5 files changed, 351 insertions(+), 40 deletions(-) diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index b8bee8de1..baff02c98 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -55,11 +55,10 @@ def build_cache_key(rule: FieldCacheRule, context: FieldCacheContext) -> Optiona class FieldCacheEngine: """Apply field-cache extraction and injection rules. - The engine is isolated from request execution in Phase 3. It defaults to - copying payloads before injection so tests and future providers can reason - about mutations explicitly. Turn-specific and per-tool modes are validated - but limited until later phases add richer conversation indexing and - per-tool-call injection targets. + The engine preserves provider protocol state; it is not session tracking. + It defaults to copying payloads before injection so providers can opt into + mutation explicitly. Turn and tool-call modes skip safely when the requested + context cannot be inferred rather than silently falling back to `last`. """ def __init__(self, rules: Iterable[FieldCacheRule], store: Optional[FieldCacheStore] = None) -> None: @@ -103,9 +102,8 @@ async def extract( values = extract_path(payload, rule.path) operation.matched = len(values) operation.sample_values = _sample_values(values) - if values: - await self._store_values(rule, operation.cache_key, values) - operation.changed = True + if values or rule.mode in {"last_user_turn", "last_assistant_turn", "per_tool_call"}: + operation.changed = await self._store_values(rule, operation.cache_key, values, payload, operation) except Exception as exc: self._log_error(transaction_logger, "field_cache_extract", exc, payload, rule) raise @@ -146,12 +144,17 @@ async def inject( self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) continue operation.hit = True - value = cached if rule.inject.as_list or rule.mode == "all" else _last_value(cached) + value = self._injection_value(rule, cached, updated, context, operation) + if operation.skipped: + operations.append(operation) + self._trace(transaction_logger, "after_field_cache_injection", updated, rule, operation, target=target) + continue operation.changed = inject_path( updated, rule.inject.path, value, when_missing_only=rule.inject.when_missing_only, + insert=rule.inject.insert, ) operation.sample_values = _sample_values(value if isinstance(value, list) else [value]) except Exception as exc: @@ -168,24 +171,69 @@ def _rules_for_source(self, source: str) -> list[FieldCacheRule]: def _rules_for_injection(self, target: str) -> list[FieldCacheRule]: return [rule for rule in self.rules if rule.enabled and rule.inject and rule.inject.target == target] - async def _store_values(self, rule: FieldCacheRule, cache_key: str, values: list[Any]) -> None: + async def _store_values(self, rule: FieldCacheRule, cache_key: str, values: list[Any], payload: Any, operation: FieldCacheOperation) -> bool: if rule.mode == "all": - await self.store.append(cache_key, values) - return - if rule.mode in {"last", "last_user_turn", "last_assistant_turn"}: - await self.store.set(cache_key, values[-1]) - return + await self.store.append(cache_key, values, ttl_seconds=rule.ttl_seconds) + return True + if rule.mode == "last": + await self.store.set(cache_key, values[-1], ttl_seconds=rule.ttl_seconds) + return True + if rule.mode in {"last_user_turn", "last_assistant_turn"}: + role = "user" if rule.mode == "last_user_turn" else "assistant" + turn_values = _turn_values(rule, payload, role) + if not turn_values: + operation.skipped = True + operation.reason = "turn_context_not_found" + return False + operation.matched = len(turn_values) + operation.sample_values = _sample_values(turn_values) + await self.store.set(cache_key, turn_values[-1], ttl_seconds=rule.ttl_seconds) + return True if rule.mode == "per_tool_call": - tool_path = rule.metadata.get("tool_call_id_path") - stored = {} - for value in values: - tool_ids = extract_path(value, tool_path) if tool_path else [] - if tool_ids: - stored[str(tool_ids[0])] = value - await self.store.set(cache_key, stored) - return + stored = _tool_call_values(rule, payload, values) + if not stored: + operation.skipped = True + operation.reason = "tool_call_id_not_found" + return False + operation.matched = len(stored) + operation.sample_values = _sample_values(list(stored.values())) + await self.store.set(cache_key, stored, ttl_seconds=rule.ttl_seconds) + return True raise ValueError(f"Unsupported field-cache mode: {rule.mode}") + def _injection_value(self, rule: FieldCacheRule, cached: Any, payload: Any, context: FieldCacheContext, operation: FieldCacheOperation) -> Any: + """Select the cached value to inject for the rule's mode. + + Per-tool-call maps require a current tool-call ID so the engine never + injects an arbitrary provider signature into the wrong tool result. + """ + + if rule.mode == "per_tool_call": + if not isinstance(cached, dict): + operation.skipped = True + operation.reason = "invalid_tool_call_cache" + return None + ids = _injection_tool_ids(rule, payload, context) + if not ids: + operation.skipped = True + operation.reason = "tool_call_id_not_found" + return None + matches = [cached[str(tool_id)] for tool_id in ids if str(tool_id) in cached] + if not matches: + operation.skipped = True + operation.reason = "tool_call_cache_miss" + return None + if rule.inject and rule.inject.as_list: + return matches + if len(matches) == 1: + return matches[0] + operation.skipped = True + operation.reason = "ambiguous_tool_call_values" + return None + if rule.inject and (rule.inject.as_list or rule.mode == "all"): + return cached if isinstance(cached, list) else [cached] + return _last_value(cached) + def _trace( self, transaction_logger: Optional[Any], @@ -274,6 +322,74 @@ def _last_value(value: Any) -> Any: return value +def _turn_values(rule: FieldCacheRule, payload: Any, role: str) -> list[Any]: + """Return values from the latest turn matching `role`. + + Rules can provide explicit turn paths for provider-specific payloads. The + default handles the common `messages[*]` shape used by OpenAI-compatible and + Responses-like requests. + """ + + container_path = rule.metadata.get("turn_container_path", "messages") + role_path = rule.metadata.get("turn_role_path", "role") + value_path = rule.metadata.get("turn_value_path") or _message_relative_path(rule.path, container_path) + if not value_path: + return [] + turns = extract_path(payload, str(container_path)) + if len(turns) == 1 and isinstance(turns[0], list): + turns = turns[0] + latest: list[Any] = [] + for turn in turns: + roles = extract_path(turn, str(role_path)) + if roles and str(roles[0]) == role: + values = extract_path(turn, str(value_path)) + if values: + latest = values + return latest + + +def _message_relative_path(path: str, container_path: str) -> Optional[str]: + prefixes = (f"{container_path}.*.", f"{container_path}[-1].") + for prefix in prefixes: + if path.startswith(prefix): + return path[len(prefix) :] + return None + + +def _tool_call_values(rule: FieldCacheRule, payload: Any, values: list[Any]) -> dict[str, Any]: + """Correlate cached values to provider tool-call IDs.""" + + container_path = rule.metadata.get("tool_container_path") + tool_id_path = rule.metadata.get("tool_call_id_path") + tool_value_path = rule.metadata.get("tool_value_path") + stored: dict[str, Any] = {} + if container_path and tool_value_path: + containers = extract_path(payload, str(container_path)) + if len(containers) == 1 and isinstance(containers[0], list): + containers = containers[0] + for container in containers: + tool_ids = extract_path(container, str(tool_id_path)) if tool_id_path else [] + tool_values = extract_path(container, str(tool_value_path)) + if tool_ids and tool_values: + stored[str(tool_ids[0])] = tool_values[-1] + return stored + for value in values: + tool_ids = extract_path(value, str(tool_id_path)) if tool_id_path else [] + if tool_ids: + stored[str(tool_ids[0])] = value + return stored + + +def _injection_tool_ids(rule: FieldCacheRule, payload: Any, context: FieldCacheContext) -> list[str]: + configured = context.metadata.get("tool_call_id") + if configured: + return [str(configured)] + inject_path_value = rule.metadata.get("inject_tool_call_id_path") + if inject_path_value: + return [str(value) for value in extract_path(payload, str(inject_path_value))] + return [] + + def _trace_direction(pass_name: str, source: str, metadata: dict[str, Any]) -> str: if "injection" in pass_name: target = metadata.get("target") diff --git a/src/rotator_library/field_cache/paths.py b/src/rotator_library/field_cache/paths.py index 96997ecad..8f8a6462b 100644 --- a/src/rotator_library/field_cache/paths.py +++ b/src/rotator_library/field_cache/paths.py @@ -103,12 +103,14 @@ def _extract_token(value: Any, token: PathToken) -> list[Any]: raise FieldCachePathError(f"Unknown path token: {token}") -def inject_path(payload: Any, path: str, injected_value: Any, *, when_missing_only: bool = False) -> bool: +def inject_path(payload: Any, path: str, injected_value: Any, *, when_missing_only: bool = False, insert: bool = False) -> bool: """Inject a value at a simple path, creating dict containers as needed. Wildcard injection is rejected because creating multiple branches can be ambiguous and provider-specific. List indexes must already exist; this keeps mutation predictable for message-tail use cases like `messages[-1].field`. + `insert=True` is intentionally limited to final list-index tokens so rules + cannot accidentally create provider-specific list structures. """ tokens = parse_path(path) @@ -124,6 +126,8 @@ def inject_path(payload: Any, path: str, injected_value: Any, *, when_missing_on raise FieldCachePathError(f"Cannot inject key {token.value!r} into non-dict value") key = str(token.value) if is_last: + if insert: + raise FieldCachePathError("insert=True requires a final list index token") if when_missing_only and key in current: return False changed = current.get(key) != injected_value @@ -140,6 +144,11 @@ def inject_path(payload: Any, path: str, injected_value: Any, *, when_missing_on if not (-len(current) <= list_index < len(current)): raise FieldCachePathError(f"List index out of range for field-cache injection: {list_index}") if is_last: + if insert: + if when_missing_only: + return False + current.insert(list_index, injected_value) + return True if when_missing_only and current[list_index] is not None: return False changed = current[list_index] != injected_value diff --git a/src/rotator_library/field_cache/store.py b/src/rotator_library/field_cache/store.py index 9394bebdd..158898e49 100644 --- a/src/rotator_library/field_cache/store.py +++ b/src/rotator_library/field_cache/store.py @@ -6,8 +6,9 @@ from __future__ import annotations import json +import time from copy import deepcopy -from typing import Any, Protocol +from typing import Any, Callable, Protocol from ..protocols import serialize_value @@ -17,35 +18,59 @@ class FieldCacheStore(Protocol): async def get(self, key: str) -> Any: ... - async def set(self, key: str, value: Any) -> None: ... + async def set(self, key: str, value: Any, *, ttl_seconds: int | None = None) -> None: ... - async def append(self, key: str, values: list[Any]) -> list[Any]: ... + async def append(self, key: str, values: list[Any], *, ttl_seconds: int | None = None) -> list[Any]: ... async def clear(self) -> None: ... class InMemoryFieldCacheStore: - """Simple process-local store for tests and lightweight runtime use.""" + """Simple process-local store with optional per-key TTL. - def __init__(self) -> None: + This is the default native runtime store. It intentionally persists only for + the Python process and avoids a database while still preserving protocol + state across requests handled by the same executor instance. + """ + + def __init__(self, *, clock: Callable[[], float] | None = None) -> None: self._values: dict[str, Any] = {} + self._expires_at: dict[str, float] = {} + self._clock = clock or time.monotonic async def get(self, key: str) -> Any: + if self._is_expired(key): + self._values.pop(key, None) + self._expires_at.pop(key, None) + return None return deepcopy(self._values.get(key)) - async def set(self, key: str, value: Any) -> None: + async def set(self, key: str, value: Any, *, ttl_seconds: int | None = None) -> None: self._values[key] = deepcopy(value) + self._set_expiry(key, ttl_seconds) - async def append(self, key: str, values: list[Any]) -> list[Any]: - current = self._values.get(key) + async def append(self, key: str, values: list[Any], *, ttl_seconds: int | None = None) -> list[Any]: + current = await self.get(key) if not isinstance(current, list): current = [] current = deepcopy(current) + deepcopy(values) self._values[key] = current + self._set_expiry(key, ttl_seconds) return deepcopy(current) async def clear(self) -> None: self._values.clear() + self._expires_at.clear() + + def _set_expiry(self, key: str, ttl_seconds: int | None) -> None: + if ttl_seconds is None or ttl_seconds <= 0: + self._expires_at.pop(key, None) + return + self._expires_at[key] = self._clock() + ttl_seconds + + def _is_expired(self, key: str) -> bool: + expires_at = self._expires_at.get(key) + return expires_at is not None and expires_at <= self._clock() class ProviderCacheFieldStore: @@ -63,17 +88,26 @@ async def get(self, key: str) -> Any: raw = await self._cache.retrieve_async(key) if raw is None: return None - return json.loads(raw) - - async def set(self, key: str, value: Any) -> None: - await self._cache.store_async(key, json.dumps(serialize_value(value), ensure_ascii=False)) - - async def append(self, key: str, values: list[Any]) -> list[Any]: + value = json.loads(raw) + if isinstance(value, dict) and value.get("__field_cache_wrapped__") is True: + expires_at = value.get("expires_at") + if isinstance(expires_at, (int, float)) and expires_at <= time.time(): + return None + return value.get("value") + return value + + async def set(self, key: str, value: Any, *, ttl_seconds: int | None = None) -> None: + payload = serialize_value(value) + if ttl_seconds is not None and ttl_seconds > 0: + payload = {"__field_cache_wrapped__": True, "expires_at": time.time() + ttl_seconds, "value": payload} + await self._cache.store_async(key, json.dumps(payload, ensure_ascii=False)) + + async def append(self, key: str, values: list[Any], *, ttl_seconds: int | None = None) -> list[Any]: current = await self.get(key) if not isinstance(current, list): current = [] current = current + serialize_value(values) - await self.set(key, current) + await self.set(key, current, ttl_seconds=ttl_seconds) return current async def clear(self) -> None: diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py index 8a8150d15..3738dd425 100644 --- a/tests/test_field_cache_engine.py +++ b/tests/test_field_cache_engine.py @@ -188,7 +188,7 @@ async def test_field_cache_trace_omits_raw_sample_values(tmp_path) -> None: @pytest.mark.asyncio async def test_field_cache_error_trace_omits_raw_payload_values(tmp_path) -> None: class FailingStore(InMemoryFieldCacheStore): - async def set(self, key, value, *, append=False): + async def set(self, key, value, *, ttl_seconds=None): raise RuntimeError("store failed") logger = TransactionLogger("openai", "gpt-test", parent_dir=tmp_path) @@ -261,6 +261,17 @@ async def clear(self) -> None: assert await store.get("key") == [{"value": 2}] +@pytest.mark.asyncio +async def test_in_memory_store_expires_ttl_values() -> None: + now = 100.0 + store = InMemoryFieldCacheStore(clock=lambda: now) + + await store.set("key", "value", ttl_seconds=5) + assert await store.get("key") == "value" + now = 106.0 + assert await store.get("key") is None + + @pytest.mark.asyncio async def test_in_memory_store_returns_deep_copies() -> None: store = InMemoryFieldCacheStore() @@ -269,3 +280,135 @@ async def test_in_memory_store_returns_deep_copies() -> None: value["nested"].append("mutated") assert await store.get("key") == {"nested": []} + + +@pytest.mark.asyncio +async def test_last_user_turn_uses_latest_user_message() -> None: + rule = FieldCacheRule( + name="user_signature", + source="request", + path="messages.*.metadata.signature", + mode="last_user_turn", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + operations = await engine.extract( + "request", + { + "messages": [ + {"role": "user", "metadata": {"signature": "first-user"}}, + {"role": "assistant", "metadata": {"signature": "assistant"}}, + {"role": "user", "metadata": {"signature": "last-user"}}, + ] + }, + _context(session_id=None), + ) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert operations[0].changed is True + assert updated["metadata"]["signature"] == "last-user" + + +@pytest.mark.asyncio +async def test_last_assistant_turn_skips_without_turn_context() -> None: + rule = FieldCacheRule( + name="assistant_signature", + source="response", + path="choices.*.message.signature", + mode="last_assistant_turn", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + operations = await engine.extract("response", {"choices": [{"message": {"signature": "sig"}}]}, _context(session_id=None)) + + assert operations[0].skipped is True + assert operations[0].reason == "turn_context_not_found" + + +@pytest.mark.asyncio +async def test_turn_mode_uses_metadata_configured_relative_paths() -> None: + rule = FieldCacheRule( + name="assistant_signature", + source="response", + path="unused.global.path", + mode="last_assistant_turn", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + metadata={"turn_container_path": "messages", "turn_role_path": "kind", "turn_value_path": "parts.*.signature"}, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract( + "response", + {"messages": [{"kind": "assistant", "parts": [{"signature": "first"}]}, {"kind": "assistant", "parts": [{"signature": "second"}]}]}, + _context(session_id=None), + ) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert updated["metadata"]["signature"] == "second" + + +@pytest.mark.asyncio +async def test_per_tool_call_correlates_sibling_id_and_value_for_injection() -> None: + rule = FieldCacheRule( + name="tool_signature", + source="response", + path="tool_calls.*.signature", + mode="per_tool_call", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + metadata={ + "tool_container_path": "tool_calls", + "tool_call_id_path": "id", + "tool_value_path": "signature", + }, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"tool_calls": [{"id": "call_a", "signature": "sig-a"}, {"id": "call_b", "signature": "sig-b"}]}, _context(session_id=None)) + updated, operations = await engine.inject("request", {"metadata": {}}, _context(session_id=None, metadata={"tool_call_id": "call_b"})) + + assert operations[0].hit is True + assert updated["metadata"]["signature"] == "sig-b" + + +@pytest.mark.asyncio +async def test_per_tool_call_skips_when_current_tool_id_is_ambiguous() -> None: + rule = FieldCacheRule( + name="tool_signature", + source="response", + path="tool_calls.*", + mode="per_tool_call", + inject=FieldCacheInjection(target="request", path="metadata.signature"), + allow_missing_session=True, + metadata={"tool_call_id_path": "id"}, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"tool_calls": [{"id": "call_a", "signature": "sig-a"}]}, _context(session_id=None)) + updated, operations = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert operations[0].skipped is True + assert operations[0].reason == "tool_call_id_not_found" + assert updated == {"metadata": {}} + + +@pytest.mark.asyncio +async def test_insert_injection_adds_list_entry() -> None: + rule = FieldCacheRule( + name="prefix_message", + source="response", + path="message", + inject=FieldCacheInjection(target="request", path="messages.0", insert=True), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"message": {"role": "system", "content": "cached"}}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"messages": [{"role": "user", "content": "hi"}]}, _context(session_id=None)) + + assert updated["messages"] == [{"role": "system", "content": "cached"}, {"role": "user", "content": "hi"}] diff --git a/tests/test_field_cache_paths.py b/tests/test_field_cache_paths.py index 4f5b22066..aaf76191b 100644 --- a/tests/test_field_cache_paths.py +++ b/tests/test_field_cache_paths.py @@ -54,6 +54,15 @@ def test_inject_path_respects_when_missing_only() -> None: assert payload["metadata"]["prompt_cache_key"] == "existing" +def test_inject_path_can_insert_at_final_list_index() -> None: + payload = {"messages": [{"role": "user"}]} + + changed = inject_path(payload, "messages.0", {"role": "system"}, insert=True) + + assert changed is True + assert payload["messages"] == [{"role": "system"}, {"role": "user"}] + + def test_inject_path_rejects_wildcards_and_missing_lists() -> None: with pytest.raises(FieldCachePathError): inject_path({"choices": []}, "choices.*.message.reasoning_content", "x") From b810c5c8cf0a50085831f308711929b5ea977fd9 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 10:32:52 +0200 Subject: [PATCH 093/182] feat(field-cache): merge configured native rules Merges optional JSON-configured field-cache rules into native provider contexts, with JSON rules replacing provider-declared rules by name and appending otherwise. Keeps the experimental config import local to avoid startup cycles and continues rejecting unsafe secret-like JSON keys through the existing config loader. Tests: pytest tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_request_executor_native_routing.py tests/test_experimental_config.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 32 +++++++++- tests/test_request_executor_native_routing.py | 58 +++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index b2e9ab102..6c38b82cd 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -671,7 +671,7 @@ def _build_native_provider_context( classifier=context.classifier, adapter_names=tuple(plugin.get_adapter_names(model) if hasattr(plugin, "get_adapter_names") else ()), adapter_config=dict(plugin.get_adapter_config(model) if hasattr(plugin, "get_adapter_config") else {}), - field_cache_rules=tuple(plugin.get_field_cache_rules(model) if hasattr(plugin, "get_field_cache_rules") else ()), + field_cache_rules=_merged_field_cache_rules(provider, model, plugin), transaction_logger=context.transaction_logger, ) @@ -2167,6 +2167,36 @@ def _provider_native_protocol(plugin: Any, model: str, target: Optional[RouteTar return None +def _merged_field_cache_rules(provider: str, model: str, plugin: Any) -> tuple[Any, ...]: + """Merge provider-declared and JSON-configured field-cache rules. + + Provider declarations are the safe default. Optional JSON config can add or + replace rules by name so operators can tune protocol-state preservation per + provider/model without editing provider code. The import is local to keep the + experimental config layer out of executor module initialization. + """ + + declared = list(plugin.get_field_cache_rules(model) if plugin and hasattr(plugin, "get_field_cache_rules") else ()) + try: + from ..config.experimental import load_experimental_config, parse_field_cache_rules + + configured = list(parse_field_cache_rules(load_experimental_config(), provider, model)) + except Exception as exc: + lib_logger.debug("Failed to load configured field-cache rules for %s/%s: %s", provider, model, exc) + configured = [] + if not configured: + return tuple(declared) + merged: dict[str, Any] = {getattr(rule, "name", str(index)): rule for index, rule in enumerate(declared)} + order = [getattr(rule, "name", str(index)) for index, rule in enumerate(declared)] + for rule in configured: + name = getattr(rule, "name", "") + if name and name not in merged: + order.append(name) + if name: + merged[name] = rule + return tuple(merged[name] for name in order if name in merged) + + def _target_scope_value(target: RouteTarget, key: str, default: Any) -> Any: """Read request-scope metadata attached by routing resolution.""" diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index dcb9a0d97..e7f1f481b 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -1,10 +1,13 @@ from __future__ import annotations +import json + import pytest from rotator_library.client import executor as executor_module from rotator_library.client.executor import RequestExecutor, RoutingExecutionError from rotator_library.core.types import RequestContext +from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule from rotator_library.routing import parse_route_target @@ -51,6 +54,18 @@ def get_field_cache_rules(self, model=""): return () +class NativePluginWithRule(NativePlugin): + def get_field_cache_rules(self, model=""): + return ( + FieldCacheRule( + name="state", + source="response", + path="provider.path", + inject=FieldCacheInjection(target="request", path="metadata.state"), + ), + ) + + class CustomPlugin: def __init__(self): self.calls = [] @@ -98,6 +113,49 @@ async def test_native_declared_provider_uses_native_executor_in_auto_mode() -> N assert http_client.calls[0]["headers"]["Authorization"] == "Bearer secret" +def test_native_context_merges_json_field_cache_rules(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "field_cache": { + "provider": { + "*": [ + { + "name": "state", + "source": "response", + "path": "json.path", + "inject": {"target": "request", "path": "metadata.state"}, + }, + { + "name": "extra", + "source": "response", + "path": "json.extra", + "inject": {"target": "request", "path": "metadata.extra"}, + }, + ] + } + } + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("LLM_PROXY_CONFIG_FILE", str(config_path)) + + native_context = _executor()._build_native_provider_context( + "provider", + "provider/gpt-test", + NativePluginWithRule(), + "secret", + "stable", + _context(), + None, + ) + + assert [rule.name for rule in native_context.field_cache_rules] == ["state", "extra"] + assert native_context.field_cache_rules[0].path == "json.path" + + @pytest.mark.asyncio async def test_litellm_fallback_execution_is_explicit(monkeypatch) -> None: calls = [] From 626f2d08df0f60820b5032c2bcc73e138a556871 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 10:45:30 +0200 Subject: [PATCH 094/182] fix(field-cache): close runtime integration gaps Fixes Phase 3b review findings by reusing a shared native executor from RequestExecutor, wiring native streaming selection for native-capable routes, and preserving provider-state cache continuity through the real client path. Hardens traces for cached provider-state fields, preserves list-valued last-mode cache values with envelopes, supports legacy stores without ttl kwargs, validates JSON field-cache config loudly, and matches unprefixed model aliases. Also adds tests for provider-cache TTL expiry, per-tool-call as_list injection, list-valued cache values, empty-list insertion, RequestExecutor-level persistence, config errors, and native trace redaction. Tests: pytest tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_request_executor_native_routing.py tests/test_experimental_config.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 57 +++++++++++- src/rotator_library/config/experimental.py | 6 +- src/rotator_library/field_cache/engine.py | 38 ++++++-- src/rotator_library/field_cache/paths.py | 16 ++-- src/rotator_library/field_cache/store.py | 6 +- src/rotator_library/transform_trace.py | 14 ++- tests/test_experimental_config.py | 16 ++++ tests/test_field_cache_engine.py | 89 +++++++++++++++++++ tests/test_field_cache_paths.py | 9 ++ tests/test_native_provider_executor.py | 2 + tests/test_request_executor_native_routing.py | 42 ++++++++- tests/test_transform_trace.py | 2 + 12 files changed, 276 insertions(+), 21 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 6c38b82cd..58b86cb3e 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -151,6 +151,7 @@ def __init__( self._litellm_logger_fn = litellm_logger_fn # StreamingHandler no longer needs usage_manager - we pass cred_context directly self._streaming_handler = StreamingHandler() + self._native_executor = NativeProviderExecutor() def _get_transient_retry_delay(self) -> float: """Small jittered delay used before transient retries and rotations.""" @@ -609,10 +610,19 @@ async def _execute_provider_request( _target_trace(target) if target else {"provider": provider, "model": model}, metadata={"protocol": native_context.protocol_name}, ) - return await NativeProviderExecutor().execute(dict(kwargs), native_context, NativeHTTPTransport(self._http_client)) + return await self._get_native_executor().execute(dict(kwargs), native_context, NativeHTTPTransport(self._http_client)) return await self._execute_litellm_request(kwargs, credential_secret, context=context, credential_id=credential_id) + def _get_native_executor(self) -> NativeProviderExecutor: + """Return the shared native executor for process-local field-cache state.""" + + native_executor = getattr(self, "_native_executor", None) + if native_executor is None: + native_executor = NativeProviderExecutor() + self._native_executor = native_executor + return native_executor + async def _execute_litellm_request( self, kwargs: Dict[str, Any], @@ -1319,8 +1329,33 @@ async def _execute_streaming( context, kwargs ) + target = _current_route_target(context) + execution = target.execution if target else "auto" + # Make the API call - if plugin and plugin.has_custom_logic(): + if execution == "native" or ( + execution == "auto" + and plugin + and _provider_native_protocol(plugin, model, target) + and _provider_supports_native_streaming(plugin, model) + ): + native_context = self._build_native_provider_context( + provider, + model, + plugin, + credential_secret, + cred_context.stable_id, + context, + target, + ) + self._log_routing_trace( + context, + "routing_native_stream_execution_selected", + _target_trace(target) if target else {"provider": provider, "model": model}, + metadata={"protocol": native_context.protocol_name}, + ) + stream = self._get_native_executor().stream(dict(kwargs), native_context, NativeHTTPTransport(self._http_client)) + elif plugin and plugin.has_custom_logic(): kwargs["credential_identifier"] = credential_secret self._log_executor_trace( context, @@ -2167,6 +2202,21 @@ def _provider_native_protocol(plugin: Any, model: str, target: Optional[RouteTar return None +def _provider_supports_native_streaming(plugin: Any, model: str) -> bool: + """Return whether a provider declares native streaming support.""" + + support = getattr(plugin, "supports_native_streaming", None) + if not callable(support): + return False + try: + return bool(support(model=model, operation="chat")) + except TypeError: + try: + return bool(support(model)) + except TypeError: + return bool(support()) + + def _merged_field_cache_rules(provider: str, model: str, plugin: Any) -> tuple[Any, ...]: """Merge provider-declared and JSON-configured field-cache rules. @@ -2182,8 +2232,7 @@ def _merged_field_cache_rules(provider: str, model: str, plugin: Any) -> tuple[A configured = list(parse_field_cache_rules(load_experimental_config(), provider, model)) except Exception as exc: - lib_logger.debug("Failed to load configured field-cache rules for %s/%s: %s", provider, model, exc) - configured = [] + raise RoutingExecutionError(f"Invalid field-cache configuration for {provider}/{model}", error_type="configuration_error") from exc if not configured: return tuple(declared) merged: dict[str, Any] = {getattr(rule, "name", str(index)): rule for index, rule in enumerate(declared)} diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py index 31767d047..65a0a4ca2 100644 --- a/src/rotator_library/config/experimental.py +++ b/src/rotator_library/config/experimental.py @@ -173,7 +173,11 @@ def parse_field_cache_rules(config: ExperimentalConfig, provider: str, model: st if not isinstance(provider_rules, dict): return () raw_rules: list[Any] = [] - for key in ("*", model): + keys = ["*"] + if "/" in model: + keys.append(model.split("/", 1)[1]) + keys.append(model) + for key in dict.fromkeys(keys): value = provider_rules.get(key, []) if isinstance(value, list): raw_rules.extend(value) diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index baff02c98..e15e1cead 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -173,10 +173,10 @@ def _rules_for_injection(self, target: str) -> list[FieldCacheRule]: async def _store_values(self, rule: FieldCacheRule, cache_key: str, values: list[Any], payload: Any, operation: FieldCacheOperation) -> bool: if rule.mode == "all": - await self.store.append(cache_key, values, ttl_seconds=rule.ttl_seconds) + await self._store_append(cache_key, values, ttl_seconds=rule.ttl_seconds) return True if rule.mode == "last": - await self.store.set(cache_key, values[-1], ttl_seconds=rule.ttl_seconds) + await self._store_set(cache_key, _wrap_cached_value(values[-1]), ttl_seconds=rule.ttl_seconds) return True if rule.mode in {"last_user_turn", "last_assistant_turn"}: role = "user" if rule.mode == "last_user_turn" else "assistant" @@ -187,7 +187,7 @@ async def _store_values(self, rule: FieldCacheRule, cache_key: str, values: list return False operation.matched = len(turn_values) operation.sample_values = _sample_values(turn_values) - await self.store.set(cache_key, turn_values[-1], ttl_seconds=rule.ttl_seconds) + await self._store_set(cache_key, _wrap_cached_value(turn_values[-1]), ttl_seconds=rule.ttl_seconds) return True if rule.mode == "per_tool_call": stored = _tool_call_values(rule, payload, values) @@ -197,10 +197,24 @@ async def _store_values(self, rule: FieldCacheRule, cache_key: str, values: list return False operation.matched = len(stored) operation.sample_values = _sample_values(list(stored.values())) - await self.store.set(cache_key, stored, ttl_seconds=rule.ttl_seconds) + await self._store_set(cache_key, stored, ttl_seconds=rule.ttl_seconds) return True raise ValueError(f"Unsupported field-cache mode: {rule.mode}") + async def _store_set(self, cache_key: str, value: Any, *, ttl_seconds: Optional[int]) -> None: + try: + await self.store.set(cache_key, value, ttl_seconds=ttl_seconds) + except TypeError: + # Preserve compatibility with simple injected stores that implement + # the original set(key, value) shape. TTL is best-effort there. + await self.store.set(cache_key, value) + + async def _store_append(self, cache_key: str, values: list[Any], *, ttl_seconds: Optional[int]) -> None: + try: + await self.store.append(cache_key, values, ttl_seconds=ttl_seconds) + except TypeError: + await self.store.append(cache_key, values) + def _injection_value(self, rule: FieldCacheRule, cached: Any, payload: Any, context: FieldCacheContext, operation: FieldCacheOperation) -> Any: """Select the cached value to inject for the rule's mode. @@ -218,7 +232,7 @@ def _injection_value(self, rule: FieldCacheRule, cached: Any, payload: Any, cont operation.skipped = True operation.reason = "tool_call_id_not_found" return None - matches = [cached[str(tool_id)] for tool_id in ids if str(tool_id) in cached] + matches = [_unwrap_cached_value(cached[str(tool_id)]) for tool_id in ids if str(tool_id) in cached] if not matches: operation.skipped = True operation.reason = "tool_call_cache_miss" @@ -232,7 +246,7 @@ def _injection_value(self, rule: FieldCacheRule, cached: Any, payload: Any, cont return None if rule.inject and (rule.inject.as_list or rule.mode == "all"): return cached if isinstance(cached, list) else [cached] - return _last_value(cached) + return _unwrap_cached_value(cached) def _trace( self, @@ -322,6 +336,18 @@ def _last_value(value: Any) -> Any: return value +def _wrap_cached_value(value: Any) -> dict[str, Any]: + """Wrap one extracted value so list-valued fields stay intact on injection.""" + + return {"__field_cache_value__": True, "value": deepcopy(value)} + + +def _unwrap_cached_value(value: Any) -> Any: + if isinstance(value, dict) and value.get("__field_cache_value__") is True: + return deepcopy(value.get("value")) + return _last_value(value) + + def _turn_values(rule: FieldCacheRule, payload: Any, role: str) -> list[Any]: """Return values from the latest turn matching `role`. diff --git a/src/rotator_library/field_cache/paths.py b/src/rotator_library/field_cache/paths.py index 8f8a6462b..1dfc29203 100644 --- a/src/rotator_library/field_cache/paths.py +++ b/src/rotator_library/field_cache/paths.py @@ -138,17 +138,21 @@ def inject_path(payload: Any, path: str, injected_value: Any, *, when_missing_on current = current[key] continue if token.kind == "index": - if not isinstance(current, list) or not current: + if not isinstance(current, list) or (not current and not (is_last and insert)): raise FieldCachePathError("Cannot inject into missing or empty list") list_index = int(token.value) + if is_last and insert: + if list_index < 0: + list_index = max(0, len(current) + list_index) + if not (0 <= list_index <= len(current)): + raise FieldCachePathError(f"List index out of range for field-cache insertion: {list_index}") + if when_missing_only: + return False + current.insert(list_index, injected_value) + return True if not (-len(current) <= list_index < len(current)): raise FieldCachePathError(f"List index out of range for field-cache injection: {list_index}") if is_last: - if insert: - if when_missing_only: - return False - current.insert(list_index, injected_value) - return True if when_missing_only and current[list_index] is not None: return False changed = current[list_index] != injected_value diff --git a/src/rotator_library/field_cache/store.py b/src/rotator_library/field_cache/store.py index 158898e49..659e0c35f 100644 --- a/src/rotator_library/field_cache/store.py +++ b/src/rotator_library/field_cache/store.py @@ -12,6 +12,8 @@ from ..protocols import serialize_value +_TTL_ENVELOPE_MARKER = "__llm_proxy_field_cache_ttl_v1__" + class FieldCacheStore(Protocol): """Minimal async store interface used by `FieldCacheEngine`.""" @@ -89,7 +91,7 @@ async def get(self, key: str) -> Any: if raw is None: return None value = json.loads(raw) - if isinstance(value, dict) and value.get("__field_cache_wrapped__") is True: + if isinstance(value, dict) and value.get(_TTL_ENVELOPE_MARKER) is True: expires_at = value.get("expires_at") if isinstance(expires_at, (int, float)) and expires_at <= time.time(): return None @@ -99,7 +101,7 @@ async def get(self, key: str) -> Any: async def set(self, key: str, value: Any, *, ttl_seconds: int | None = None) -> None: payload = serialize_value(value) if ttl_seconds is not None and ttl_seconds > 0: - payload = {"__field_cache_wrapped__": True, "expires_at": time.time() + ttl_seconds, "value": payload} + payload = {_TTL_ENVELOPE_MARKER: True, "expires_at": time.time() + ttl_seconds, "value": payload} await self._cache.store_async(key, json.dumps(payload, ensure_ascii=False)) async def append(self, key: str, values: list[Any], *, ttl_seconds: int | None = None) -> list[Any]: diff --git a/src/rotator_library/transform_trace.py b/src/rotator_library/transform_trace.py index 8e74c6303..3608d140c 100644 --- a/src/rotator_library/transform_trace.py +++ b/src/rotator_library/transform_trace.py @@ -43,6 +43,13 @@ "password", "secret", "token", + # Provider protocol-state fields can contain opaque signatures or + # reasoning continuations that field-cache rules preserve and re-inject. + "reasoning-content", + "reasoningcontent", + "thought-signature", + "thoughtsignature", + "signature", } ) @@ -109,8 +116,13 @@ def sanitize_for_trace(value: Any, *, scrub_strings: bool = False) -> Any: if isinstance(value, Mapping): sanitized = {} + block_type = str(value.get("type", "")).lower() + source_field = str(value.get("source_field", "")).lower() + if not source_field and isinstance(value.get("extra"), Mapping): + source_field = str(value["extra"].get("source_field", "")).lower() + is_reasoning_block = "reasoning" in block_type or source_field in {"reasoning_content", "thoughtsignature", "thought_signature", "signature"} for key, item in value.items(): - if _normalise_key(key) in _SENSITIVE_KEYS: + if _normalise_key(key) in _SENSITIVE_KEYS or (is_reasoning_block and _normalise_key(key) == "text"): sanitized[str(key)] = REDACTED else: sanitized[str(key)] = sanitize_for_trace(item, scrub_strings=scrub_strings) diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py index 313d0483b..88ceb2879 100644 --- a/tests/test_experimental_config.py +++ b/tests/test_experimental_config.py @@ -70,6 +70,22 @@ def test_field_cache_rules_parse_wildcard_then_model_specific() -> None: assert rules[0].inject is not None +def test_field_cache_rules_match_unprefixed_model_alias() -> None: + config = load_config_from_mapping( + { + "field_cache": { + "gemini_cli": { + "gemini-3": [{"name": "signature", "source": "response", "path": "sig"}], + } + } + } + ) + + rules = parse_field_cache_rules(config, "gemini_cli", "gemini_cli/gemini-3") + + assert [rule.name for rule in rules] == ["signature"] + + def test_env_price_key_sanitizes_provider_and_model() -> None: assert env_price_key("openai", "gpt-5.1-mini", "cache_read") == "MODEL_PRICE_OPENAI_GPT_5_1_MINI_CACHE_READ" diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py index 3738dd425..6db1cf400 100644 --- a/tests/test_field_cache_engine.py +++ b/tests/test_field_cache_engine.py @@ -272,6 +272,32 @@ async def test_in_memory_store_expires_ttl_values() -> None: assert await store.get("key") is None +@pytest.mark.asyncio +async def test_provider_cache_field_store_expires_ttl_values(monkeypatch) -> None: + class FakeProviderCache: + def __init__(self) -> None: + self.values = {} + + async def retrieve_async(self, key: str): + return self.values.get(key) + + async def store_async(self, key: str, value: str) -> None: + json.loads(value) + self.values[key] = value + + async def clear(self) -> None: + self.values.clear() + + now = 100.0 + monkeypatch.setattr("rotator_library.field_cache.store.time.time", lambda: now) + store = ProviderCacheFieldStore(FakeProviderCache()) + + await store.set("key", "value", ttl_seconds=5) + assert await store.get("key") == "value" + now = 106.0 + assert await store.get("key") is None + + @pytest.mark.asyncio async def test_in_memory_store_returns_deep_copies() -> None: store = InMemoryFieldCacheStore() @@ -311,6 +337,23 @@ async def test_last_user_turn_uses_latest_user_message() -> None: assert updated["metadata"]["signature"] == "last-user" +@pytest.mark.asyncio +async def test_last_mode_preserves_list_valued_field() -> None: + rule = FieldCacheRule( + name="list_value", + source="response", + path="metadata.signatures", + inject=FieldCacheInjection(target="request", path="metadata.signatures"), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"metadata": {"signatures": ["a", "b"]}}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert updated["metadata"]["signatures"] == ["a", "b"] + + @pytest.mark.asyncio async def test_last_assistant_turn_skips_without_turn_context() -> None: rule = FieldCacheRule( @@ -376,6 +419,52 @@ async def test_per_tool_call_correlates_sibling_id_and_value_for_injection() -> assert updated["metadata"]["signature"] == "sig-b" +@pytest.mark.asyncio +async def test_per_tool_call_as_list_injects_matching_values() -> None: + rule = FieldCacheRule( + name="tool_signature", + source="response", + path="tool_calls.*", + mode="per_tool_call", + inject=FieldCacheInjection(target="request", path="metadata.signatures", as_list=True), + allow_missing_session=True, + metadata={"tool_call_id_path": "id", "inject_tool_call_id_path": "tool_ids.*"}, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"tool_calls": [{"id": "a", "signature": "sig-a"}, {"id": "b", "signature": "sig-b"}]}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"metadata": {}, "tool_ids": ["a", "b"]}, _context(session_id=None)) + + assert updated["metadata"]["signatures"] == [{"id": "a", "signature": "sig-a"}, {"id": "b", "signature": "sig-b"}] + + +@pytest.mark.asyncio +async def test_engine_supports_legacy_store_without_ttl_keyword() -> None: + class LegacyStore: + def __init__(self) -> None: + self.values = {} + + async def get(self, key): + return self.values.get(key) + + async def set(self, key, value): + self.values[key] = value + + async def append(self, key, values): + self.values.setdefault(key, []).extend(values) + return self.values[key] + + async def clear(self): + self.values.clear() + + engine = FieldCacheEngine([_reasoning_rule()], store=LegacyStore()) + + await engine.extract("response", {"choices": [{"message": {"reasoning_content": "legacy"}}]}, _context()) + updated, _ = await engine.inject("request", {"messages": [{}]}, _context()) + + assert updated["messages"][-1]["reasoning_content"] == "legacy" + + @pytest.mark.asyncio async def test_per_tool_call_skips_when_current_tool_id_is_ambiguous() -> None: rule = FieldCacheRule( diff --git a/tests/test_field_cache_paths.py b/tests/test_field_cache_paths.py index aaf76191b..8e916b7df 100644 --- a/tests/test_field_cache_paths.py +++ b/tests/test_field_cache_paths.py @@ -63,6 +63,15 @@ def test_inject_path_can_insert_at_final_list_index() -> None: assert payload["messages"] == [{"role": "system"}, {"role": "user"}] +def test_inject_path_can_insert_into_empty_list() -> None: + payload = {"messages": []} + + changed = inject_path(payload, "messages.0", {"role": "system"}, insert=True) + + assert changed is True + assert payload["messages"] == [{"role": "system"}] + + def test_inject_path_rejects_wildcards_and_missing_lists() -> None: with pytest.raises(FieldCachePathError): inject_path({"choices": []}, "choices.*.message.reasoning_content", "x") diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index c8ddc7e4a..da7cfb59a 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -70,6 +70,8 @@ async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tm assert result["id"] == "chat_1" assert client.calls[0]["json"]["model"] == "provider/gpt-test" + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "hidden" not in trace_text pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] assert "native_protocol_selected" in pass_names assert "raw_native_client_request" in pass_names diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index e7f1f481b..538fa1ef1 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -31,6 +31,16 @@ async def post(self, endpoint, *, headers, json): return FakeNativeResponse({"id": "chat_native", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}) +class SequencedHTTPClient: + def __init__(self, responses): + self.responses = list(responses) + self.calls = [] + + async def post(self, endpoint, *, headers, json): + self.calls.append({"endpoint": endpoint, "headers": headers, "json": json}) + return FakeNativeResponse(self.responses.pop(0)) + + class NativePlugin: def has_custom_logic(self): return False @@ -60,8 +70,9 @@ def get_field_cache_rules(self, model=""): FieldCacheRule( name="state", source="response", - path="provider.path", + path="choices.0.message.reasoning_content", inject=FieldCacheInjection(target="request", path="metadata.state"), + allow_missing_session=True, ), ) @@ -113,6 +124,24 @@ async def test_native_declared_provider_uses_native_executor_in_auto_mode() -> N assert http_client.calls[0]["headers"]["Authorization"] == "Bearer secret" +@pytest.mark.asyncio +async def test_request_executor_reuses_native_field_cache_store() -> None: + http_client = SequencedHTTPClient( + [ + {"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "cached"}}]}, + {"id": "chat_2", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}, + ] + ) + executor = _executor(http_client) + context = _context(parse_route_target("provider/gpt-test")) + context.routing_target_index = 0 + + await executor._execute_provider_request("provider", "provider/gpt-test", NativePluginWithRule(), "secret", "stable", dict(context.kwargs), context) + await executor._execute_provider_request("provider", "provider/gpt-test", NativePluginWithRule(), "secret", "stable", dict(context.kwargs), context) + + assert http_client.calls[1]["json"]["metadata"]["state"] == "cached" + + def test_native_context_merges_json_field_cache_rules(monkeypatch, tmp_path) -> None: config_path = tmp_path / "config.json" config_path.write_text( @@ -156,6 +185,17 @@ def test_native_context_merges_json_field_cache_rules(monkeypatch, tmp_path) -> assert native_context.field_cache_rules[0].path == "json.path" +def test_native_context_raises_on_invalid_field_cache_config(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"field_cache": {"provider": {"*": [{"name": "bad", "source": "response"}]}}}), encoding="utf-8") + monkeypatch.setenv("LLM_PROXY_CONFIG_FILE", str(config_path)) + + with pytest.raises(RoutingExecutionError) as exc: + _executor()._build_native_provider_context("provider", "provider/gpt-test", NativePlugin(), "secret", "stable", _context(), None) + + assert exc.value.error_type == "configuration_error" + + @pytest.mark.asyncio async def test_litellm_fallback_execution_is_explicit(monkeypatch) -> None: calls = [] diff --git a/tests/test_transform_trace.py b/tests/test_transform_trace.py index b3ed8928d..e289cd15d 100644 --- a/tests/test_transform_trace.py +++ b/tests/test_transform_trace.py @@ -34,6 +34,8 @@ def test_sanitize_for_trace_redacts_sensitive_keys_recursively() -> None: assert sanitized["headers"]["Cookie"] == REDACTED assert sanitized["headers"]["normal"] == "token in normal text" assert sanitized["items"][0]["refresh_token"] == REDACTED + assert sanitize_for_trace({"reasoning_content": "provider-state"})["reasoning_content"] == REDACTED + assert sanitize_for_trace({"thoughtSignature": "provider-state"})["thoughtSignature"] == REDACTED assert sanitized["items"][0]["text"] == "token should remain in value" From 3ff29ac2d85d47335d8468c5d9c9ebed90434921 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 10:55:24 +0200 Subject: [PATCH 095/182] fix(field-cache): harden native trace and config semantics Adds path-based redaction for active field-cache source and injection paths in native provider traces so arbitrary configured cache keys cannot leak provider-state values. Tightens field-cache rule validation for source, target, and scope values, preserves configuration errors through native execution, sets streaming native context transport to sse, and unwraps value envelopes for as_list/per-tool-call cases. Tests: pytest tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_request_executor_native_routing.py tests/test_experimental_config.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 27 ++++++++ src/rotator_library/field_cache/engine.py | 9 ++- src/rotator_library/field_cache/types.py | 10 +++ .../native_provider/executor.py | 67 ++++++++++++++++++- tests/test_experimental_config.py | 17 +++++ tests/test_field_cache_engine.py | 36 ++++++++++ tests/test_native_provider_executor.py | 36 ++++++++++ 7 files changed, 198 insertions(+), 4 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 58b86cb3e..ef73216e1 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -657,6 +657,7 @@ def _build_native_provider_context( credential_id: str, context: RequestContext, target: Optional[RouteTarget], + transport: str = "http", ) -> NativeProviderContext: """Build native provider context from provider declarations.""" @@ -679,6 +680,7 @@ def _build_native_provider_context( session_id=context.session_id, scope_key=context.usage_manager_key, classifier=context.classifier, + transport=transport, adapter_names=tuple(plugin.get_adapter_names(model) if hasattr(plugin, "get_adapter_names") else ()), adapter_config=dict(plugin.get_adapter_config(model) if hasattr(plugin, "get_adapter_config") else {}), field_cache_rules=_merged_field_cache_rules(provider, model, plugin), @@ -1150,6 +1152,27 @@ async def _execute_non_streaming( ) return normalized_response + except RoutingExecutionError as e: + if e.error_type == "configuration_error": + raise + last_exception = e + action = await self._handle_error_with_context( + e, + cred_context, + model, + provider, + attempt, + error_accumulator, + retry_state, + request_headers, + context, + ) + if action == ErrorAction.RETRY_SAME: + continue + elif action == ErrorAction.ROTATE: + break + else: + raise except Exception as e: last_exception = e action = await self._handle_error_with_context( @@ -1173,6 +1196,9 @@ async def _execute_non_streaming( except PreRequestCallbackError: raise + except RoutingExecutionError as exc: + if exc.error_type == "configuration_error": + raise except Exception: # Let context manager handle cleanup pass @@ -1347,6 +1373,7 @@ async def _execute_streaming( cred_context.stable_id, context, target, + transport="sse", ) self._log_routing_trace( context, diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index e15e1cead..b32afcef3 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -244,8 +244,11 @@ def _injection_value(self, rule: FieldCacheRule, cached: Any, payload: Any, cont operation.skipped = True operation.reason = "ambiguous_tool_call_values" return None - if rule.inject and (rule.inject.as_list or rule.mode == "all"): + if rule.mode == "all": return cached if isinstance(cached, list) else [cached] + if rule.inject and rule.inject.as_list: + unwrapped = _unwrap_cached_value(cached) + return unwrapped if isinstance(unwrapped, list) else [unwrapped] return _unwrap_cached_value(cached) def _trace( @@ -397,12 +400,12 @@ def _tool_call_values(rule: FieldCacheRule, payload: Any, values: list[Any]) -> tool_ids = extract_path(container, str(tool_id_path)) if tool_id_path else [] tool_values = extract_path(container, str(tool_value_path)) if tool_ids and tool_values: - stored[str(tool_ids[0])] = tool_values[-1] + stored[str(tool_ids[0])] = _wrap_cached_value(tool_values[-1]) return stored for value in values: tool_ids = extract_path(value, str(tool_id_path)) if tool_id_path else [] if tool_ids: - stored[str(tool_ids[0])] = value + stored[str(tool_ids[0])] = _wrap_cached_value(value) return stored diff --git a/src/rotator_library/field_cache/types.py b/src/rotator_library/field_cache/types.py index 02a44add8..1eeb8152f 100644 --- a/src/rotator_library/field_cache/types.py +++ b/src/rotator_library/field_cache/types.py @@ -21,6 +21,9 @@ FieldCacheScope = Literal["provider", "model", "credential", "session", "conversation", "classifier"] DEFAULT_SCOPE: tuple[FieldCacheScope, ...] = ("provider", "model", "classifier", "session") +_VALID_SOURCES = {"request", "response", "stream_event", "unified_request", "unified_response", "unified_stream_event"} +_VALID_TARGETS = {"request", "unified_request", "metadata"} +_VALID_SCOPES = {"provider", "model", "credential", "session", "conversation", "classifier"} @dataclass(frozen=True) @@ -59,8 +62,15 @@ def __post_init__(self) -> None: raise ValueError("FieldCacheRule.name must be non-empty and filesystem-safe") if self.mode not in {"last", "all", "last_user_turn", "last_assistant_turn", "per_tool_call"}: raise ValueError(f"Unsupported field-cache mode: {self.mode}") + if self.source not in _VALID_SOURCES: + raise ValueError(f"Unsupported field-cache source: {self.source}") + if self.inject and self.inject.target not in _VALID_TARGETS: + raise ValueError(f"Unsupported field-cache injection target: {self.inject.target}") if not self.scope: raise ValueError("FieldCacheRule.scope must contain at least one dimension") + invalid_scopes = [scope for scope in self.scope if scope not in _VALID_SCOPES] + if invalid_scopes: + raise ValueError(f"Unsupported field-cache scope: {invalid_scopes[0]}") @dataclass(frozen=True) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index 95bebc750..df1a9be3b 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -5,11 +5,14 @@ from __future__ import annotations +from copy import deepcopy from typing import Any, AsyncGenerator from ..adapters import get_adapter, run_adapter_chain from ..field_cache import FieldCacheEngine, InMemoryFieldCacheStore +from ..field_cache.paths import FieldCachePathError, PathToken, parse_path from ..protocols import get_protocol +from ..transform_trace import REDACTED from ..usage.accounting import extract_usage_record from .context import NativeProviderContext from .http import NativeHTTPTransport @@ -162,7 +165,7 @@ def _trace( return context.transaction_logger.log_transform_pass( pass_name, - data, + _redact_field_cache_paths(data, context, direction), direction=direction, stage=stage, protocol=context.protocol_name, @@ -178,3 +181,65 @@ def _trace( }, snapshot=snapshot, ) + + +def _redact_field_cache_paths(data: Any, context: NativeProviderContext, direction: str) -> Any: + """Redact configured cache paths before broad native payload traces. + + Field-cache rules can inject opaque state under arbitrary configured keys, + so key-based trace redaction is not enough. Native traces apply the active + rules' source and injection paths to a copy before handing data to the normal + transaction trace sanitizer. + """ + + if not context.field_cache_rules: + return data + redacted = deepcopy(data) + for rule in context.field_cache_rules: + paths: list[str] = [] + if direction == "request" and rule.inject: + paths.append(rule.inject.path) + if direction in {"response", "stream"}: + paths.append(rule.path) + for path in paths: + try: + _redact_path(redacted, parse_path(path)) + except (FieldCachePathError, TypeError, ValueError): + continue + return redacted + + +def _redact_path(value: Any, tokens: tuple[PathToken, ...]) -> None: + if not tokens: + return + token = tokens[0] + rest = tokens[1:] + if token.kind == "key": + if isinstance(value, dict) and token.value in value: + if rest: + _redact_path(value[token.value], rest) + else: + value[token.value] = REDACTED + return + if token.kind == "index": + if isinstance(value, list) and value: + index = int(token.value) + if -len(value) <= index < len(value): + if rest: + _redact_path(value[index], rest) + else: + value[index] = REDACTED + return + if token.kind == "wildcard": + if isinstance(value, dict): + for key in list(value.keys()): + if rest: + _redact_path(value[key], rest) + else: + value[key] = REDACTED + elif isinstance(value, list): + for index, item in enumerate(value): + if rest: + _redact_path(item, rest) + else: + value[index] = REDACTED diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py index 88ceb2879..6d78c7175 100644 --- a/tests/test_experimental_config.py +++ b/tests/test_experimental_config.py @@ -86,6 +86,23 @@ def test_field_cache_rules_match_unprefixed_model_alias() -> None: assert [rule.name for rule in rules] == ["signature"] +def test_field_cache_rule_rejects_invalid_config_values() -> None: + config = load_config_from_mapping( + { + "field_cache": { + "provider": { + "*": [ + {"name": "bad_source", "source": "not_a_source", "path": "x"}, + ] + } + } + } + ) + + with pytest.raises(ValueError, match="Unsupported field-cache source"): + parse_field_cache_rules(config, "provider", "model") + + def test_env_price_key_sanitizes_provider_and_model() -> None: assert env_price_key("openai", "gpt-5.1-mini", "cache_read") == "MODEL_PRICE_OPENAI_GPT_5_1_MINI_CACHE_READ" diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py index 6db1cf400..de6862f03 100644 --- a/tests/test_field_cache_engine.py +++ b/tests/test_field_cache_engine.py @@ -354,6 +354,23 @@ async def test_last_mode_preserves_list_valued_field() -> None: assert updated["metadata"]["signatures"] == ["a", "b"] +@pytest.mark.asyncio +async def test_as_list_unwraps_last_mode_value_envelope() -> None: + rule = FieldCacheRule( + name="value", + source="response", + path="value", + inject=FieldCacheInjection(target="request", path="metadata.values", as_list=True), + allow_missing_session=True, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"value": "sig"}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None)) + + assert updated["metadata"]["values"] == ["sig"] + + @pytest.mark.asyncio async def test_last_assistant_turn_skips_without_turn_context() -> None: rule = FieldCacheRule( @@ -438,6 +455,25 @@ async def test_per_tool_call_as_list_injects_matching_values() -> None: assert updated["metadata"]["signatures"] == [{"id": "a", "signature": "sig-a"}, {"id": "b", "signature": "sig-b"}] +@pytest.mark.asyncio +async def test_per_tool_call_preserves_list_valued_match() -> None: + rule = FieldCacheRule( + name="tool_signatures", + source="response", + path="tool_calls.*.signatures", + mode="per_tool_call", + inject=FieldCacheInjection(target="request", path="metadata.signatures"), + allow_missing_session=True, + metadata={"tool_container_path": "tool_calls", "tool_call_id_path": "id", "tool_value_path": "signatures"}, + ) + engine = FieldCacheEngine([rule]) + + await engine.extract("response", {"tool_calls": [{"id": "call", "signatures": ["a", "b"]}]}, _context(session_id=None)) + updated, _ = await engine.inject("request", {"metadata": {}}, _context(session_id=None, metadata={"tool_call_id": "call"})) + + assert updated["metadata"]["signatures"] == ["a", "b"] + + @pytest.mark.asyncio async def test_engine_supports_legacy_store_without_ttl_keyword() -> None: class LegacyStore: diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index da7cfb59a..ab123145e 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -130,6 +130,42 @@ async def test_native_provider_default_field_cache_persists_across_requests() -> assert second_client.calls[0]["json"]["metadata"]["reasoning_content"] == "cached-state" +@pytest.mark.asyncio +async def test_native_provider_trace_redacts_configured_injection_paths(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="state", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + transaction_logger=logger, + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "gpt-test", "messages": [{"role": "user", "content": "hi"}]}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "opaque-state"}}]})), + ) + await executor.execute( + {"model": "gpt-test", "messages": [{"role": "user", "content": "again"}]}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_2", "choices": [{"message": {"role": "assistant", "content": "ok"}}]})), + ) + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "opaque-state" not in trace_text + assert '"state": "[REDACTED]"' in trace_text + + @pytest.mark.asyncio async def test_native_provider_executor_logs_transform_errors(tmp_path) -> None: logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) From 7ea21eccadd89f7b17567423123acf90583aad5c Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:00:40 +0200 Subject: [PATCH 096/182] fix(field-cache): preserve config errors and provider-state redaction Broadens provider-state key redaction for cache-related fields that can appear after native responses return to the general RequestExecutor trace path. Also re-raises streaming native configuration errors before retry/rotation handling so malformed field-cache config remains visible instead of becoming an unknown rotatable provider failure. Tests: pytest tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_request_executor_native_routing.py tests/test_experimental_config.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 18 ++++++++++++++++++ src/rotator_library/transform_trace.py | 12 ++++++++++++ tests/test_transform_trace.py | 2 ++ 3 files changed, 32 insertions(+) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index ef73216e1..c1a3477ce 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1676,6 +1676,21 @@ async def _execute_streaming( continue # Retry + except RoutingExecutionError as e: + if e.error_type == "configuration_error": + raise + last_exception = e + classified = classify_error(e, provider) + if _can_start_stream_provider_cooldown(last_streamed_chunk): + await self._maybe_start_provider_cooldown(provider, classified, context=context) + log_failure(api_key=cred, model=model, attempt=attempt + 1, error=e, request_headers=request_headers) + error_accumulator.record_error(cred, classified, str(e)[:150]) + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + cred_context.mark_failure(classified) + break + except Exception as e: last_exception = e classified = classify_error(e, provider) @@ -1733,6 +1748,9 @@ async def _execute_streaming( except PreRequestCallbackError: raise + except RoutingExecutionError as exc: + if exc.error_type == "configuration_error": + raise except Exception: # Let context manager handle cleanup pass diff --git a/src/rotator_library/transform_trace.py b/src/rotator_library/transform_trace.py index 3608d140c..c54039876 100644 --- a/src/rotator_library/transform_trace.py +++ b/src/rotator_library/transform_trace.py @@ -50,6 +50,18 @@ "thought-signature", "thoughtsignature", "signature", + "state", + "provider-state", + "provider-session-id", + "providersessionid", + "prompt-cache-key", + "promptcachekey", + "cache-key", + "cachekey", + "thinking-signatures", + "thinkingsignatures", + "thought-signatures", + "thoughtsignatures", } ) diff --git a/tests/test_transform_trace.py b/tests/test_transform_trace.py index e289cd15d..682ae6b1f 100644 --- a/tests/test_transform_trace.py +++ b/tests/test_transform_trace.py @@ -36,6 +36,8 @@ def test_sanitize_for_trace_redacts_sensitive_keys_recursively() -> None: assert sanitized["items"][0]["refresh_token"] == REDACTED assert sanitize_for_trace({"reasoning_content": "provider-state"})["reasoning_content"] == REDACTED assert sanitize_for_trace({"thoughtSignature": "provider-state"})["thoughtSignature"] == REDACTED + assert sanitize_for_trace({"metadata": {"state": "provider-state"}})["metadata"]["state"] == REDACTED + assert sanitize_for_trace({"metadata": {"prompt_cache_key": "cache-state"}})["metadata"]["prompt_cache_key"] == REDACTED assert sanitized["items"][0]["text"] == "token should remain in value" From f00b48ac947475f1c4fba908b78f0c82fa4541f9 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:05:06 +0200 Subject: [PATCH 097/182] fix(field-cache): redact executor native response traces Applies active field-cache rule paths to RequestExecutor-level native response trace payloads so arbitrary configured cached fields are redacted after native responses return from NativeProviderExecutor. This keeps client-facing responses unchanged while closing the remaining trace-leak path for raw_provider_response, final response logging, and post_usage_normalization_response. Tests: pytest tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_request_executor_native_routing.py tests/test_experimental_config.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 76 ++++++++++++++++++- tests/test_request_executor_native_routing.py | 23 ++++++ 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index c1a3477ce..a35294cc2 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -67,6 +67,8 @@ from ..routing import FallbackPolicy, clone_context_for_target from ..routing.types import RouteTarget from ..native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from ..field_cache.paths import FieldCachePathError, PathToken, parse_path +from ..transform_trace import REDACTED from ..usage.accounting import UsageRecord, extract_usage_record from ..usage.costs import CostBreakdown, CostCalculator @@ -1084,10 +1086,11 @@ async def _execute_non_streaming( kwargs, context, ) + trace_response = _redact_context_field_cache_paths(response, context, "response", plugin) self._log_executor_trace( context, "raw_provider_response", - response, + trace_response, direction="response", stage="provider", credential_id=cred_context.stable_id, @@ -1132,6 +1135,7 @@ async def _execute_non_streaming( if hasattr(response, "model_dump") else response ) + response_data = _redact_context_field_cache_paths(response_data, context, "response", plugin) context.transaction_logger.log_response( response_data ) @@ -1141,10 +1145,11 @@ async def _execute_non_streaming( ) normalized_response = self._normalize_response_usage(response, model) + trace_normalized_response = _redact_context_field_cache_paths(normalized_response, context, "response", plugin) self._log_executor_trace( context, "post_usage_normalization_response", - normalized_response, + trace_normalized_response, direction="response", stage="final", credential_id=cred_context.stable_id, @@ -2262,6 +2267,73 @@ def _provider_supports_native_streaming(plugin: Any, model: str) -> bool: return bool(support()) +def _redact_context_field_cache_paths(payload: Any, context: RequestContext, direction: str, plugin: Any) -> Any: + """Redact configured field-cache paths before executor-level traces. + + Native provider execution already redacts its internal trace passes. The + client executor also logs returned responses, so it must apply the same + rule-aware redaction there without mutating the client-facing response. + """ + + if direction not in {"response", "stream"} or not plugin or not _provider_native_protocol(plugin, context.model, _current_route_target(context)): + return payload + try: + rules = _merged_field_cache_rules(context.provider, context.model, plugin) + except RoutingExecutionError: + raise + except Exception: + return payload + if not rules: + return payload + redacted = deepcopy(payload) + for rule in rules: + if direction == "response" and getattr(rule, "source", None) not in {"response", "unified_response"}: + continue + if direction == "stream" and getattr(rule, "source", None) not in {"stream_event", "unified_stream_event"}: + continue + try: + _redact_trace_path(redacted, parse_path(rule.path)) + except (FieldCachePathError, TypeError, ValueError): + continue + return redacted + + +def _redact_trace_path(value: Any, tokens: tuple[PathToken, ...]) -> None: + if not tokens: + return + token = tokens[0] + rest = tokens[1:] + if token.kind == "key": + if isinstance(value, dict) and token.value in value: + if rest: + _redact_trace_path(value[token.value], rest) + else: + value[token.value] = REDACTED + return + if token.kind == "index": + if isinstance(value, list) and value: + index = int(token.value) + if -len(value) <= index < len(value): + if rest: + _redact_trace_path(value[index], rest) + else: + value[index] = REDACTED + return + if token.kind == "wildcard": + if isinstance(value, dict): + for key in list(value.keys()): + if rest: + _redact_trace_path(value[key], rest) + else: + value[key] = REDACTED + elif isinstance(value, list): + for index, item in enumerate(value): + if rest: + _redact_trace_path(item, rest) + else: + value[index] = REDACTED + + def _merged_field_cache_rules(provider: str, model: str, plugin: Any) -> tuple[Any, ...]: """Merge provider-declared and JSON-configured field-cache rules. diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 538fa1ef1..b67d6dd44 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -77,6 +77,19 @@ def get_field_cache_rules(self, model=""): ) +class NativePluginWithVendorRule(NativePlugin): + def get_field_cache_rules(self, model=""): + return ( + FieldCacheRule( + name="vendor_state", + source="response", + path="choices.0.message.vendor_state", + inject=FieldCacheInjection(target="request", path="metadata.vendor_state"), + allow_missing_session=True, + ), + ) + + class CustomPlugin: def __init__(self): self.calls = [] @@ -196,6 +209,16 @@ def test_native_context_raises_on_invalid_field_cache_config(monkeypatch, tmp_pa assert exc.value.error_type == "configuration_error" +def test_executor_trace_redaction_uses_native_field_cache_response_paths() -> None: + context = _context(parse_route_target("provider/gpt-test")) + payload = {"choices": [{"message": {"role": "assistant", "content": "ok", "vendor_state": "opaque-vendor-state"}}]} + + redacted = executor_module._redact_context_field_cache_paths(payload, context, "response", NativePluginWithVendorRule()) + + assert redacted["choices"][0]["message"]["vendor_state"] == "[REDACTED]" + assert payload["choices"][0]["message"]["vendor_state"] == "opaque-vendor-state" + + @pytest.mark.asyncio async def test_litellm_fallback_execution_is_explicit(monkeypatch) -> None: calls = [] From 6740ed3c6cf06496f2f6e0b91a493e869050525e Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:12:07 +0200 Subject: [PATCH 098/182] fix(field-cache): redact native stream traces Extends active field-cache path redaction to RequestExecutor stream logging for raw SSE chunks, parsed stream chunks, stream error events, assembled stream responses, and final response logs. The stream bytes yielded to clients remain unchanged; only trace payloads are redacted. Adds a regression for arbitrary vendor_state cache fields in native stream traces. Tests: pytest tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_request_executor_native_routing.py tests/test_experimental_config.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 37 +++++++++++++++---- tests/test_request_executor_native_routing.py | 12 ++++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index a35294cc2..67a774613 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1457,6 +1457,8 @@ async def _execute_streaming( base_stream, context.transaction_logger, context.kwargs, + context=context, + plugin=plugin, ): last_streamed_chunk = chunk yield chunk @@ -2141,6 +2143,9 @@ async def _transaction_logging_stream_wrapper( stream: AsyncGenerator[str, None], transaction_logger: TransactionLogger, request_kwargs: Dict[str, Any], + *, + context: Optional[RequestContext] = None, + plugin: Any = None, ) -> AsyncGenerator[str, None]: """ Wrap a stream to log chunks and final response to TransactionLogger. @@ -2158,9 +2163,10 @@ async def _transaction_logging_stream_wrapper( chunks = [] async for sse_line in stream: + trace_sse_line = _redact_stream_sse_for_trace(sse_line, context, plugin) transaction_logger.log_transform_pass( "raw_stream_chunk", - sse_line, + trace_sse_line, direction="stream", stage="client", transport="sse", @@ -2169,7 +2175,7 @@ async def _transaction_logging_stream_wrapper( if sse_line.startswith("data: [DONE]"): transaction_logger.log_transform_pass( "stream_done_event", - {"raw": sse_line}, + {"raw": trace_sse_line}, direction="stream", stage="final", transport="sse", @@ -2186,11 +2192,12 @@ async def _transaction_logging_stream_wrapper( if content: chunk_data = json.loads(content) chunks.append(chunk_data) - transaction_logger.log_stream_chunk(chunk_data) + trace_chunk_data = _redact_context_field_cache_paths(chunk_data, context, "stream", plugin) if context else chunk_data + transaction_logger.log_stream_chunk(trace_chunk_data) if isinstance(chunk_data, dict) and chunk_data.get("error") is not None: transaction_logger.log_transform_pass( "stream_error_event", - chunk_data, + trace_chunk_data, direction="stream", stage="client", transport="sse", @@ -2205,14 +2212,15 @@ async def _transaction_logging_stream_wrapper( if chunks: try: final_response = TransactionLogger.assemble_streaming_response(chunks) + trace_final_response = _redact_context_field_cache_paths(final_response, context, "stream", plugin) if context else final_response transaction_logger.log_transform_pass( "assembled_stream_response", - final_response, + trace_final_response, direction="response", stage="client", transport="sse", ) - transaction_logger.log_response(final_response) + transaction_logger.log_response(trace_final_response) except Exception as e: lib_logger.debug( f"Failed to assemble/log final streaming response: {e}" @@ -2289,7 +2297,7 @@ def _redact_context_field_cache_paths(payload: Any, context: RequestContext, dir for rule in rules: if direction == "response" and getattr(rule, "source", None) not in {"response", "unified_response"}: continue - if direction == "stream" and getattr(rule, "source", None) not in {"stream_event", "unified_stream_event"}: + if direction == "stream" and getattr(rule, "source", None) not in {"stream_event", "unified_stream_event", "response", "unified_response"}: continue try: _redact_trace_path(redacted, parse_path(rule.path)) @@ -2298,6 +2306,21 @@ def _redact_context_field_cache_paths(payload: Any, context: RequestContext, dir return redacted +def _redact_stream_sse_for_trace(sse_line: str, context: Optional[RequestContext], plugin: Any) -> str: + """Return a trace-only SSE line with configured native cache paths redacted.""" + + if not context or not isinstance(sse_line, str) or not sse_line.startswith("data: ") or sse_line.startswith("data: [DONE]"): + return sse_line + try: + payload = json.loads(sse_line[6:].strip()) + except json.JSONDecodeError: + return sse_line + redacted = _redact_context_field_cache_paths(payload, context, "stream", plugin) + if redacted is payload: + return sse_line + return f"data: {json.dumps(redacted, separators=(',', ':'))}\n\n" + + def _redact_trace_path(value: Any, tokens: tuple[PathToken, ...]) -> None: if not tokens: return diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index b67d6dd44..0e7144bcf 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -219,6 +219,18 @@ def test_executor_trace_redaction_uses_native_field_cache_response_paths() -> No assert payload["choices"][0]["message"]["vendor_state"] == "opaque-vendor-state" +def test_executor_stream_trace_redaction_uses_native_field_cache_paths() -> None: + context = _context(parse_route_target("provider/gpt-test")) + sse_line = 'data: {"choices":[{"delta":{"content":"ok"},"message":{"vendor_state":"opaque-vendor-state"}}]}\n\n' + + redacted = executor_module._redact_stream_sse_for_trace(sse_line, context, NativePluginWithVendorRule()) + + parsed = json.loads(redacted[6:].strip()) + + assert "opaque-vendor-state" not in redacted + assert parsed["choices"][0]["message"]["vendor_state"] == "[REDACTED]" + + @pytest.mark.asyncio async def test_litellm_fallback_execution_is_explicit(monkeypatch) -> None: calls = [] From 7c969aac23abe48824f75459650c69176665d701 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:19:23 +0200 Subject: [PATCH 099/182] fix(field-cache): redact stream-event envelope paths Adds trace-only redaction fallbacks for stream_event rules that target unified raw.* paths while traces contain raw or formatted stream chunk shapes. Native and executor stream traces now strip the raw. envelope prefix where needed and recursively redact the configured terminal cache key across duplicated trace envelopes such as raw payload, delta extras, and extra.payload. Tests: pytest tests/test_field_cache_paths.py tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_request_executor_native_routing.py tests/test_experimental_config.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_transform_trace.py tests/test_transaction_logger_transform_trace.py --- src/rotator_library/client/executor.py | 39 +++++++++++++++++-- .../native_provider/executor.py | 38 ++++++++++++++++-- tests/test_native_provider_streaming.py | 9 ++++- tests/test_request_executor_native_routing.py | 14 ++++++- 4 files changed, 89 insertions(+), 11 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 67a774613..e8af84f83 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -2299,13 +2299,27 @@ def _redact_context_field_cache_paths(payload: Any, context: RequestContext, dir continue if direction == "stream" and getattr(rule, "source", None) not in {"stream_event", "unified_stream_event", "response", "unified_response"}: continue - try: - _redact_trace_path(redacted, parse_path(rule.path)) - except (FieldCachePathError, TypeError, ValueError): - continue + for path in _trace_redaction_paths((rule.path,), direction=direction): + try: + tokens = parse_path(path) + _redact_trace_path(redacted, tokens) + _redact_trace_leaf_key(redacted, tokens) + except (FieldCachePathError, TypeError, ValueError): + continue return redacted +def _trace_redaction_paths(paths: tuple[str, ...] | list[str], *, direction: str) -> list[str]: + """Return configured paths plus raw-stream envelope fallbacks for traces.""" + + expanded: list[str] = [] + for path in paths: + expanded.append(path) + if direction == "stream" and path.startswith("raw."): + expanded.append(path[4:]) + return expanded + + def _redact_stream_sse_for_trace(sse_line: str, context: Optional[RequestContext], plugin: Any) -> str: """Return a trace-only SSE line with configured native cache paths redacted.""" @@ -2357,6 +2371,23 @@ def _redact_trace_path(value: Any, tokens: tuple[PathToken, ...]) -> None: value[index] = REDACTED +def _redact_trace_leaf_key(value: Any, tokens: tuple[PathToken, ...]) -> None: + """Redact terminal configured cache keys across duplicated trace envelopes.""" + + leaf = next((token.value for token in reversed(tokens) if token.kind == "key"), None) + if not leaf: + return + if isinstance(value, dict): + for key, item in list(value.items()): + if key == leaf: + value[key] = REDACTED + else: + _redact_trace_leaf_key(item, tokens) + elif isinstance(value, list): + for item in value: + _redact_trace_leaf_key(item, tokens) + + def _merged_field_cache_rules(provider: str, model: str, plugin: Any) -> tuple[Any, ...]: """Merge provider-declared and JSON-configured field-cache rules. diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index df1a9be3b..92bdb3d49 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -11,7 +11,7 @@ from ..adapters import get_adapter, run_adapter_chain from ..field_cache import FieldCacheEngine, InMemoryFieldCacheStore from ..field_cache.paths import FieldCachePathError, PathToken, parse_path -from ..protocols import get_protocol +from ..protocols import get_protocol, serialize_value from ..transform_trace import REDACTED from ..usage.accounting import extract_usage_record from .context import NativeProviderContext @@ -194,21 +194,34 @@ def _redact_field_cache_paths(data: Any, context: NativeProviderContext, directi if not context.field_cache_rules: return data - redacted = deepcopy(data) + redacted = serialize_value(deepcopy(data)) for rule in context.field_cache_rules: paths: list[str] = [] if direction == "request" and rule.inject: paths.append(rule.inject.path) if direction in {"response", "stream"}: paths.append(rule.path) - for path in paths: + for path in _trace_redaction_paths(paths, direction=direction): try: - _redact_path(redacted, parse_path(path)) + tokens = parse_path(path) + _redact_path(redacted, tokens) + _redact_leaf_key(redacted, tokens) except (FieldCachePathError, TypeError, ValueError): continue return redacted +def _trace_redaction_paths(paths: list[str], *, direction: str) -> list[str]: + """Return configured paths plus raw-stream envelope fallbacks for traces.""" + + expanded: list[str] = [] + for path in paths: + expanded.append(path) + if direction == "stream" and path.startswith("raw."): + expanded.append(path[4:]) + return expanded + + def _redact_path(value: Any, tokens: tuple[PathToken, ...]) -> None: if not tokens: return @@ -243,3 +256,20 @@ def _redact_path(value: Any, tokens: tuple[PathToken, ...]) -> None: _redact_path(item, rest) else: value[index] = REDACTED + + +def _redact_leaf_key(value: Any, tokens: tuple[PathToken, ...]) -> None: + """Redact the configured terminal key wherever stream traces duplicate it.""" + + leaf = next((token.value for token in reversed(tokens) if token.kind == "key"), None) + if not leaf: + return + if isinstance(value, dict): + for key, item in list(value.items()): + if key == leaf: + value[key] = REDACTED + else: + _redact_leaf_key(item, tokens) + elif isinstance(value, list): + for item in value: + _redact_leaf_key(item, tokens) diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index 62f94473a..e35fcb144 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -32,11 +32,14 @@ async def test_native_provider_stream_traces_and_yields_formatted_events(tmp_pat model="gpt-test", protocol_name="openai_chat", endpoint="https://example.test/chat", - field_cache_rules=(FieldCacheRule(name="stream_reasoning", source="stream_event", path="raw.choices.0.delta.reasoning_content", allow_missing_session=True),), + field_cache_rules=( + FieldCacheRule(name="stream_reasoning", source="stream_event", path="raw.choices.0.delta.reasoning_content", allow_missing_session=True), + FieldCacheRule(name="stream_vendor_state", source="stream_event", path="raw.choices.0.delta.vendor_state", allow_missing_session=True), + ), transaction_logger=logger, ) chunks = [ - {"choices": [{"delta": {"content": "hi", "reasoning_content": "hidden"}}]}, + {"choices": [{"delta": {"content": "hi", "reasoning_content": "hidden", "vendor_state": "opaque-vendor-state"}}]}, "[DONE]", ] client = FakeStreamingClient(chunks) @@ -52,6 +55,8 @@ async def test_native_provider_stream_traces_and_yields_formatted_events(tmp_pat assert "after_field_cache_extraction" in pass_names assert "after_field_cache_stream_extraction" in pass_names assert pass_names.count("formatted_client_stream_event") == 2 + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "opaque-vendor-state" not in trace_text @pytest.mark.asyncio diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 0e7144bcf..fcd49c6b0 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -90,6 +90,18 @@ def get_field_cache_rules(self, model=""): ) +class NativePluginWithStreamVendorRule(NativePlugin): + def get_field_cache_rules(self, model=""): + return ( + FieldCacheRule( + name="vendor_state", + source="stream_event", + path="raw.choices.0.message.vendor_state", + allow_missing_session=True, + ), + ) + + class CustomPlugin: def __init__(self): self.calls = [] @@ -223,7 +235,7 @@ def test_executor_stream_trace_redaction_uses_native_field_cache_paths() -> None context = _context(parse_route_target("provider/gpt-test")) sse_line = 'data: {"choices":[{"delta":{"content":"ok"},"message":{"vendor_state":"opaque-vendor-state"}}]}\n\n' - redacted = executor_module._redact_stream_sse_for_trace(sse_line, context, NativePluginWithVendorRule()) + redacted = executor_module._redact_stream_sse_for_trace(sse_line, context, NativePluginWithStreamVendorRule()) parsed = json.loads(redacted[6:].strip()) From 3055003b45495889e23eb489b11a5bb36b5dabec Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:25:41 +0200 Subject: [PATCH 100/182] docs(experimental): plan responses correction Adds the Phase 4b corrective plan for Responses session anchoring, storage TTL/pruning, failed/current-state storage behavior, and transport-neutral streaming events. The plan is based on the current service, store, route, protocol, and streaming implementations and preserves SSE compatibility while preparing WebSocket support without protocol rewrites. Tests: not run (planning document only). --- ...-4b-responses-session-storage-transport.md | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 docs/experimental/phase-4b-responses-session-storage-transport.md diff --git a/docs/experimental/phase-4b-responses-session-storage-transport.md b/docs/experimental/phase-4b-responses-session-storage-transport.md new file mode 100644 index 000000000..81e352540 --- /dev/null +++ b/docs/experimental/phase-4b-responses-session-storage-transport.md @@ -0,0 +1,94 @@ +# Phase 4b Plan: Responses API Session, Storage, And Transport Corrections + +## Goal + +Correct the Phase 4 audit findings without over-claiming full native Responses provider execution. Phase 4 made `/v1/responses` usable through the chat bridge, but the validation pass found that important Responses semantics are still incomplete: `previous_response_id` is only used to replay parent output, not as session-routing evidence; default storage has only incidental expiry support with no runtime TTL policy or max-size pruning; stream generation is SSE-first inside the service instead of cleanly transport-neutral; and route/service tracing/storage should more clearly represent stored/skipped/current-state behavior. + +## Non-Goals + +- Do not implement a WebSocket FastAPI route in this phase unless the transport abstraction makes it trivial and tests stay small. +- Do not replace the bridge with fully native provider Responses execution here. +- Do not introduce SQLite or a database. +- Do not build multi-user security/admin features. +- Do not change current public SSE event order except where tests show it is wrong. +- Do not commit user-facing reports. +- Do not touch unrelated dirty `ARCHITECTURE.md`, `STRUCTURE.md`, `.opencode/`, `docs/issues/`, or old phase reports. + +## Current Code State + +- `ResponsesService.create_response()` parses `previous_response_id`, loads the parent from the store, and passes parent output to `ResponsesBridge.to_chat_kwargs()`. +- `ResponsesBridge.to_chat_kwargs()` stores `previous_response_id` only in private `_responses_bridge` metadata used for trace context, then `ResponsesService` pops that metadata before calling `client.acompletion()`. +- The actual `RotatingClient` request context cannot see `previous_response_id`, so `SessionTracker` cannot use it as a strong anchor for credential affinity. +- `StoredResponse` has `expires_at`, and stores check expiry on `get()`, but `ResponsesService` never sets a TTL policy and `InMemoryResponsesStore` has no max-size pruning. +- `ProviderCacheResponsesStore` exists but app startup always creates `ResponsesService()` with default in-memory store. +- `ResponsesService.stream_response()` directly instantiates `ResponsesSSEFormatter()` and yields formatted strings. +- `ResponsesWebSocketFormatter` exists as a placeholder with `NotImplementedError`, which is honest but does not prove that service logic is formatter-neutral. +- Responses stream storage stores only the completed payload after stream success and can skip storage when `store=false`. It does not store in-progress current state. +- `response.failed` events are yielded on exceptions, but failed responses are not stored as failed state even when `store=true`. + +## Implementation Plan + +1. Add an internal Responses session-hints carrier. + - Introduce private `_session_tracking_hints` request kwargs consumed by `RequestContextBuilder` and removed before provider execution. + - Merge service-level hints with provider hints before `SessionTracker.infer_session()`. + - Preserve provider hints and explain that these hints are proxy-internal evidence, never provider payload. + +2. Make `previous_response_id` a strong Responses session anchor. + - Attach `_session_tracking_hints` for continuation requests with a strong anchor like `responses_previous_response_id:{id}` and deterministic affinity key. + - Do not fake a strong first-turn anchor from a generated response ID before the request executes. + - Add tests proving continuation requests expose hidden hints to context construction and the hidden field is not sent to providers. + +3. Record generated Responses IDs as response-derived metadata. + - Store response IDs and parent IDs clearly in `StoredResponse.metadata`. + - Add a helper for Responses session hints so future native Responses execution can share it. + +4. Add runtime storage policy. + - Add `ResponsesStoreSettings` with `ttl_seconds`, `max_items`, `store_failed`, and `store_in_progress`. + - Preserve current defaults: no expiry, no max pruning, completed responses stored, in-progress updates disabled. + +5. Implement TTL assignment and max-size pruning. + - Set `StoredResponse.expires_at` in `_stored_response()` from settings. + - Add max-item pruning to `InMemoryResponsesStore.save()`. + - Keep provider-cache expiry via `StoredResponse.expires_at`; document max-size limitations if listing is unavailable. + +6. Store failed stream responses when `store=true`. + - Persist `response.failed` payloads when storage policy says failed responses are stored. + - Keep `store=false` skip behavior. + +7. Add a current-state storage seam. + - Add `store_in_progress` default false. + - When enabled, save created/intermediate/completed stream state using the same store API. + - Keep default behavior unchanged. + +8. Refactor stream generation to transport-neutral events. + - Add `ResponsesStreamEvent` and a formatter interface in `responses/streaming.py`. + - Add `ResponsesService.stream_events()` yielding event objects. + - Make `stream_response()` a thin SSE wrapper over `stream_events()`. + - Keep `ResponsesWebSocketFormatter` honest: either pure JSON event formatting or explicit route-level future support, but service logic must not be SSE-specific. + +9. Tighten trace boundaries around the event pipeline. + - Preserve existing `responses_stream_event_*` pass names. + - Avoid trace-only conversions when no transaction logger is present. + - Add trace pass for current-state storage updates if enabled. + +10. Update FastAPI route wiring minimally. + - Keep current `/v1/responses` route behavior. + - Continue using `service.stream_response()` for SSE. + - Defer provider-cache/durable config wiring to config work unless a safe existing setting exists. + +## Tests + +- Responses service tests for hidden session hints, TTL, max pruning, and metadata. +- Responses streaming tests for `stream_events()` order, SSE wrapper compatibility, failed storage, `store=false`, and WebSocket formatter seam. +- Responses store tests for memory max pruning and provider-cache expiry behavior. +- Request-builder/session tracking tests for `_session_tracking_hints` consumption and provider payload cleanup. +- Route regression tests for current HTTP behavior. + +## Acceptance Criteria + +- `previous_response_id` becomes strong session-routing evidence for continuation requests without leaking internal hints to provider payloads. +- Responses storage has explicit TTL/max-size policy and failed-stream storage behavior. +- Streaming service logic can emit transport-neutral event objects, with SSE as a formatter wrapper and WebSocket support no longer requiring service/protocol rewrite. +- Existing `/v1/responses`, retrieve, delete, input_items, and SSE behavior remain compatible. +- Trace-disabled paths avoid unnecessary trace-only conversions. +- Focused tests and dual-agent review pass. From 41cfedea974b80ef13680152b3ac2d507a936a2b Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:28:56 +0200 Subject: [PATCH 101/182] feat(responses): add internal session continuation hints Adds a proxy-internal _session_tracking_hints carrier consumed by RequestContextBuilder before provider execution so service-level session evidence never reaches provider payloads. Responses continuations now attach strong previous_response_id evidence and a deterministic affinity key, allowing the centralized SessionTracker to keep continuation requests sticky without bypassing provider hints. Tests: pytest tests/test_responses_bridge.py tests/test_request_builder_routing.py --- src/rotator_library/client/request_builder.py | 54 ++++++++++++++++--- src/rotator_library/responses/bridge.py | 16 ++++++ tests/test_request_builder_routing.py | 26 ++++++++- tests/test_responses_bridge.py | 2 + 4 files changed, 90 insertions(+), 8 deletions(-) diff --git a/src/rotator_library/client/request_builder.py b/src/rotator_library/client/request_builder.py index df1b400b2..80a4ec728 100644 --- a/src/rotator_library/client/request_builder.py +++ b/src/rotator_library/client/request_builder.py @@ -10,6 +10,7 @@ from ..core.types import RequestContext from ..routing import FallbackResolver, RoutingConfigError, load_routing_config_from_env from ..routing.types import RouteTarget, RoutingDecision +from ..session_tracking import SessionTrackingHints from ..transaction_logger import TransactionLogger @@ -37,13 +38,14 @@ def __init__( self._get_provider_instance = get_provider_instance @staticmethod - def _pop_scope_kwargs(kwargs: Dict[str, Any]) -> tuple[Optional[str], Any, Any, bool]: + def _pop_scope_kwargs(kwargs: Dict[str, Any]) -> tuple[Optional[str], Any, Any, bool, Any]: classifier = kwargs.pop("classifier", None) request_api_keys = kwargs.pop("api_keys", None) request_providers = kwargs.pop("providers", None) private = bool(kwargs.pop("private", False)) + session_tracking_hints = kwargs.pop("_session_tracking_hints", None) kwargs.pop("model_filters", None) - return classifier, request_api_keys, request_providers, private + return classifier, request_api_keys, request_providers, private, session_tracking_hints @staticmethod def _provider_from_model(model: str) -> str: @@ -129,13 +131,47 @@ async def _get_session_hints( ) return None + @staticmethod + def _merge_session_hints(*hints: Any) -> Any: + """Merge proxy-internal and provider session evidence. + + Internal hints are removed from request kwargs before provider execution. + They let services such as Responses expose stable continuation IDs to the + centralized tracker without adding provider-visible payload fields. + """ + + merged = SessionTrackingHints() + seen = False + for hint in hints: + if not hint: + continue + if isinstance(hint, dict): + hint = SessionTrackingHints( + strong_anchors=list(hint.get("strong_anchors") or []), + medium_anchors=list(hint.get("medium_anchors") or []), + weak_anchors=list(hint.get("weak_anchors") or []), + affinity_key=hint.get("affinity_key"), + session_scope=hint.get("session_scope"), + ) + if not isinstance(hint, SessionTrackingHints): + continue + seen = True + merged.strong_anchors.extend(hint.strong_anchors) + merged.medium_anchors.extend(hint.medium_anchors) + merged.weak_anchors.extend(hint.weak_anchors) + if not merged.affinity_key and hint.affinity_key: + merged.affinity_key = hint.affinity_key + if not merged.session_scope and hint.session_scope: + merged.session_scope = hint.session_scope + return merged if seen else None + async def build_completion_context( self, request: Optional[Any], pre_request_callback: Optional[Callable], kwargs: Dict[str, Any], ) -> RequestContext: - classifier, request_api_keys, request_providers, private = self._pop_scope_kwargs( + classifier, request_api_keys, request_providers, private, internal_session_hints = self._pop_scope_kwargs( kwargs ) model = kwargs.get("model", "") @@ -189,7 +225,10 @@ async def build_completion_context( provider=provider, model=resolved_model, scope_key=scope["usage_manager_key"], - hints=await self._get_session_hints(provider, resolved_model, kwargs), + hints=self._merge_session_hints( + internal_session_hints, + await self._get_session_hints(provider, resolved_model, kwargs), + ), ) if transaction_logger: transaction_logger.set_trace_context( @@ -229,7 +268,7 @@ async def build_embedding_context( pre_request_callback: Optional[Callable], kwargs: Dict[str, Any], ) -> RequestContext: - classifier, request_api_keys, request_providers, private = self._pop_scope_kwargs( + classifier, request_api_keys, request_providers, private, internal_session_hints = self._pop_scope_kwargs( kwargs ) model = kwargs.get("model", "") @@ -252,7 +291,10 @@ async def build_embedding_context( provider=provider, model=model, scope_key=scope["usage_manager_key"], - hints=await self._get_session_hints(provider, model, kwargs), + hints=self._merge_session_hints( + internal_session_hints, + await self._get_session_hints(provider, model, kwargs), + ), ) return RequestContext( diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index 3c2c9425c..ba7c539fb 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -80,6 +80,9 @@ def to_chat_kwargs( "extra": deepcopy(unified.extra), } kwargs["_responses_bridge"] = {k: v for k, v in unsupported.items() if v} + hints = responses_session_hints(unified.previous_response_id) + if hints: + kwargs["_session_tracking_hints"] = hints return kwargs def from_chat_response( @@ -118,6 +121,19 @@ def _chat_generation_key(key: str) -> str: return key +def responses_session_hints(previous_response_id: Optional[str]) -> dict[str, Any] | None: + """Return proxy-internal sticky routing evidence for Responses continuations.""" + + if not previous_response_id: + return None + anchor = f"responses_previous_response_id:{previous_response_id}" + return { + "strong_anchors": [anchor], + "affinity_key": anchor, + "session_scope": "responses", + } + + def _message_to_chat(message: UnifiedMessage) -> dict[str, Any]: payload = {"role": message.role, "content": _blocks_to_chat_content(message.content)} if message.name: diff --git a/tests/test_request_builder_routing.py b/tests/test_request_builder_routing.py index ce4dbcf25..53ff9e3a4 100644 --- a/tests/test_request_builder_routing.py +++ b/tests/test_request_builder_routing.py @@ -19,7 +19,11 @@ class FakeSession: class FakeSessionTracker: + def __init__(self): + self.calls = [] + def infer_session(self, *args, **kwargs): + self.calls.append((args, kwargs)) return FakeSession() @@ -33,11 +37,11 @@ async def _scope(provider, classifier, request_api_keys, request_providers, priv } -def _builder() -> RequestContextBuilder: +def _builder(session_tracker=None) -> RequestContextBuilder: return RequestContextBuilder( resolve_scope_for_provider=_scope, model_resolver=FakeModelResolver(), - session_tracker=FakeSessionTracker(), + session_tracker=session_tracker or FakeSessionTracker(), get_global_timeout=lambda: 30, get_enable_request_logging=lambda: False, ) @@ -76,3 +80,21 @@ async def test_request_builder_rejects_unprefixed_model_without_route(monkeypatc with pytest.raises(ValueError): await _builder().build_completion_context(None, None, {"model": "gpt-5.1", "messages": []}) + + +@pytest.mark.asyncio +async def test_request_builder_consumes_internal_session_tracking_hints(monkeypatch) -> None: + monkeypatch.delenv("FALLBACK_GROUPS", raising=False) + tracker = FakeSessionTracker() + kwargs = { + "model": "openai/gpt-5.1", + "messages": [], + "_session_tracking_hints": {"strong_anchors": ["responses_previous_response_id:resp_parent"], "affinity_key": "responses_previous_response_id:resp_parent"}, + } + + context = await _builder(tracker).build_completion_context(None, None, kwargs) + + assert "_session_tracking_hints" not in context.kwargs + hints = tracker.calls[0][1]["hints"] + assert hints.strong_anchors == ["responses_previous_response_id:resp_parent"] + assert hints.affinity_key == "responses_previous_response_id:resp_parent" diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py index ecb50b0ad..63d973a73 100644 --- a/tests/test_responses_bridge.py +++ b/tests/test_responses_bridge.py @@ -41,6 +41,8 @@ def test_bridge_adds_parent_response_messages_for_previous_response_id() -> None {"role": "user", "content": "Continue"}, ] assert kwargs["_responses_bridge"]["previous_response_id"] == "resp_parent" + assert kwargs["_session_tracking_hints"]["strong_anchors"] == ["responses_previous_response_id:resp_parent"] + assert kwargs["_session_tracking_hints"]["affinity_key"] == "responses_previous_response_id:resp_parent" def test_bridge_preserves_tool_definitions() -> None: From ab70413b20bfa0f633f9e56436589d6247ca7743 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:31:11 +0200 Subject: [PATCH 102/182] feat(responses): add storage policy controls Adds explicit Responses storage settings for TTL, max memory items, failed stream persistence, and future in-progress updates while preserving existing defaults. In-memory storage now prunes expired entries and oldest overflow entries. ResponsesService assigns expires_at from policy metadata and stores failed stream responses when enabled. Tests: pytest tests/test_responses_service.py tests/test_responses_streaming.py tests/test_responses_store.py --- src/rotator_library/responses/__init__.py | 3 +- src/rotator_library/responses/service.py | 28 ++++++++++++++-- src/rotator_library/responses/store.py | 17 +++++++++- src/rotator_library/responses/types.py | 15 +++++++++ tests/test_responses_service.py | 39 ++++++++++++++++++++++- tests/test_responses_store.py | 19 +++++++++++ tests/test_responses_streaming.py | 20 ++++++++++-- 7 files changed, 133 insertions(+), 8 deletions(-) diff --git a/src/rotator_library/responses/__init__.py b/src/rotator_library/responses/__init__.py index 04de91582..01c420439 100644 --- a/src/rotator_library/responses/__init__.py +++ b/src/rotator_library/responses/__init__.py @@ -7,7 +7,7 @@ from .service import ResponsesService, ResponsesServiceError from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore from .streaming import ResponsesSSEFormatter, ResponsesWebSocketFormatter -from .types import StoredResponse, generate_response_id +from .types import ResponsesStoreSettings, StoredResponse, generate_response_id __all__ = [ "InMemoryResponsesStore", @@ -15,6 +15,7 @@ "ResponsesBridge", "ResponsesService", "ResponsesServiceError", + "ResponsesStoreSettings", "ResponsesSSEFormatter", "ResponsesStore", "ResponsesWebSocketFormatter", diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 11714041b..a1ed6eefc 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -26,7 +26,7 @@ response_created_payload, response_failed_payload, ) -from .types import StoredResponse +from .types import ResponsesStoreSettings, StoredResponse from .types import generate_response_id @@ -53,10 +53,12 @@ def __init__( protocol: Optional[ResponsesProtocol] = None, bridge: Optional[ResponsesBridge] = None, store: Optional[ResponsesStore] = None, + store_settings: Optional[ResponsesStoreSettings] = None, ) -> None: + self.store_settings = store_settings or ResponsesStoreSettings() self.protocol = protocol or ResponsesProtocol() self.bridge = bridge or ResponsesBridge(self.protocol) - self.store = store or InMemoryResponsesStore() + self.store = store or InMemoryResponsesStore(max_items=self.store_settings.max_items) async def create_response( self, @@ -253,6 +255,9 @@ async def stream_response( monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) + stored = await self._store_stream_response(stream_request, failed, parent, failed=True) + if stored: + self._trace(transaction_logger, "responses_stored_failed_stream_response", {"response_id": failed.get("id"), "status": "failed"}, direction="metadata", stage="final") self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": "sse"}, scrub_strings=True) if transaction_logger: self._trace( @@ -329,7 +334,11 @@ def _stored_response( input_items=_input_items(raw_request), output_items=deepcopy(response_payload.get("output") or []), usage=deepcopy(response_payload.get("usage")) if isinstance(response_payload.get("usage"), dict) else None, - metadata={"previous_response_id": parent.id if parent else raw_request.get("previous_response_id")}, + metadata={ + "previous_response_id": parent.id if parent else raw_request.get("previous_response_id"), + "response_id": response_payload.get("id"), + }, + expires_at=_expires_at(self.store_settings), ) @staticmethod @@ -400,9 +409,13 @@ async def _store_stream_response( raw_request: dict[str, Any], response_payload: dict[str, Any], parent: Optional[StoredResponse], + *, + failed: bool = False, ) -> bool: if not raw_request.get("store", True): return False + if failed and not self.store_settings.store_failed: + return False await self.store.save(self._stored_response(raw_request, response_payload, parent)) return True @@ -414,6 +427,15 @@ def _input_items(raw_request: dict[str, Any]) -> list[Any]: return deepcopy(value if isinstance(value, list) else [value]) +def _expires_at(settings: ResponsesStoreSettings) -> Optional[float]: + """Return the expiration timestamp for a new stored response, if enabled.""" + + ttl = settings.ttl_seconds + if ttl is None or ttl <= 0: + return None + return time.time() + ttl + + def _chunk_text_delta(chunk: dict[str, Any]) -> str: choices = chunk.get("choices") if isinstance(chunk.get("choices"), list) else [] if not choices: diff --git a/src/rotator_library/responses/store.py b/src/rotator_library/responses/store.py index 549ee64ba..b5c507ee9 100644 --- a/src/rotator_library/responses/store.py +++ b/src/rotator_library/responses/store.py @@ -32,11 +32,14 @@ class InMemoryResponsesStore: later configuration code when disk persistence is desired. """ - def __init__(self) -> None: + def __init__(self, *, max_items: int | None = None) -> None: self._responses: dict[str, StoredResponse] = {} + self.max_items = max_items if max_items and max_items > 0 else None async def save(self, response: StoredResponse) -> None: + self._prune_expired() self._responses[response.id] = StoredResponse.from_dict(response.to_dict()) + self._prune_overflow() async def get(self, response_id: str) -> Optional[StoredResponse]: response = self._responses.get(response_id) @@ -56,6 +59,18 @@ async def list_input_items(self, response_id: str) -> Optional[list[Any]]: return None return deepcopy(response.input_items) + def _prune_expired(self) -> None: + for response_id, response in list(self._responses.items()): + if response.is_expired(): + self._responses.pop(response_id, None) + + def _prune_overflow(self) -> None: + if not self.max_items: + return + while len(self._responses) > self.max_items: + oldest_id = min(self._responses.values(), key=lambda response: response.created_at).id + self._responses.pop(oldest_id, None) + class ProviderCacheResponsesStore: """Responses store backed by an injected `ProviderCache` instance. diff --git a/src/rotator_library/responses/types.py b/src/rotator_library/responses/types.py index 5edb78297..9b2a91e29 100644 --- a/src/rotator_library/responses/types.py +++ b/src/rotator_library/responses/types.py @@ -24,6 +24,21 @@ def generate_response_id() -> str: return f"resp_{secrets.token_urlsafe(18).replace('-', '').replace('_', '')[:24]}" +@dataclass +class ResponsesStoreSettings: + """Runtime policy for Responses object storage. + + The defaults preserve Phase 4 behavior. Operators can opt into TTL, bounded + memory, failed-response persistence, or in-progress updates without changing + the store interface or introducing a database. + """ + + ttl_seconds: Optional[int] = None + max_items: Optional[int] = None + store_failed: bool = True + store_in_progress: bool = False + + @dataclass class StoredResponse: """Persisted response object used for retrieval and continuation. diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py index 52585e589..976444dfc 100644 --- a/tests/test_responses_service.py +++ b/tests/test_responses_service.py @@ -5,7 +5,7 @@ import pytest import rotator_library.responses.service as responses_service_module -from rotator_library.responses import InMemoryResponsesStore, ResponsesService, ResponsesServiceError, StoredResponse +from rotator_library.responses import InMemoryResponsesStore, ResponsesService, ResponsesServiceError, ResponsesStoreSettings, StoredResponse from rotator_library.transaction_logger import TransactionLogger @@ -51,6 +51,43 @@ async def test_store_false_does_not_persist_response() -> None: assert await store.get(response["id"]) is None +@pytest.mark.asyncio +async def test_create_response_applies_storage_ttl() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store, store_settings=ResponsesStoreSettings(ttl_seconds=60)) + + response = await service.create_response({"model": "gpt-test", "input": "Hello"}, FakeClient()) + stored = await store.get(response["id"]) + + assert stored is not None + assert stored.expires_at is not None + assert stored.expires_at > stored.created_at + assert stored.metadata["response_id"] == response["id"] + + +@pytest.mark.asyncio +async def test_service_default_store_honors_max_items() -> None: + class SequencedClient(FakeClient): + def __init__(self) -> None: + super().__init__() + self.index = 0 + + async def acompletion(self, **kwargs): + self.index += 1 + response = await super().acompletion(**kwargs) + response["id"] = f"chat_response_{self.index}" + return response + + service = ResponsesService(store_settings=ResponsesStoreSettings(max_items=1)) + client = SequencedClient() + + first = await service.create_response({"model": "gpt-test", "input": "one"}, client) + second = await service.create_response({"model": "gpt-test", "input": "two"}, client) + + assert await service.store.get(first["id"]) is None + assert await service.store.get(second["id"]) is not None + + @pytest.mark.asyncio async def test_previous_response_id_loads_parent_context() -> None: store = InMemoryResponsesStore() diff --git a/tests/test_responses_store.py b/tests/test_responses_store.py index 0f8ff720c..b7025f74a 100644 --- a/tests/test_responses_store.py +++ b/tests/test_responses_store.py @@ -61,6 +61,25 @@ async def test_in_memory_store_returns_copies_and_expires() -> None: assert await store.get(stored.id) is None +@pytest.mark.asyncio +async def test_in_memory_store_prunes_oldest_when_max_items_exceeded() -> None: + store = InMemoryResponsesStore(max_items=2) + first = _stored("resp_first") + first.created_at = 1 + second = _stored("resp_second") + second.created_at = 2 + third = _stored("resp_third") + third.created_at = 3 + + await store.save(first) + await store.save(second) + await store.save(third) + + assert await store.get("resp_first") is None + assert await store.get("resp_second") is not None + assert await store.get("resp_third") is not None + + @pytest.mark.asyncio async def test_provider_cache_store_serializes_json_and_does_not_clear_without_key_delete() -> None: class FakeProviderCache: diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 54f2d7c98..86562311d 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -2,7 +2,7 @@ import pytest -from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesWebSocketFormatter +from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesStoreSettings, ResponsesWebSocketFormatter from rotator_library.transaction_logger import TransactionLogger @@ -77,7 +77,8 @@ async def test_stream_response_store_false_does_not_persist(tmp_path) -> None: @pytest.mark.asyncio async def test_stream_response_errors_emit_failed_event() -> None: - service = ResponsesService(store=InMemoryResponsesStore()) + store = InMemoryResponsesStore() + service = ResponsesService(store=store) events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FailingStreamingClient())] @@ -85,6 +86,21 @@ async def test_stream_response_errors_emit_failed_event() -> None: assert "event: response.failed" in event_text assert "stream exploded" in event_text assert event_text.endswith("data: [DONE]\n\n") + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.status == "failed" + + +@pytest.mark.asyncio +async def test_stream_response_can_skip_failed_storage() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store, store_settings=ResponsesStoreSettings(store_failed=False)) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, FailingStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + assert await store.get(response_id) is None @pytest.mark.asyncio From 99955584bc6998b7bb75a38615686b186b1d7b12 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:33:57 +0200 Subject: [PATCH 103/182] feat(responses): emit transport neutral stream events Refactors Responses streaming so service logic yields transport-neutral ResponsesStreamEvent objects and HTTP SSE is only a formatter wrapper. Adds JSON WebSocket event formatting as a transport seam, keeps existing SSE route behavior, and implements default-off in-progress stream state storage for current-state retrieval surfaces. Tests: pytest tests/test_responses_streaming.py tests/test_responses_service.py tests/test_responses_store.py tests/test_responses_routes.py --- src/rotator_library/responses/__init__.py | 3 +- src/rotator_library/responses/service.py | 70 ++++++++++++++++++---- src/rotator_library/responses/streaming.py | 29 ++++++++- tests/test_responses_streaming.py | 38 +++++++++++- 4 files changed, 123 insertions(+), 17 deletions(-) diff --git a/src/rotator_library/responses/__init__.py b/src/rotator_library/responses/__init__.py index 01c420439..a16e76bde 100644 --- a/src/rotator_library/responses/__init__.py +++ b/src/rotator_library/responses/__init__.py @@ -6,7 +6,7 @@ from .bridge import ResponsesBridge from .service import ResponsesService, ResponsesServiceError from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore -from .streaming import ResponsesSSEFormatter, ResponsesWebSocketFormatter +from .streaming import ResponsesSSEFormatter, ResponsesStreamEvent, ResponsesWebSocketFormatter from .types import ResponsesStoreSettings, StoredResponse, generate_response_id __all__ = [ @@ -17,6 +17,7 @@ "ResponsesServiceError", "ResponsesStoreSettings", "ResponsesSSEFormatter", + "ResponsesStreamEvent", "ResponsesStore", "ResponsesWebSocketFormatter", "StoredResponse", diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index a1ed6eefc..a6203571b 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -17,6 +17,7 @@ from .store import InMemoryResponsesStore, ResponsesStore from .streaming import ( ResponsesSSEFormatter, + ResponsesStreamEvent, ResponsesStreamState, output_item_added_payload, output_item_done_payload, @@ -118,9 +119,22 @@ async def stream_response( ) -> AsyncGenerator[str, None]: """Stream a Responses API request as HTTP SSE events.""" + formatter = ResponsesSSEFormatter() + async for event in self.stream_events(raw_request, client, request=request, transaction_logger=transaction_logger): + yield formatter.format_stream_event(event) + + async def stream_events( + self, + raw_request: dict[str, Any], + client: Any, + *, + request: Optional[Any] = None, + transaction_logger: Optional[Any] = None, + ) -> AsyncGenerator[ResponsesStreamEvent, None]: + """Yield transport-neutral Responses events for streaming transports.""" + if not raw_request.get("model"): raise ResponsesServiceError("'model' is required", status_code=400) - formatter = ResponsesSSEFormatter() stream_request = dict(raw_request) stream_request["stream"] = True self._trace(transaction_logger, "responses_raw_request", stream_request, direction="request", stage="client") @@ -156,7 +170,8 @@ async def stream_response( ) created = response_created_payload(response_id, unified.model) self._trace(transaction_logger, "responses_stream_event_created", created, direction="stream", stage="final", metadata={"transport": "sse"}) - yield formatter.format_event("response.created", created) + await self._store_stream_current_state(stream_request, created, parent, transaction_logger=transaction_logger) + yield ResponsesStreamEvent("response.created", created) try: chat_stream = await client.acompletion(request=request, **chat_kwargs) async for raw_chunk in chat_stream: @@ -185,7 +200,7 @@ async def stream_response( item_started = True added = output_item_added_payload(state) self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": "sse"}) - yield formatter.format_event("response.output_item.added", added) + yield ResponsesStreamEvent("response.output_item.added", added) state = ResponsesStreamState( response_id=state.response_id, model=state.model, @@ -213,15 +228,16 @@ async def stream_response( metadata={"transport": "sse"}, ) self._trace(transaction_logger, "responses_stream_event_output_text_delta", event, direction="stream", stage="final", metadata={"transport": "sse"}) - yield formatter.format_event("response.output_text.delta", event) + await self._store_stream_current_state(stream_request, _current_stream_payload(state), parent, transaction_logger=transaction_logger) + yield ResponsesStreamEvent("response.output_text.delta", event) if not item_started: added = output_item_added_payload(state) self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": "sse"}) - yield formatter.format_event("response.output_item.added", added) + yield ResponsesStreamEvent("response.output_item.added", added) done_item = output_item_done_payload(state) self._trace(transaction_logger, "responses_stream_event_output_item_done", done_item, direction="stream", stage="final", metadata={"transport": "sse"}) - yield formatter.format_event("response.output_item.done", done_item) + yield ResponsesStreamEvent("response.output_item.done", done_item) completed = response_completed_payload(state, _usage_to_responses_stream(usage)) self._trace_responses_usage(transaction_logger, completed, unified.model, source="responses_stream") stored = await self._store_stream_response(stream_request, completed, parent) @@ -248,9 +264,9 @@ async def stream_response( metadata={"transport": "sse"}, ) self._trace(transaction_logger, "responses_stream_event_completed", completed, direction="stream", stage="final", metadata={"transport": "sse"}) - yield formatter.format_event("response.completed", completed) - self._trace(transaction_logger, "stream_done_event", {"raw": formatter.done()}, direction="stream", stage="final", metadata={"transport": "sse"}) - yield formatter.done() + yield ResponsesStreamEvent("response.completed", completed) + self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": "sse"}) + yield ResponsesStreamEvent("done", {}, terminal=True) except Exception as exc: monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) @@ -268,9 +284,9 @@ async def stream_response( stage="final", metadata={"transport": "sse", "failed": True}, ) - yield formatter.format_event("response.failed", failed) - self._trace(transaction_logger, "stream_done_event", {"raw": formatter.done()}, direction="stream", stage="final", metadata={"transport": "sse", "failed": True}) - yield formatter.done() + yield ResponsesStreamEvent("response.failed", failed) + self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": "sse", "failed": True}) + yield ResponsesStreamEvent("done", {}, terminal=True) async def get_response(self, response_id: str) -> dict[str, Any]: """Return a stored response payload or raise a 404-compatible error.""" @@ -419,6 +435,28 @@ async def _store_stream_response( await self.store.save(self._stored_response(raw_request, response_payload, parent)) return True + async def _store_stream_current_state( + self, + raw_request: dict[str, Any], + response_payload: dict[str, Any], + parent: Optional[StoredResponse], + *, + transaction_logger: Optional[Any], + ) -> bool: + """Optionally persist in-progress stream state for retrieval surfaces.""" + + if not self.store_settings.store_in_progress or not raw_request.get("store", True): + return False + await self.store.save(self._stored_response(raw_request, response_payload, parent)) + self._trace( + transaction_logger, + "responses_stored_stream_current_state", + {"response_id": response_payload.get("id"), "status": response_payload.get("status")}, + direction="metadata", + stage="final", + ) + return True + def _input_items(raw_request: dict[str, Any]) -> list[Any]: value = raw_request.get("input") @@ -427,6 +465,14 @@ def _input_items(raw_request: dict[str, Any]) -> list[Any]: return deepcopy(value if isinstance(value, list) else [value]) +def _current_stream_payload(state: ResponsesStreamState) -> dict[str, Any]: + """Return a retrievable in-progress Responses object for stream state.""" + + payload = response_completed_payload(state) + payload["status"] = "in_progress" + return payload + + def _expires_at(settings: ResponsesStoreSettings) -> Optional[float]: """Return the expiration timestamp for a new stored response, if enabled.""" diff --git a/src/rotator_library/responses/streaming.py b/src/rotator_library/responses/streaming.py index d358ab65a..94f91abe6 100644 --- a/src/rotator_library/responses/streaming.py +++ b/src/rotator_library/responses/streaming.py @@ -22,6 +22,19 @@ class ResponsesStreamState: output_item_id: str = "msg_0" +@dataclass(frozen=True) +class ResponsesStreamEvent: + """Transport-neutral Responses stream event. + + Service code yields these events. Formatters decide how to serialize them for + SSE, WebSocket, or future transports without duplicating protocol logic. + """ + + event_name: str + payload: dict[str, Any] + terminal: bool = False + + class ResponsesSSEFormatter: """Format Responses API events as HTTP Server-Sent Events.""" @@ -30,6 +43,13 @@ class ResponsesSSEFormatter: def format_event(self, event_name: str, payload: dict[str, Any]) -> str: return f"event: {event_name}\ndata: {json.dumps(serialize_value(payload), ensure_ascii=False)}\n\n" + def format_stream_event(self, event: ResponsesStreamEvent) -> str: + """Format one transport-neutral event for HTTP SSE.""" + + if event.terminal: + return self.done() + return self.format_event(event.event_name, event.payload) + def done(self) -> str: """Return the final compatibility sentinel used by many SSE clients.""" @@ -43,7 +63,14 @@ class ResponsesWebSocketFormatter: future_supported = True def format_event(self, event_name: str, payload: dict[str, Any]) -> str: - raise NotImplementedError("Responses WebSocket transport is planned but not implemented yet") + return json.dumps({"event": event_name, "data": serialize_value(payload)}, ensure_ascii=False) + + def format_stream_event(self, event: ResponsesStreamEvent) -> str: + """Format one transport-neutral event as a WebSocket message payload.""" + + if event.terminal: + return json.dumps({"event": "done", "data": {}}, ensure_ascii=False) + return self.format_event(event.event_name, event.payload) def parse_chat_sse_chunk(chunk: Any) -> dict[str, Any] | None: diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 86562311d..511a522c7 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -2,7 +2,7 @@ import pytest -from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesStoreSettings, ResponsesWebSocketFormatter +from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesStoreSettings, ResponsesStreamEvent, ResponsesWebSocketFormatter from rotator_library.transaction_logger import TransactionLogger @@ -103,6 +103,20 @@ async def test_stream_response_can_skip_failed_storage() -> None: assert await store.get(response_id) is None +@pytest.mark.asyncio +async def test_stream_events_can_store_in_progress_state() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store, store_settings=ResponsesStoreSettings(store_in_progress=True)) + stream = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient()) + + created = await anext(stream) + stored = await store.get(created.payload["id"]) + await stream.aclose() + + assert stored is not None + assert stored.status == "in_progress" + + @pytest.mark.asyncio async def test_stream_response_failure_trace_scrubs_header_like_secret_text(tmp_path) -> None: logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) @@ -127,8 +141,7 @@ def test_transport_formatters_expose_sse_and_websocket_seam() -> None: websocket = ResponsesWebSocketFormatter() assert websocket.transport == "websocket" assert websocket.future_supported is True - with pytest.raises(NotImplementedError): - websocket.format_event("response.created", {}) + assert websocket.format_stream_event(ResponsesStreamEvent("response.created", {"id": "resp"})) == '{"event": "response.created", "data": {"id": "resp"}}' @pytest.mark.asyncio @@ -148,3 +161,22 @@ async def test_responses_stream_records_common_stream_metrics(tmp_path) -> None: assert "responses_stream_event_created" in trace_text assert "responses_stream_event_output_text_delta" in trace_text assert "responses_stream_event_completed" in trace_text + + +@pytest.mark.asyncio +async def test_stream_events_are_transport_neutral_and_sse_wraps_them() -> None: + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [event async for event in service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, FakeStreamingClient())] + sse = [ResponsesSSEFormatter().format_stream_event(event) for event in events] + + assert [event.event_name for event in events if not event.terminal] == [ + "response.created", + "response.output_item.added", + "response.output_text.delta", + "response.output_text.delta", + "response.output_item.done", + "response.completed", + ] + assert events[-1].terminal is True + assert "".join(sse).endswith("data: [DONE]\n\n") From d647c60e5f84e9d79b95be94f3194c68095e8b04 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:50:16 +0200 Subject: [PATCH 104/182] fix(responses): harden continuation hints and stream failures Keeps Responses continuation hints proxy-internal by only passing them to the internal RotatingClient path and scrubbing them from bridge request traces. Continuation storage now captures non-secret request context metadata and reuses parent affinity for previous_response_id hints so follow-up requests anchor back to the parent response session. Streaming now treats upstream error chunks as failed responses, preserves partial visible output in failed storage, and validates stream requests before HTTP SSE starts. Tests: pytest tests/test_responses_service.py tests/test_responses_streaming.py tests/test_responses_store.py tests/test_responses_routes.py tests/test_responses_bridge.py tests/test_responses_usage_accounting.py tests/test_protocol_responses.py tests/test_request_builder_routing.py --- src/proxy_app/main.py | 1 + src/rotator_library/client/rotating_client.py | 3 + src/rotator_library/responses/bridge.py | 4 +- src/rotator_library/responses/service.py | 130 ++++++++++++++---- tests/test_responses_routes.py | 18 +++ tests/test_responses_service.py | 74 ++++++++++ tests/test_responses_streaming.py | 25 ++++ 7 files changed, 227 insertions(+), 28 deletions(-) diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index f2aa4dd1f..48f3e7b75 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -1056,6 +1056,7 @@ async def responses_create( transaction_logger = TransactionLogger("responses", request_data.get("model", "unknown")) if ENABLE_REQUEST_LOGGING else None try: if request_data.get("stream"): + await service.validate_stream_request(request_data) return StreamingResponse( service.stream_response(request_data, client, request=request, transaction_logger=transaction_logger), media_type="text/event-stream", diff --git a/src/rotator_library/client/rotating_client.py b/src/rotator_library/client/rotating_client.py index 7d496ecc9..e9cd3804a 100644 --- a/src/rotator_library/client/rotating_client.py +++ b/src/rotator_library/client/rotating_client.py @@ -535,9 +535,12 @@ async def acompletion( pre_request_callback: Optional[callable] = None, **kwargs, ) -> Union[Any, AsyncGenerator[str, None]]: + request_context_callback = kwargs.pop("_request_context_callback", None) context = await self._request_builder.build_completion_context( request, pre_request_callback, kwargs ) + if callable(request_context_callback): + request_context_callback(context) return await self._executor.execute(context) async def aembedding( diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index ba7c539fb..145436469 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -121,7 +121,7 @@ def _chat_generation_key(key: str) -> str: return key -def responses_session_hints(previous_response_id: Optional[str]) -> dict[str, Any] | None: +def responses_session_hints(previous_response_id: Optional[str], *, affinity_key: Optional[str] = None) -> dict[str, Any] | None: """Return proxy-internal sticky routing evidence for Responses continuations.""" if not previous_response_id: @@ -129,7 +129,7 @@ def responses_session_hints(previous_response_id: Optional[str]) -> dict[str, An anchor = f"responses_previous_response_id:{previous_response_id}" return { "strong_anchors": [anchor], - "affinity_key": anchor, + "affinity_key": affinity_key or anchor, "session_scope": "responses", } diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index a6203571b..e2c92278b 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -13,7 +13,7 @@ from ..streaming import StreamEvent, StreamMonitor from ..usage.accounting import extract_usage_record from ..protocols.responses import ResponsesProtocol -from .bridge import ResponsesBridge +from .bridge import ResponsesBridge, responses_session_hints from .store import InMemoryResponsesStore, ResponsesStore from .streaming import ( ResponsesSSEFormatter, @@ -84,13 +84,18 @@ async def create_response( parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) + session_hints = chat_kwargs.pop("_session_tracking_hints", None) + session_hints = _responses_session_hints(unified.previous_response_id, parent, session_hints) + session_info: dict[str, Any] = {} + chat_kwargs.update(_internal_client_kwargs(client, session_hints, session_info)) + trace_chat_kwargs = _without_internal_kwargs(chat_kwargs) self._trace( transaction_logger, "responses_bridge_chat_request", - chat_kwargs, + trace_chat_kwargs, direction="request", stage="adapter", - metadata={"bridge_metadata": bridge_metadata}, + metadata={"bridge_metadata": {**bridge_metadata, "has_session_hints": bool(session_hints)}}, ) chat_response = await client.acompletion(request=request, **chat_kwargs) @@ -102,7 +107,7 @@ async def create_response( self._trace_responses_usage(transaction_logger, response_payload, unified.model, source="responses_response") if raw_request.get("store", True): - stored = self._stored_response(raw_request, response_payload, parent) + stored = self._stored_response(raw_request, response_payload, parent, session_info=session_info) await self.store.save(stored) self._trace(transaction_logger, "responses_stored_response", stored.to_dict(), direction="metadata", stage="final") @@ -116,13 +121,23 @@ async def stream_response( *, request: Optional[Any] = None, transaction_logger: Optional[Any] = None, + transport: str = "sse", ) -> AsyncGenerator[str, None]: """Stream a Responses API request as HTTP SSE events.""" formatter = ResponsesSSEFormatter() - async for event in self.stream_events(raw_request, client, request=request, transaction_logger=transaction_logger): + async for event in self.stream_events(raw_request, client, request=request, transaction_logger=transaction_logger, transport=transport): yield formatter.format_stream_event(event) + async def validate_stream_request(self, raw_request: dict[str, Any]) -> None: + """Validate stream-only preconditions before an HTTP response starts.""" + + if not raw_request.get("model"): + raise ResponsesServiceError("'model' is required", status_code=400) + previous_response_id = raw_request.get("previous_response_id") + if previous_response_id: + await self._load_previous_response(str(previous_response_id), None) + async def stream_events( self, raw_request: dict[str, Any], @@ -130,6 +145,7 @@ async def stream_events( *, request: Optional[Any] = None, transaction_logger: Optional[Any] = None, + transport: str = "sse", ) -> AsyncGenerator[ResponsesStreamEvent, None]: """Yield transport-neutral Responses events for streaming transports.""" @@ -138,20 +154,25 @@ async def stream_events( stream_request = dict(raw_request) stream_request["stream"] = True self._trace(transaction_logger, "responses_raw_request", stream_request, direction="request", stage="client") - unified = self.protocol.parse_request(stream_request, ProtocolContext(source_protocol="responses", transport="sse")) + unified = self.protocol.parse_request(stream_request, ProtocolContext(source_protocol="responses", transport=transport)) if transaction_logger: self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) + session_hints = chat_kwargs.pop("_session_tracking_hints", None) + session_hints = _responses_session_hints(unified.previous_response_id, parent, session_hints) + session_info: dict[str, Any] = {} + chat_kwargs.update(_internal_client_kwargs(client, session_hints, session_info)) chat_kwargs["stream"] = True + trace_chat_kwargs = _without_internal_kwargs(chat_kwargs) self._trace( transaction_logger, "responses_bridge_chat_request", - chat_kwargs, + trace_chat_kwargs, direction="request", stage="adapter", - metadata={"bridge_metadata": bridge_metadata, "transport": "sse"}, + metadata={"bridge_metadata": {**bridge_metadata, "has_session_hints": bool(session_hints)}, "transport": transport}, ) response_id = generate_response_id() @@ -166,10 +187,10 @@ async def stream_events( {"event": StreamEvent("started", protocol="responses").to_dict(), "metrics": monitor.metrics.to_dict()}, direction="stream", stage="client", - metadata={"transport": "sse"}, + metadata={"transport": transport}, ) created = response_created_payload(response_id, unified.model) - self._trace(transaction_logger, "responses_stream_event_created", created, direction="stream", stage="final", metadata={"transport": "sse"}) + self._trace(transaction_logger, "responses_stream_event_created", created, direction="stream", stage="final", metadata={"transport": transport}) await self._store_stream_current_state(stream_request, created, parent, transaction_logger=transaction_logger) yield ResponsesStreamEvent("response.created", created) try: @@ -184,13 +205,15 @@ async def stream_events( {"metrics": monitor.metrics.to_dict()}, direction="stream", stage="provider", - metadata={"transport": "sse"}, + metadata={"transport": transport}, ) self._trace(transaction_logger, "raw_chat_bridge_stream_chunk", raw_chunk, direction="stream", stage="provider") chunk = parse_chat_sse_chunk(raw_chunk) if not chunk or chunk.get("type") == "done": continue self._trace(transaction_logger, "parsed_unified_stream_event", chunk, direction="stream", stage="protocol") + if chunk.get("error") is not None: + raise ResponsesServiceError("Upstream stream error", status_code=502, error_type="upstream_error") if chunk.get("usage"): usage = chunk["usage"] delta = _chunk_text_delta(chunk) @@ -199,7 +222,7 @@ async def stream_events( if not item_started: item_started = True added = output_item_added_payload(state) - self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": "sse"}) + self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": transport}) yield ResponsesStreamEvent("response.output_item.added", added) state = ResponsesStreamState( response_id=state.response_id, @@ -225,22 +248,22 @@ async def stream_events( {"event": event, "metrics": monitor.metrics.to_dict()}, direction="stream", stage="final", - metadata={"transport": "sse"}, + metadata={"transport": transport}, ) - self._trace(transaction_logger, "responses_stream_event_output_text_delta", event, direction="stream", stage="final", metadata={"transport": "sse"}) + self._trace(transaction_logger, "responses_stream_event_output_text_delta", event, direction="stream", stage="final", metadata={"transport": transport}) await self._store_stream_current_state(stream_request, _current_stream_payload(state), parent, transaction_logger=transaction_logger) yield ResponsesStreamEvent("response.output_text.delta", event) if not item_started: added = output_item_added_payload(state) - self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": "sse"}) + self._trace(transaction_logger, "responses_stream_event_output_item_added", added, direction="stream", stage="final", metadata={"transport": transport}) yield ResponsesStreamEvent("response.output_item.added", added) done_item = output_item_done_payload(state) - self._trace(transaction_logger, "responses_stream_event_output_item_done", done_item, direction="stream", stage="final", metadata={"transport": "sse"}) + self._trace(transaction_logger, "responses_stream_event_output_item_done", done_item, direction="stream", stage="final", metadata={"transport": transport}) yield ResponsesStreamEvent("response.output_item.done", done_item) completed = response_completed_payload(state, _usage_to_responses_stream(usage)) self._trace_responses_usage(transaction_logger, completed, unified.model, source="responses_stream") - stored = await self._store_stream_response(stream_request, completed, parent) + stored = await self._store_stream_response(stream_request, completed, parent, session_info=session_info) if stored: self._trace(transaction_logger, "responses_stored_stream_response", completed, direction="metadata", stage="final") else: @@ -253,7 +276,7 @@ async def stream_events( {"metrics": monitor.metrics.to_dict()}, direction="stream", stage="final", - metadata={"transport": "sse"}, + metadata={"transport": transport}, ) self._trace( transaction_logger, @@ -261,20 +284,22 @@ async def stream_events( {"metrics": monitor.metrics.to_dict()}, direction="stream", stage="final", - metadata={"transport": "sse"}, + metadata={"transport": transport}, ) - self._trace(transaction_logger, "responses_stream_event_completed", completed, direction="stream", stage="final", metadata={"transport": "sse"}) + self._trace(transaction_logger, "responses_stream_event_completed", completed, direction="stream", stage="final", metadata={"transport": transport}) yield ResponsesStreamEvent("response.completed", completed) - self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": "sse"}) + self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": transport}) yield ResponsesStreamEvent("done", {}, terminal=True) except Exception as exc: monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) + if state.output_text: + failed["output"] = [output_item_done_payload(state)["item"]] self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) - stored = await self._store_stream_response(stream_request, failed, parent, failed=True) + stored = await self._store_stream_response(stream_request, failed, parent, failed=True, session_info=session_info) if stored: self._trace(transaction_logger, "responses_stored_failed_stream_response", {"response_id": failed.get("id"), "status": "failed"}, direction="metadata", stage="final") - self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": "sse"}, scrub_strings=True) + self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": transport}, scrub_strings=True) if transaction_logger: self._trace( transaction_logger, @@ -282,10 +307,10 @@ async def stream_events( {"metrics": monitor.metrics.to_dict()}, direction="stream", stage="final", - metadata={"transport": "sse", "failed": True}, + metadata={"transport": transport, "failed": True}, ) yield ResponsesStreamEvent("response.failed", failed) - self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": "sse", "failed": True}) + self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": transport, "failed": True}) yield ResponsesStreamEvent("done", {}, terminal=True) async def get_response(self, response_id: str) -> dict[str, Any]: @@ -339,7 +364,11 @@ def _stored_response( raw_request: dict[str, Any], response_payload: dict[str, Any], parent: Optional[StoredResponse], + *, + session_info: Optional[dict[str, Any]] = None, ) -> StoredResponse: + session_info = session_info or {} + session_affinity_key = session_info.get("session_affinity_key") or (parent.metadata.get("session_affinity_key") if parent else None) return StoredResponse( id=str(response_payload["id"]), created_at=float(response_payload.get("created_at") or time.time()), @@ -353,7 +382,11 @@ def _stored_response( metadata={ "previous_response_id": parent.id if parent else raw_request.get("previous_response_id"), "response_id": response_payload.get("id"), + "session_affinity_key": session_affinity_key, }, + session_id=session_info.get("session_id") or (parent.session_id if parent else None), + scope_key=session_info.get("scope_key") or (parent.scope_key if parent else None), + classifier=session_info.get("classifier") or (parent.classifier if parent else None), expires_at=_expires_at(self.store_settings), ) @@ -427,12 +460,13 @@ async def _store_stream_response( parent: Optional[StoredResponse], *, failed: bool = False, + session_info: Optional[dict[str, Any]] = None, ) -> bool: if not raw_request.get("store", True): return False if failed and not self.store_settings.store_failed: return False - await self.store.save(self._stored_response(raw_request, response_payload, parent)) + await self.store.save(self._stored_response(raw_request, response_payload, parent, session_info=session_info)) return True async def _store_stream_current_state( @@ -482,6 +516,50 @@ def _expires_at(settings: ResponsesStoreSettings) -> Optional[float]: return time.time() + ttl +def _internal_client_kwargs(client: Any, session_hints: Any, session_info: dict[str, Any]) -> dict[str, Any]: + """Return hidden kwargs only for the internal RotatingClient path.""" + + if not _supports_internal_context_kwargs(client): + return {} + kwargs: dict[str, Any] = {"_request_context_callback": _capture_request_context(session_info)} + if session_hints: + kwargs["_session_tracking_hints"] = session_hints + return kwargs + + +def _supports_internal_context_kwargs(client: Any) -> bool: + """Return whether a client is the proxy's internal rotating client.""" + + return hasattr(client, "_request_builder") and hasattr(client, "_executor") + + +def _capture_request_context(session_info: dict[str, Any]): + """Build a callback that records non-secret request context metadata.""" + + def capture(context: Any) -> None: + session_info["session_id"] = getattr(context, "session_id", None) + session_info["session_affinity_key"] = getattr(context, "session_affinity_key", None) + session_info["scope_key"] = getattr(context, "usage_manager_key", None) + session_info["classifier"] = getattr(context, "classifier", None) + + return capture + + +def _without_internal_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + """Return trace/provider-visible kwargs without proxy-internal controls.""" + + return {key: deepcopy(value) for key, value in kwargs.items() if not key.startswith("_")} + + +def _responses_session_hints(previous_response_id: Optional[str], parent: Optional[StoredResponse], fallback: Any) -> Any: + """Prefer parent stored affinity when building Responses continuation hints.""" + + if not previous_response_id: + return None + parent_affinity = parent.metadata.get("session_affinity_key") if parent else None + return responses_session_hints(previous_response_id, affinity_key=parent_affinity) or fallback + + def _chunk_text_delta(chunk: dict[str, Any]) -> str: choices = chunk.get("choices") if isinstance(chunk.get("choices"), list) else [] if not choices: diff --git a/tests/test_responses_routes.py b/tests/test_responses_routes.py index 4079fde5f..1cdb18fd6 100644 --- a/tests/test_responses_routes.py +++ b/tests/test_responses_routes.py @@ -52,6 +52,24 @@ def test_post_responses_missing_model_returns_400() -> None: assert response.json()["detail"]["error"]["type"] == "invalid_request_error" +def test_post_responses_stream_missing_model_returns_400_before_sse() -> None: + client = _client() + + response = client.post("/v1/responses", json={"input": "hello", "stream": True}) + + assert response.status_code == 400 + assert response.json()["detail"]["error"]["type"] == "invalid_request_error" + + +def test_post_responses_stream_missing_previous_response_returns_404_before_sse() -> None: + client = _client() + + response = client.post("/v1/responses", json={"model": "gpt-test", "input": "hello", "stream": True, "previous_response_id": "missing"}) + + assert response.status_code == 404 + assert response.json()["detail"]["error"]["type"] == "not_found_error" + + def test_get_delete_and_input_items_routes() -> None: client = _client() created = client.post("/v1/responses", json={"model": "gpt-test", "input": ["hello"]}).json() diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py index 976444dfc..4374ae361 100644 --- a/tests/test_responses_service.py +++ b/tests/test_responses_service.py @@ -23,6 +23,32 @@ async def acompletion(self, **kwargs): } +class FakeInternalClient(FakeClient): + def __init__(self) -> None: + super().__init__() + self._request_builder = object() + self._executor = object() + + async def acompletion(self, **kwargs): + callback = kwargs.pop("_request_context_callback", None) + hints = kwargs.pop("_session_tracking_hints", None) + if callback: + callback( + type( + "Context", + (), + { + "session_id": "session-parent", + "session_affinity_key": "affinity-parent", + "usage_manager_key": "scope-parent", + "classifier": "global", + }, + )() + ) + self.internal_hints = hints + return await super().acompletion(**kwargs) + + def _trace_entries(log_dir): return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] @@ -119,6 +145,54 @@ async def test_previous_response_id_loads_parent_context() -> None: ] +@pytest.mark.asyncio +async def test_internal_session_hints_do_not_leak_to_direct_clients_or_traces(tmp_path) -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + await store.save( + StoredResponse( + id="resp_parent", + model="gpt-test", + status="completed", + response={"id": "resp_parent", "object": "response", "output": []}, + metadata={"session_affinity_key": "affinity-parent"}, + ) + ) + client = FakeClient() + + await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}, client, transaction_logger=logger) + + assert "_session_tracking_hints" not in client.calls[0] + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "_session_tracking_hints" not in trace_text + assert "has_session_hints" in trace_text + + +@pytest.mark.asyncio +async def test_internal_client_context_metadata_is_stored_with_response() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + await store.save( + StoredResponse( + id="resp_parent", + model="gpt-test", + status="completed", + response={"id": "resp_parent", "object": "response", "output": []}, + metadata={"session_affinity_key": "affinity-parent"}, + ) + ) + client = FakeInternalClient() + + response = await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}, client) + stored = await store.get(response["id"]) + + assert client.internal_hints["affinity_key"] == "affinity-parent" + assert stored is not None + assert stored.session_id == "session-parent" + assert stored.metadata["session_affinity_key"] == "affinity-parent" + + @pytest.mark.asyncio async def test_missing_previous_response_id_raises_not_found() -> None: service = ResponsesService(store=InMemoryResponsesStore()) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 511a522c7..11b57c65b 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -35,6 +35,15 @@ async def chunks(): return chunks() +class ErrorChunkStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"partial"}}]}\n\n' + yield 'data: {"error":{"message":"provider failed"}}\n\n' + + return chunks() + + @pytest.mark.asyncio async def test_stream_response_emits_responses_sse_events_and_stores_final_response() -> None: store = InMemoryResponsesStore() @@ -92,6 +101,22 @@ async def test_stream_response_errors_emit_failed_event() -> None: assert stored.status == "failed" +@pytest.mark.asyncio +async def test_stream_response_error_chunks_store_failed_with_partial_output() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, ErrorChunkStreamingClient())] + + event_text = "".join(events) + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert "event: response.failed" in event_text + assert stored is not None + assert stored.status == "failed" + assert stored.output_items[0]["content"][0]["text"] == "partial" + + @pytest.mark.asyncio async def test_stream_response_can_skip_failed_storage() -> None: store = InMemoryResponsesStore() From ad417b7d61af8cde3e57ecaed9833d3b9fb1b0a8 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 11:58:08 +0200 Subject: [PATCH 105/182] fix(responses): bind continuations to response anchors Records emitted Responses IDs as strong response-derived session anchors so previous_response_id continuations can match the original parent session and credential instead of only forming a new continuation session. Also handles event:error and type:error upstream stream frames as failed Responses outputs, preserving partial visible output and client-safe error details. Tests: pytest tests/test_responses_service.py tests/test_responses_streaming.py tests/test_responses_store.py tests/test_responses_routes.py tests/test_responses_bridge.py tests/test_responses_usage_accounting.py tests/test_protocol_responses.py tests/test_request_builder_routing.py tests/test_session_tracking.py --- src/rotator_library/responses/service.py | 44 +++++++++++++++++-- src/rotator_library/responses/streaming.py | 16 +++++-- src/rotator_library/session_tracking.py | 19 ++++++++- tests/test_responses_service.py | 49 ++++++++++++++++++++++ tests/test_responses_streaming.py | 24 +++++++++++ tests/test_session_tracking.py | 24 +++++++++++ 6 files changed, 168 insertions(+), 8 deletions(-) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index e2c92278b..508c2dd6d 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -103,6 +103,7 @@ async def create_response( self._trace(transaction_logger, "responses_bridge_chat_response", self._response_to_dict(chat_response), direction="response", stage="provider") response_payload = self.bridge.from_chat_response(chat_response, unified) + _record_responses_session_anchor(session_info, response_payload) self._trace(transaction_logger, "responses_parsed_response", response_payload, direction="response", stage="protocol") self._trace_responses_usage(transaction_logger, response_payload, unified.model, source="responses_response") @@ -212,8 +213,8 @@ async def stream_events( if not chunk or chunk.get("type") == "done": continue self._trace(transaction_logger, "parsed_unified_stream_event", chunk, direction="stream", stage="protocol") - if chunk.get("error") is not None: - raise ResponsesServiceError("Upstream stream error", status_code=502, error_type="upstream_error") + if chunk.get("error") is not None or chunk.get("type") == "error": + raise ResponsesServiceError(_stream_error_message(chunk), status_code=502, error_type="upstream_error") if chunk.get("usage"): usage = chunk["usage"] delta = _chunk_text_delta(chunk) @@ -251,7 +252,7 @@ async def stream_events( metadata={"transport": transport}, ) self._trace(transaction_logger, "responses_stream_event_output_text_delta", event, direction="stream", stage="final", metadata={"transport": transport}) - await self._store_stream_current_state(stream_request, _current_stream_payload(state), parent, transaction_logger=transaction_logger) + await self._store_stream_current_state(stream_request, _current_stream_payload(state), parent, transaction_logger=transaction_logger, session_info=session_info) yield ResponsesStreamEvent("response.output_text.delta", event) if not item_started: @@ -262,6 +263,7 @@ async def stream_events( self._trace(transaction_logger, "responses_stream_event_output_item_done", done_item, direction="stream", stage="final", metadata={"transport": transport}) yield ResponsesStreamEvent("response.output_item.done", done_item) completed = response_completed_payload(state, _usage_to_responses_stream(usage)) + _record_responses_session_anchor(session_info, completed) self._trace_responses_usage(transaction_logger, completed, unified.model, source="responses_stream") stored = await self._store_stream_response(stream_request, completed, parent, session_info=session_info) if stored: @@ -476,12 +478,13 @@ async def _store_stream_current_state( parent: Optional[StoredResponse], *, transaction_logger: Optional[Any], + session_info: Optional[dict[str, Any]] = None, ) -> bool: """Optionally persist in-progress stream state for retrieval surfaces.""" if not self.store_settings.store_in_progress or not raw_request.get("store", True): return False - await self.store.save(self._stored_response(raw_request, response_payload, parent)) + await self.store.save(self._stored_response(raw_request, response_payload, parent, session_info=session_info)) self._trace( transaction_logger, "responses_stored_stream_current_state", @@ -541,6 +544,10 @@ def capture(context: Any) -> None: session_info["session_affinity_key"] = getattr(context, "session_affinity_key", None) session_info["scope_key"] = getattr(context, "usage_manager_key", None) session_info["classifier"] = getattr(context, "classifier", None) + session_info["session_tracker"] = getattr(context, "session_tracker", None) + session_info["provider"] = getattr(context, "provider", None) + session_info["model"] = getattr(context, "model", None) + session_info["tracking_namespace"] = getattr(context, "session_tracking_namespace", None) return capture @@ -560,6 +567,35 @@ def _responses_session_hints(previous_response_id: Optional[str], parent: Option return responses_session_hints(previous_response_id, affinity_key=parent_affinity) or fallback +def _record_responses_session_anchor(session_info: dict[str, Any], response_payload: dict[str, Any]) -> None: + """Record emitted Responses IDs as response-derived session evidence.""" + + tracker = session_info.get("session_tracker") + session_id = session_info.get("session_id") + if not tracker or not session_id or not response_payload.get("id"): + return + tracker.record_response( + session_id, + provider=session_info.get("provider"), + model=session_info.get("model"), + scope_key=session_info.get("scope_key"), + tracking_namespace=session_info.get("tracking_namespace"), + response={"id": response_payload.get("id"), "object": "response"}, + ) + + +def _stream_error_message(chunk: dict[str, Any]) -> str: + """Return a compact, client-safe message for upstream stream error chunks.""" + + error = chunk.get("error") + if isinstance(error, dict): + message = error.get("message") or error.get("type") + if message: + return str(message) + message = chunk.get("message") + return str(message) if message else "Upstream stream error" + + def _chunk_text_delta(chunk: dict[str, Any]) -> str: choices = chunk.get("choices") if isinstance(chunk.get("choices"), list) else [] if not choices: diff --git a/src/rotator_library/responses/streaming.py b/src/rotator_library/responses/streaming.py index 94f91abe6..699a971f9 100644 --- a/src/rotator_library/responses/streaming.py +++ b/src/rotator_library/responses/streaming.py @@ -83,14 +83,24 @@ def parse_chat_sse_chunk(chunk: Any) -> dict[str, Any] | None: text = chunk.strip() if not text: return None - if text.startswith("data:"): - text = text[len("data:") :].strip() + event_name = None + data_lines: list[str] = [] + for line in text.splitlines(): + if line.startswith("event:"): + event_name = line[len("event:") :].strip() + elif line.startswith("data:"): + data_lines.append(line[len("data:") :].strip()) + if data_lines: + text = "\n".join(data_lines).strip() if text == "[DONE]": return {"type": "done"} try: - return json.loads(text) + payload = json.loads(text) except json.JSONDecodeError: return None + if isinstance(payload, dict) and event_name and payload.get("type") is None: + payload["type"] = event_name + return payload if isinstance(payload, dict) else None def response_created_payload(response_id: str, model: str) -> dict[str, Any]: diff --git a/src/rotator_library/session_tracking.py b/src/rotator_library/session_tracking.py index c30b3b738..f54588097 100644 --- a/src/rotator_library/session_tracking.py +++ b/src/rotator_library/session_tracking.py @@ -773,6 +773,21 @@ def _anchors_from_response(self, response: Any, namespace: str) -> List[SessionA data = response.model_dump() if hasattr(response, "model_dump") else response if not isinstance(data, dict): return [] + anchors: List[SessionAnchor] = [] + response_id = data.get("id") + if response_id: + # Responses API continuations identify their parent with + # previous_response_id. Record the emitted response id as strong + # response evidence so the next request can route back to the exact + # session/credential that produced it. + anchors.append( + SessionAnchor( + self._scoped(namespace, f"provider:responses_previous_response_id:{response_id}"), + "strong", + source="response", + group="responses_previous_response_id", + ) + ) messages: List[Dict[str, Any]] = [] for choice in data.get("choices") or []: if not isinstance(choice, dict): @@ -782,7 +797,9 @@ def _anchors_from_response(self, response: Any, namespace: str) -> List[SessionA response_message = dict(message) response_message.setdefault("role", "assistant") messages.append(response_message) - return self._anchors_from_messages(messages, namespace, source="response") if messages else [] + if messages: + anchors.extend(self._anchors_from_messages(messages, namespace, source="response")) + return anchors def _affinity_from_anchors( self, diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py index 4374ae361..1c71ba2dc 100644 --- a/tests/test_responses_service.py +++ b/tests/test_responses_service.py @@ -42,6 +42,10 @@ async def acompletion(self, **kwargs): "session_affinity_key": "affinity-parent", "usage_manager_key": "scope-parent", "classifier": "global", + "session_tracker": None, + "provider": "openai", + "model": "gpt-test", + "session_tracking_namespace": "namespace", }, )() ) @@ -193,6 +197,51 @@ async def test_internal_client_context_metadata_is_stored_with_response() -> Non assert stored.metadata["session_affinity_key"] == "affinity-parent" +@pytest.mark.asyncio +async def test_responses_service_records_response_id_session_anchor() -> None: + class Tracker: + def __init__(self) -> None: + self.calls = [] + + def record_response(self, *args, **kwargs): + self.calls.append((args, kwargs)) + + tracker = Tracker() + + class Client(FakeInternalClient): + async def acompletion(self, **kwargs): + callback = kwargs.pop("_request_context_callback", None) + kwargs.pop("_session_tracking_hints", None) + if callback: + callback( + type( + "Context", + (), + { + "session_id": "session-parent", + "session_affinity_key": "affinity-parent", + "usage_manager_key": "scope-parent", + "classifier": "global", + "session_tracker": tracker, + "provider": "openai", + "model": "gpt-test", + "session_tracking_namespace": "namespace", + }, + )() + ) + self.calls.append(kwargs) + return { + "id": "resp_parent", + "model": kwargs["model"], + "choices": [{"message": {"role": "assistant", "content": "Hello back"}, "finish_reason": "stop"}], + } + + await ResponsesService(store=InMemoryResponsesStore()).create_response({"model": "gpt-test", "input": "Hello"}, Client()) + + assert tracker.calls[0][0][0] == "session-parent" + assert tracker.calls[0][1]["response"] == {"id": "resp_parent", "object": "response"} + + @pytest.mark.asyncio async def test_missing_previous_response_id_raises_not_found() -> None: service = ResponsesService(store=InMemoryResponsesStore()) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 11b57c65b..a3d5e280d 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -44,6 +44,15 @@ async def chunks(): return chunks() +class EventErrorStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"partial"}}]}\n\n' + yield 'event: error\ndata: {"message":"event failed"}\n\n' + + return chunks() + + @pytest.mark.asyncio async def test_stream_response_emits_responses_sse_events_and_stores_final_response() -> None: store = InMemoryResponsesStore() @@ -117,6 +126,21 @@ async def test_stream_response_error_chunks_store_failed_with_partial_output() - assert stored.output_items[0]["content"][0]["text"] == "partial" +@pytest.mark.asyncio +async def test_stream_response_event_error_frames_are_failed() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, EventErrorStreamingClient())] + + event_text = "".join(events) + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert "event: response.failed" in event_text + assert stored is not None + assert stored.response["error"]["message"] == "event failed" + + @pytest.mark.asyncio async def test_stream_response_can_skip_failed_storage() -> None: store = InMemoryResponsesStore() diff --git a/tests/test_session_tracking.py b/tests/test_session_tracking.py index 22262252a..adb85ec8c 100644 --- a/tests/test_session_tracking.py +++ b/tests/test_session_tracking.py @@ -166,6 +166,30 @@ def test_response_anchors_bridge_next_request(self): self.assertEqual(inferred.session_id, continued.session_id) + def test_response_id_anchor_bridges_responses_previous_response_id(self): + tracker = SessionTracker(ttl_seconds=3600) + inferred = tracker.infer_session( + {"messages": [{"role": "user", "content": "Start a Responses API conversation."}]}, + provider="openai", + model="gpt-test", + ) + tracker.record_response( + inferred.session_id, + provider="openai", + model="gpt-test", + response={"id": "resp_parent", "object": "response"}, + ) + + continued = tracker.infer_session( + {"messages": [{"role": "user", "content": "Continue."}]}, + provider="openai", + model="gpt-test", + hints={"strong_anchors": ["responses_previous_response_id:resp_parent"]}, + ) + + self.assertEqual(inferred.session_id, continued.session_id) + self.assertEqual(continued.confidence, "strong") + def test_compaction_probe_detects_user_summary_from_prior_response(self): tracker = SessionTracker(ttl_seconds=3600) request = { From 12cbc015ba1524899fd845c4d3342c61a1d05bab Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:02:39 +0200 Subject: [PATCH 106/182] fix(responses): align continuation anchor namespace Removes the forced Responses session_scope from previous_response_id hints so continuation requests use the same provider/model namespace where emitted response IDs are recorded. This lets the strong responses_previous_response_id anchor bridge back to the parent session and credential instead of creating a separate Responses-scoped session. Tests: pytest tests/test_responses_service.py tests/test_responses_streaming.py tests/test_responses_store.py tests/test_responses_routes.py tests/test_responses_bridge.py tests/test_responses_usage_accounting.py tests/test_protocol_responses.py tests/test_request_builder_routing.py tests/test_session_tracking.py --- src/rotator_library/responses/bridge.py | 1 - tests/test_responses_bridge.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index 145436469..8485f29bd 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -130,7 +130,6 @@ def responses_session_hints(previous_response_id: Optional[str], *, affinity_key return { "strong_anchors": [anchor], "affinity_key": affinity_key or anchor, - "session_scope": "responses", } diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py index 63d973a73..29d822b48 100644 --- a/tests/test_responses_bridge.py +++ b/tests/test_responses_bridge.py @@ -43,6 +43,7 @@ def test_bridge_adds_parent_response_messages_for_previous_response_id() -> None assert kwargs["_responses_bridge"]["previous_response_id"] == "resp_parent" assert kwargs["_session_tracking_hints"]["strong_anchors"] == ["responses_previous_response_id:resp_parent"] assert kwargs["_session_tracking_hints"]["affinity_key"] == "responses_previous_response_id:resp_parent" + assert "session_scope" not in kwargs["_session_tracking_hints"] def test_bridge_preserves_tool_definitions() -> None: From 86d350306a68c73f3025d0e5dd68c1dda76bcb00 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:06:03 +0200 Subject: [PATCH 107/182] docs(experimental): plan provider native correction Adds the Phase 5b corrective plan for making priority provider integrations mock-live through native execution instead of declaration-only skeletons. The plan covers provider-aware native operations, model normalization, endpoint/header hardening, streaming declarations, RequestExecutor integration tests, Antigravity safety boundaries, and Gemini CLI custom-path parity. Tests: not run (planning document only). --- .../phase-5b-provider-native-integrations.md | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 docs/experimental/phase-5b-provider-native-integrations.md diff --git a/docs/experimental/phase-5b-provider-native-integrations.md b/docs/experimental/phase-5b-provider-native-integrations.md new file mode 100644 index 000000000..e861a9df8 --- /dev/null +++ b/docs/experimental/phase-5b-provider-native-integrations.md @@ -0,0 +1,101 @@ +# Phase 5b Plan: Priority Providers From Skeletons To Mock-Live Native Integrations + +## Goal + +Correct the Phase 5 audit gap. Phase 5 added native-provider seams and priority provider declarations, but the validation pass found the priority providers are still mostly skeletons. Phase 5b will make Claude Code, Codex, Copilot, and Antigravity usable through the native execution path under mock HTTP tests, while preserving Gemini CLI's existing custom execution path and adding parity declarations where they are safe. + +## Non-Goals + +- Do not remove LiteLLM fallback. +- Do not replace Gemini CLI's existing custom provider implementation with the generic native executor in this phase. +- Do not invent device fingerprinting or brittle environment/device-profile behavior for Antigravity. +- Do not add real credential acquisition flows for Copilot/Codex/Claude Code; this phase consumes credentials supplied by the existing credential system. +- Do not use external files outside the project root. +- Do not introduce SQLite or new persistence. +- Do not touch unrelated dirty `ARCHITECTURE.md`, `STRUCTURE.md`, `.opencode/`, `docs/issues/`, or old phase reports. +- Do not commit user-facing phase reports. + +## Current Code State + +- `NativeProviderExecutor` can parse/build/adapter/cache/HTTP/format native requests and streams. +- `RequestExecutor` can select native execution when routing target execution is `native` or auto-detected by provider protocol declaration. +- `_build_native_provider_context()` currently asks every provider for endpoint and headers with operation `"chat"`, which is too generic. +- Priority provider files exist for Claude Code, Codex, Copilot, and Antigravity, but tests mostly cover declarations and helper methods. +- Provider-prefixed model names can leak into upstream native payloads unless each provider normalizes them. +- Native streaming exists, but provider support flags are conservative and not all priority providers have stream coverage. +- Gemini CLI has substantial existing custom logic and must not be silently bypassed by auto-native routing. + +## Implementation Plan + +1. Add provider-native operation resolution. + - Add default methods to `ProviderInterface`: `get_native_operation()`, `normalize_native_model()`, and optional `prepare_native_request()`. + - Preserve current behavior by default. + - Update `RequestExecutor._build_native_provider_context()` to ask providers for operation, endpoint, headers, normalized model, and prepared request metadata. + +2. Make model normalization explicit and tested. + - Claude Code strips `claude_code/`. + - Codex strips `codex/`. + - Copilot strips `copilot/`. + - Antigravity strips `antigravity/` and maps public aliases to internal upstream names. + - Gemini CLI remains custom-path first. + +3. Make provider endpoints operation-aware. + - Claude Code uses `messages` and `/v1/messages`. + - Codex uses `responses` and `/v1/responses`. + - Copilot uses `chat` and `/chat/completions`. + - Antigravity uses Gemini generate/stream-generate endpoints. + +4. Add minimal provider request preparation. + - Normalize model before protocol parsing. + - Allow provider `prepare_native_request()` to deep-copy and adjust request payloads before protocol parsing. + - Trace this pass without adding credentials. + +5. Strengthen provider auth/header behavior. + - Keep current supplied-credential model. + - Add tested header sets for each provider. + - Do not add secrets to JSON config. + +6. Native streaming support declarations. + - Enable only where tested safe. + - Prove native streaming selection through `RequestExecutor` tests. + - Keep unsupported providers on existing fallback/custom paths. + +7. Mock-live RequestExecutor integration tests. + - Prove Claude Code, Codex, Copilot, and Antigravity use the native executor with the correct protocol, operation, endpoint, headers, model normalization, and response formatting under fake HTTP. + - Cover streaming for providers that opt in. + +8. Provider model discovery hardening. + - Test fallback and successful discovery for each priority provider. + - Avoid duplicate prefixes and invalid aliases. + +9. Gemini CLI parity review. + - Keep custom execution path. + - Verify declarations align with Gemini protocol and field-cache paths. + - Explicitly prevent accidental auto-native routing if required. + +10. Documentation and comments. + - Update provider docstrings away from “skeleton” once behavior is mock-live. + - Explain model normalization, Antigravity safety boundaries, and Gemini CLI custom-path deferral. + +## Tests + +- `tests/test_claude_code_provider.py` +- `tests/test_codex_provider.py` +- `tests/test_copilot_provider.py` +- `tests/test_antigravity_provider_restore.py` +- `tests/test_gemini_cli_protocol_declarations.py` +- `tests/test_provider_protocol_declarations.py` +- `tests/test_native_provider_executor.py` +- `tests/test_native_provider_streaming.py` +- `tests/test_request_executor_native_routing.py` +- Relevant protocol, field-cache, and routing regressions. + +## Acceptance Criteria + +- Priority providers are no longer declaration-only skeletons; each has mock-live native `RequestExecutor` coverage or an explicit tested reason for custom-path deferral. +- Native operation and endpoint selection are provider-aware, not hardcoded to `"chat"`. +- Provider-prefixed model names are normalized before native upstream calls. +- Native streaming is enabled only where tested safe. +- Gemini CLI remains on its existing custom path unless explicitly routed otherwise. +- LiteLLM fallback remains available for uncovered providers/protocols. +- Focused tests and dual-agent review pass. From d366b84afc57538528a297ae4d1563d83e17e11f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:10:57 +0200 Subject: [PATCH 108/182] feat(providers): resolve native operations per provider Adds provider-native operation, model normalization, and request-preparation hooks so native execution is not hardcoded to chat. Wires RequestExecutor and NativeProviderContext to carry provider-selected operations into protocol context, endpoints, headers, traces, adapters, and field-cache metadata. Hardens Claude Code, Codex, Copilot, and Antigravity declarations with operation-aware endpoints, normalized upstream model names, tested stream support flags, and provider-specific request preparation where needed. Tests: pytest tests/test_request_executor_native_routing.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py --- src/rotator_library/client/executor.py | 75 +++++++++++++++---- .../native_provider/context.py | 8 +- .../providers/antigravity_provider.py | 32 +++++++- .../providers/claude_code_provider.py | 18 ++++- .../providers/codex_provider.py | 18 ++++- .../providers/copilot_provider.py | 18 ++++- .../providers/provider_interface.py | 34 +++++++++ tests/test_antigravity_provider_restore.py | 11 +++ tests/test_claude_code_provider.py | 9 +++ tests/test_codex_provider.py | 9 +++ tests/test_copilot_provider.py | 9 +++ tests/test_request_executor_native_routing.py | 11 ++- 12 files changed, 229 insertions(+), 23 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index e8af84f83..827927f28 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -597,7 +597,7 @@ async def _execute_provider_request( return await plugin.acompletion(self._http_client, **kwargs) if execution == "native" or (execution == "auto" and _provider_native_protocol(plugin, model, target)): - native_context = self._build_native_provider_context( + native_context, native_request = self._build_native_provider_context( provider, model, plugin, @@ -605,14 +605,16 @@ async def _execute_provider_request( credential_id, context, target, + raw_request=kwargs, + return_request=True, ) self._log_routing_trace( context, "routing_native_execution_selected", _target_trace(target) if target else {"provider": provider, "model": model}, - metadata={"protocol": native_context.protocol_name}, + metadata={"protocol": native_context.protocol_name, "operation": native_context.operation}, ) - return await self._get_native_executor().execute(dict(kwargs), native_context, NativeHTTPTransport(self._http_client)) + return await self._get_native_executor().execute(native_request, native_context, NativeHTTPTransport(self._http_client)) return await self._execute_litellm_request(kwargs, credential_secret, context=context, credential_id=credential_id) @@ -659,8 +661,11 @@ def _build_native_provider_context( credential_id: str, context: RequestContext, target: Optional[RouteTarget], + raw_request: Optional[Dict[str, Any]] = None, transport: str = "http", - ) -> NativeProviderContext: + stream: bool = False, + return_request: bool = False, + ) -> NativeProviderContext | tuple[NativeProviderContext, Dict[str, Any]]: """Build native provider context from provider declarations.""" if not plugin: @@ -670,24 +675,48 @@ def _build_native_provider_context( raise RoutingExecutionError(f"Provider {provider} has no native protocol declaration") if not hasattr(plugin, "get_native_endpoint") or not hasattr(plugin, "get_native_headers"): raise RoutingExecutionError(f"Provider {provider} has no native endpoint/header helpers") - endpoint = plugin.get_native_endpoint(model=model, operation="chat") - headers = plugin.get_native_headers(credential_secret, model=model, operation="chat") - return NativeProviderContext( + public_model = model + native_model = plugin.normalize_native_model(model) if hasattr(plugin, "normalize_native_model") else _strip_provider_prefix(model) + request_payload = dict(raw_request or {}) + if native_model: + request_payload["model"] = native_model + operation = plugin.get_native_operation(native_model, request_payload, stream=stream) if hasattr(plugin, "get_native_operation") else "chat" + if hasattr(plugin, "prepare_native_request"): + prepared = plugin.prepare_native_request(request_payload, model=native_model, operation=operation) + if prepared is not request_payload: + request_payload = dict(prepared) + self._log_executor_trace( + context, + "provider_native_request_prepared", + request_payload, + direction="request", + stage="provider", + credential_id=credential_id, + metadata={"provider": provider, "model": public_model, "native_model": native_model, "operation": operation}, + ) + endpoint = plugin.get_native_endpoint(model=native_model, operation=operation) + headers = plugin.get_native_headers(credential_secret, model=native_model, operation=operation) + native_context = NativeProviderContext( provider=provider, - model=model, + model=native_model, protocol_name=protocol_name, endpoint=endpoint, + operation=operation, headers=headers, credential_id=credential_id, session_id=context.session_id, scope_key=context.usage_manager_key, classifier=context.classifier, transport=transport, - adapter_names=tuple(plugin.get_adapter_names(model) if hasattr(plugin, "get_adapter_names") else ()), - adapter_config=dict(plugin.get_adapter_config(model) if hasattr(plugin, "get_adapter_config") else {}), - field_cache_rules=_merged_field_cache_rules(provider, model, plugin), + adapter_names=tuple(plugin.get_adapter_names(native_model) if hasattr(plugin, "get_adapter_names") else ()), + adapter_config=dict(plugin.get_adapter_config(native_model) if hasattr(plugin, "get_adapter_config") else {}), + field_cache_rules=_merged_field_cache_rules(provider, public_model, plugin), transaction_logger=context.transaction_logger, + metadata={"public_model": public_model}, ) + if return_request: + return native_context, request_payload + return native_context async def execute( self, @@ -1370,7 +1399,7 @@ async def _execute_streaming( and _provider_native_protocol(plugin, model, target) and _provider_supports_native_streaming(plugin, model) ): - native_context = self._build_native_provider_context( + native_context, native_request = self._build_native_provider_context( provider, model, plugin, @@ -1378,15 +1407,18 @@ async def _execute_streaming( cred_context.stable_id, context, target, + raw_request=kwargs, transport="sse", + stream=True, + return_request=True, ) self._log_routing_trace( context, "routing_native_stream_execution_selected", _target_trace(target) if target else {"provider": provider, "model": model}, - metadata={"protocol": native_context.protocol_name}, + metadata={"protocol": native_context.protocol_name, "operation": native_context.operation}, ) - stream = self._get_native_executor().stream(dict(kwargs), native_context, NativeHTTPTransport(self._http_client)) + stream = self._get_native_executor().stream(native_request, native_context, NativeHTTPTransport(self._http_client)) elif plugin and plugin.has_custom_logic(): kwargs["credential_identifier"] = credential_secret self._log_executor_trace( @@ -2260,14 +2292,27 @@ def _provider_native_protocol(plugin: Any, model: str, target: Optional[RouteTar return None +def _strip_provider_prefix(model: str) -> str: + """Return model without the proxy-facing provider prefix.""" + + return model.split("/", 1)[1] if "/" in model else model + + def _provider_supports_native_streaming(plugin: Any, model: str) -> bool: """Return whether a provider declares native streaming support.""" support = getattr(plugin, "supports_native_streaming", None) if not callable(support): return False + operation = "chat" + resolver = getattr(plugin, "get_native_operation", None) + if callable(resolver): + try: + operation = resolver(model, {"model": model, "stream": True}, stream=True) + except TypeError: + operation = resolver(model) try: - return bool(support(model=model, operation="chat")) + return bool(support(model=model, operation=operation)) except TypeError: try: return bool(support(model)) diff --git a/src/rotator_library/native_provider/context.py b/src/rotator_library/native_provider/context.py index 52cc9130e..bacf49b3f 100644 --- a/src/rotator_library/native_provider/context.py +++ b/src/rotator_library/native_provider/context.py @@ -26,6 +26,7 @@ class NativeProviderContext: model: str protocol_name: str endpoint: str + operation: str = "chat" headers: dict[str, str] = field(default_factory=dict) credential_id: Optional[str] = None session_id: Optional[str] = None @@ -49,7 +50,8 @@ def protocol_context(self, *, target_protocol: Optional[str] = None) -> Protocol session_id=self.session_id, credential_stable_id=self.credential_id, transport=self.transport, - metadata=dict(self.metadata), + provider_options={"operation": self.operation}, + metadata={"operation": self.operation, **dict(self.metadata)}, ) def adapter_context(self) -> AdapterContext: @@ -64,7 +66,7 @@ def adapter_context(self) -> AdapterContext: scope_key=self.scope_key, classifier=self.classifier, transport=self.transport, - metadata=dict(self.metadata), + metadata={"operation": self.operation, **dict(self.metadata)}, adapter_config=dict(self.adapter_config), transaction_logger=self.transaction_logger, ) @@ -79,5 +81,5 @@ def field_cache_context(self) -> FieldCacheContext: session_id=self.session_id, conversation_id=self.scope_key, classifier=self.classifier, - metadata=dict(self.metadata), + metadata={"operation": self.operation, **dict(self.metadata)}, ) diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index eac7e0562..a608aea62 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -"""Antigravity provider integration skeleton restored from safe retired pieces.""" +"""Antigravity provider integration restored from safe retired pieces.""" from __future__ import annotations @@ -68,6 +68,7 @@ class AntigravityProvider(ProviderInterface): metadata={"purpose": "preserve Gemini thought signatures across Antigravity turns"}, ), ) + native_streaming_supported = True model_quota_groups = { "gemini": ["gemini-2.5-flash", "gemini-2.5-flash-lite", "gemini-3-pro-preview", "gemini-3-flash"], "claude": ["claude-sonnet-4.5", "claude-opus-4.5", "claude-opus-4.6"], @@ -108,6 +109,35 @@ def get_native_headers(self, credential_identifier: str, model: str = "", operat **ANTIGRAVITY_HEADERS, } + def get_native_operation(self, model: str = "", request: dict[str, Any] | None = None, stream: bool = False) -> str: + """Return the Gemini generate operation used by Antigravity endpoints.""" + + return "stream_generate" if stream else "generate" + + def normalize_native_model(self, model: str) -> str: + """Strip the proxy prefix and map public aliases to upstream names.""" + + clean = model.split("/", 1)[1] if model.startswith("antigravity/") else model + return self._alias_to_internal(clean) + + def prepare_native_request(self, request: dict[str, Any], model: str = "", operation: str = "") -> dict[str, Any]: + """Return a request with the upstream Antigravity model name. + + The provider intentionally keeps this to model alias handling only. Device + profile and fingerprint behavior stays out of the active integration until + it is verified against current service behavior. + """ + + prepared = dict(request) + if model: + prepared["model"] = model + return prepared + + def supports_native_streaming(self, model: str = "", operation: str = "generate") -> bool: + """Return true for the tested stream-generate endpoint only.""" + + return operation == "stream_generate" + def get_native_endpoint(self, model: str = "", operation: str = "generate") -> str: """Return Antigravity internal operation endpoints.""" diff --git a/src/rotator_library/providers/claude_code_provider.py b/src/rotator_library/providers/claude_code_provider.py index 12a03fb64..4094fa61c 100644 --- a/src/rotator_library/providers/claude_code_provider.py +++ b/src/rotator_library/providers/claude_code_provider.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -"""Claude Code provider integration skeleton for native protocol execution.""" +"""Claude Code provider integration for native Anthropic Messages execution.""" from __future__ import annotations @@ -29,6 +29,7 @@ class ClaudeCodeProvider(ProviderInterface): provider_env_name = "claude_code" protocol_name = "anthropic_messages" adapter_names = ("suppress_developer_role",) + native_streaming_supported = True field_cache_rules = ( FieldCacheRule( name="claude_code_thinking_signature", @@ -74,6 +75,21 @@ def get_native_headers(self, credential_identifier: str, model: str = "", operat "content-type": "application/json", } + def get_native_operation(self, model: str = "", request: dict[str, Any] | None = None, stream: bool = False) -> str: + """Claude Code uses the Anthropic Messages operation for completions.""" + + return "messages" + + def normalize_native_model(self, model: str) -> str: + """Strip the proxy provider prefix before sending upstream.""" + + return model.split("/", 1)[1] if model.startswith("claude_code/") else model + + def supports_native_streaming(self, model: str = "", operation: str = "messages") -> bool: + """Return true for the tested Anthropic Messages stream operation.""" + + return operation == "messages" + def get_native_endpoint(self, model: str = "", operation: str = "messages") -> str: """Return the provider endpoint for a native operation.""" diff --git a/src/rotator_library/providers/codex_provider.py b/src/rotator_library/providers/codex_provider.py index 68848902b..90b0ae720 100644 --- a/src/rotator_library/providers/codex_provider.py +++ b/src/rotator_library/providers/codex_provider.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -"""Codex provider integration skeleton for native Responses execution.""" +"""Codex provider integration for native Responses execution.""" from __future__ import annotations @@ -23,6 +23,7 @@ class CodexProvider(ProviderInterface): provider_env_name = "codex" protocol_name = "responses" adapter_names: tuple[str, ...] = () + native_streaming_supported = True field_cache_rules = ( FieldCacheRule( name="codex_previous_response_id", @@ -59,6 +60,21 @@ def get_native_headers(self, credential_identifier: str, model: str = "", operat return {"Authorization": f"Bearer {credential_identifier}", "content-type": "application/json"} + def get_native_operation(self, model: str = "", request: dict | None = None, stream: bool = False) -> str: + """Codex native calls use the Responses operation.""" + + return "responses" + + def normalize_native_model(self, model: str) -> str: + """Strip the proxy provider prefix before sending upstream.""" + + return model.split("/", 1)[1] if model.startswith("codex/") else model + + def supports_native_streaming(self, model: str = "", operation: str = "responses") -> bool: + """Return true for tested Responses stream payloads.""" + + return operation == "responses" + def get_native_endpoint(self, model: str = "", operation: str = "responses") -> str: """Return the native Codex endpoint for an operation.""" diff --git a/src/rotator_library/providers/copilot_provider.py b/src/rotator_library/providers/copilot_provider.py index 073bb2efa..d7f9f384b 100644 --- a/src/rotator_library/providers/copilot_provider.py +++ b/src/rotator_library/providers/copilot_provider.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -"""Copilot provider integration skeleton for native OpenAI Chat execution.""" +"""Copilot provider integration for native OpenAI Chat execution.""" from __future__ import annotations @@ -29,6 +29,7 @@ class CopilotProvider(ProviderInterface): adapter_names = ("suppress_developer_role",) field_cache_rules: tuple = () default_rotation_mode = "sequential" + native_streaming_supported = True async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: """Fetch Copilot-visible models with a safe fallback list.""" @@ -57,6 +58,21 @@ def get_native_headers(self, credential_identifier: str, model: str = "", operat "Copilot-Integration-Id": os.getenv("COPILOT_INTEGRATION_ID", "llm-api-key-proxy"), } + def get_native_operation(self, model: str = "", request: dict | None = None, stream: bool = False) -> str: + """Copilot exposes an OpenAI-compatible chat operation.""" + + return "chat" + + def normalize_native_model(self, model: str) -> str: + """Strip the proxy provider prefix before sending upstream.""" + + return model.split("/", 1)[1] if model.startswith("copilot/") else model + + def supports_native_streaming(self, model: str = "", operation: str = "chat") -> bool: + """Return true for OpenAI Chat SSE streams.""" + + return operation == "chat" + def get_native_endpoint(self, model: str = "", operation: str = "chat") -> str: """Return the Copilot endpoint for a native operation.""" diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index e73acebe0..3ce8817af 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -373,6 +373,40 @@ def supports_native_streaming(self, model: str = "", operation: str = "chat") -> return self.native_streaming_supported + def get_native_operation(self, model: str = "", request: Optional[Dict[str, Any]] = None, stream: bool = False) -> str: + """Return the provider-native operation for a request. + + Providers that expose native protocols often use operation names that are + not simply ``chat``: Anthropic-compatible providers use ``messages``, + Responses providers use ``responses``, and Gemini-style providers use a + generate operation. The default stays ``chat`` so existing OpenAI-chat + providers remain compatible unless they opt in to something richer. + """ + + return "chat" + + def normalize_native_model(self, model: str) -> str: + """Return the upstream model name for native provider calls. + + The proxy-facing model commonly includes a provider prefix such as + ``provider/model``. Native upstream APIs usually expect only ``model``. + Providers may override this for aliases, but stripping the first prefix + is the safe default for native execution. + """ + + return model.split("/", 1)[1] if "/" in model else model + + def prepare_native_request(self, request: Dict[str, Any], model: str = "", operation: str = "") -> Dict[str, Any]: + """Return a provider-adjusted native request payload. + + This hook is intentionally limited to payload shape/model aliases and is + called before protocol parsing. It must not add credentials; auth belongs + in ``get_native_headers()`` so traces never mix request payloads with + secrets. + """ + + return dict(request) + def get_model_pricing(self, model: str = "") -> Optional[Any]: """Return optional local pricing metadata for advisory cost tracking. diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py index d2a79d23d..46209e33e 100644 --- a/tests/test_antigravity_provider_restore.py +++ b/tests/test_antigravity_provider_restore.py @@ -60,9 +60,20 @@ def test_antigravity_model_aliases_and_tracking_normalization() -> None: provider = AntigravityProvider() assert provider._alias_to_internal("claude-sonnet-4.5") == "claude-sonnet-4-5" + assert provider.normalize_native_model("antigravity/claude-sonnet-4.5") == "claude-sonnet-4-5" assert provider.normalize_model_for_tracking("antigravity/claude-sonnet-4-5") == "antigravity/claude-sonnet-4.5" +def test_antigravity_native_operation_model_and_stream_support() -> None: + provider = AntigravityProvider() + + assert provider.get_native_operation("gemini-3-flash", {}, stream=False) == "generate" + assert provider.get_native_operation("gemini-3-flash", {}, stream=True) == "stream_generate" + assert provider.supports_native_streaming("gemini-3-flash", operation="stream_generate") is True + assert provider.supports_native_streaming("gemini-3-flash", operation="generate") is False + assert provider.prepare_native_request({"model": "antigravity/gemini-3-pro-low"}, model="gemini-3-pro-preview", operation="generate")["model"] == "gemini-3-pro-preview" + + @pytest.mark.asyncio async def test_antigravity_get_models_filters_and_aliases_mocked_response(monkeypatch) -> None: monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") diff --git a/tests/test_claude_code_provider.py b/tests/test_claude_code_provider.py index ef910067e..8e1ffcada 100644 --- a/tests/test_claude_code_provider.py +++ b/tests/test_claude_code_provider.py @@ -70,6 +70,15 @@ def test_claude_code_provider_builds_native_headers_and_endpoint(monkeypatch) -> } +def test_claude_code_native_operation_model_and_stream_support() -> None: + provider = ClaudeCodeProvider() + + assert provider.get_native_operation("claude-sonnet-4-5", {}, stream=False) == "messages" + assert provider.normalize_native_model("claude_code/claude-sonnet-4-5") == "claude-sonnet-4-5" + assert provider.supports_native_streaming("claude-sonnet-4-5", operation="messages") is True + assert provider.supports_native_streaming("claude-sonnet-4-5", operation="chat") is False + + @pytest.mark.asyncio async def test_claude_code_provider_get_models_uses_mocked_models_endpoint(monkeypatch) -> None: monkeypatch.setenv("CLAUDE_CODE_API_BASE", "https://claude-code.test") diff --git a/tests/test_codex_provider.py b/tests/test_codex_provider.py index 779e4b4ba..a45b93623 100644 --- a/tests/test_codex_provider.py +++ b/tests/test_codex_provider.py @@ -50,6 +50,15 @@ def test_codex_provider_builds_native_headers_and_endpoint(monkeypatch) -> None: assert provider.get_native_headers("token") == {"Authorization": "Bearer token", "content-type": "application/json"} +def test_codex_native_operation_model_and_stream_support() -> None: + provider = CodexProvider() + + assert provider.get_native_operation("gpt-5.1-codex", {}, stream=False) == "responses" + assert provider.normalize_native_model("codex/gpt-5.1-codex") == "gpt-5.1-codex" + assert provider.supports_native_streaming("gpt-5.1-codex", operation="responses") is True + assert provider.supports_native_streaming("gpt-5.1-codex", operation="chat") is False + + @pytest.mark.asyncio async def test_codex_provider_get_models_filters_codex_models(monkeypatch) -> None: monkeypatch.setenv("CODEX_API_BASE", "https://codex.test") diff --git a/tests/test_copilot_provider.py b/tests/test_copilot_provider.py index bc0f3c929..9a5f252d5 100644 --- a/tests/test_copilot_provider.py +++ b/tests/test_copilot_provider.py @@ -69,6 +69,15 @@ def test_copilot_provider_builds_native_headers_and_endpoint(monkeypatch) -> Non } +def test_copilot_native_operation_model_and_stream_support() -> None: + provider = CopilotProvider() + + assert provider.get_native_operation("gpt-4.1", {}, stream=False) == "chat" + assert provider.normalize_native_model("copilot/gpt-4.1") == "gpt-4.1" + assert provider.supports_native_streaming("gpt-4.1", operation="chat") is True + assert provider.supports_native_streaming("gpt-4.1", operation="responses") is False + + @pytest.mark.asyncio async def test_copilot_provider_get_models_uses_mocked_endpoint(monkeypatch) -> None: monkeypatch.setenv("COPILOT_API_BASE", "https://copilot.test") diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index fcd49c6b0..037339bf6 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -52,7 +52,13 @@ def get_native_endpoint(self, model="", operation="chat"): return "https://native.test/chat" def get_native_headers(self, credential_identifier, model="", operation="chat"): - return {"Authorization": f"Bearer {credential_identifier}"} + return {"Authorization": f"Bearer {credential_identifier}", "X-Operation": operation, "X-Model": model} + + def get_native_operation(self, model="", request=None, stream=False): + return "messages" if stream else "chat" + + def normalize_native_model(self, model=""): + return model.split("/", 1)[1] if "/" in model else model def get_adapter_names(self, model=""): return () @@ -147,6 +153,9 @@ async def test_native_declared_provider_uses_native_executor_in_auto_mode() -> N assert response["id"] == "chat_native" assert http_client.calls[0]["endpoint"] == "https://native.test/chat" assert http_client.calls[0]["headers"]["Authorization"] == "Bearer secret" + assert http_client.calls[0]["headers"]["X-Operation"] == "chat" + assert http_client.calls[0]["headers"]["X-Model"] == "gpt-test" + assert http_client.calls[0]["json"]["model"] == "gpt-test" @pytest.mark.asyncio From 60fd7c1203410248c1f4632ef7c12cef55f7fd4f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:14:10 +0200 Subject: [PATCH 109/182] test(providers): cover priority native execution paths Adds provider request preparation for Codex and Antigravity so routed chat-style payloads become native Responses input and Gemini contents before protocol parsing. Adds mock-live RequestExecutor tests for Claude Code, Codex, Copilot, and Antigravity to prove endpoint selection, model normalization, headers, adapter behavior, request shape conversion, and native response formatting through the actual native executor path. Tests: pytest tests/test_request_executor_native_routing.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py --- .../providers/antigravity_provider.py | 21 +++- .../providers/codex_provider.py | 26 ++++- tests/test_antigravity_provider_restore.py | 17 ++++ tests/test_codex_provider.py | 14 +++ tests/test_request_executor_native_routing.py | 99 +++++++++++++++++++ 5 files changed, 172 insertions(+), 5 deletions(-) diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index a608aea62..6b2de1344 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -121,16 +121,18 @@ def normalize_native_model(self, model: str) -> str: return self._alias_to_internal(clean) def prepare_native_request(self, request: dict[str, Any], model: str = "", operation: str = "") -> dict[str, Any]: - """Return a request with the upstream Antigravity model name. + """Return a request with the upstream model and Gemini contents shape. - The provider intentionally keeps this to model alias handling only. Device - profile and fingerprint behavior stays out of the active integration until - it is verified against current service behavior. + The provider intentionally keeps this to model alias and message-shape + handling only. Device profile and fingerprint behavior stays out of the + active integration until it is verified against current service behavior. """ prepared = dict(request) if model: prepared["model"] = model + if "contents" not in prepared and isinstance(prepared.get("messages"), list): + prepared["contents"] = [_message_to_gemini_content(message) for message in prepared.pop("messages")] return prepared def supports_native_streaming(self, model: str = "", operation: str = "generate") -> bool: @@ -189,3 +191,14 @@ def _models_from_response(self, payload: dict[str, Any]) -> list[str]: if public not in EXCLUDED_MODELS and public in AVAILABLE_MODELS and public not in result: result.append(public) return result + + +def _message_to_gemini_content(message: Any) -> dict[str, Any]: + """Return a minimal Gemini content item from an OpenAI-style message.""" + + if not isinstance(message, dict): + return {"role": "user", "parts": [{"text": str(message)}]} + role = "model" if message.get("role") == "assistant" else "user" + content = message.get("content", "") + parts = content if isinstance(content, list) else [{"text": str(content)}] + return {"role": role, "parts": parts} diff --git a/src/rotator_library/providers/codex_provider.py b/src/rotator_library/providers/codex_provider.py index 90b0ae720..d7ac737be 100644 --- a/src/rotator_library/providers/codex_provider.py +++ b/src/rotator_library/providers/codex_provider.py @@ -6,7 +6,7 @@ from __future__ import annotations import os -from typing import List +from typing import Any, List import httpx @@ -75,6 +75,22 @@ def supports_native_streaming(self, model: str = "", operation: str = "responses return operation == "responses" + def prepare_native_request(self, request: dict[str, Any], model: str = "", operation: str = "") -> dict[str, Any]: + """Convert chat-style proxy input into Responses API input. + + Direct `/v1/responses` calls already provide `input`; routed chat calls + usually provide `messages`. Converting here keeps the reusable Responses + protocol focused on its native shape while still making Codex mock-live + through the generic native executor. + """ + + prepared = dict(request) + if model: + prepared["model"] = model + if "input" not in prepared and isinstance(prepared.get("messages"), list): + prepared["input"] = [_message_to_responses_input(message) for message in prepared.pop("messages")] + return prepared + def get_native_endpoint(self, model: str = "", operation: str = "responses") -> str: """Return the native Codex endpoint for an operation.""" @@ -86,3 +102,11 @@ def _with_prefix(model: str) -> str: if model.startswith("codex/"): return model return f"codex/{model}" + + +def _message_to_responses_input(message: Any) -> dict[str, Any]: + """Return a minimal Responses input item from an OpenAI-style message.""" + + if not isinstance(message, dict): + return {"role": "user", "content": str(message)} + return {"role": message.get("role", "user"), "content": message.get("content", "")} diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py index 46209e33e..878be5948 100644 --- a/tests/test_antigravity_provider_restore.py +++ b/tests/test_antigravity_provider_restore.py @@ -74,6 +74,23 @@ def test_antigravity_native_operation_model_and_stream_support() -> None: assert provider.prepare_native_request({"model": "antigravity/gemini-3-pro-low"}, model="gemini-3-pro-preview", operation="generate")["model"] == "gemini-3-pro-preview" +def test_antigravity_prepare_native_request_converts_messages_to_gemini_contents() -> None: + provider = AntigravityProvider() + + prepared = provider.prepare_native_request( + {"messages": [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi"}]}, + model="gemini-3-flash", + operation="generate", + ) + + assert prepared["model"] == "gemini-3-flash" + assert prepared["contents"] == [ + {"role": "user", "parts": [{"text": "hello"}]}, + {"role": "model", "parts": [{"text": "hi"}]}, + ] + assert "messages" not in prepared + + @pytest.mark.asyncio async def test_antigravity_get_models_filters_and_aliases_mocked_response(monkeypatch) -> None: monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") diff --git a/tests/test_codex_provider.py b/tests/test_codex_provider.py index a45b93623..d04a286a0 100644 --- a/tests/test_codex_provider.py +++ b/tests/test_codex_provider.py @@ -59,6 +59,20 @@ def test_codex_native_operation_model_and_stream_support() -> None: assert provider.supports_native_streaming("gpt-5.1-codex", operation="chat") is False +def test_codex_prepare_native_request_converts_messages_to_responses_input() -> None: + provider = CodexProvider() + + prepared = provider.prepare_native_request( + {"model": "codex/gpt-5.1-codex", "messages": [{"role": "user", "content": "hello"}]}, + model="gpt-5.1-codex", + operation="responses", + ) + + assert prepared["model"] == "gpt-5.1-codex" + assert prepared["input"] == [{"role": "user", "content": "hello"}] + assert "messages" not in prepared + + @pytest.mark.asyncio async def test_codex_provider_get_models_filters_codex_models(monkeypatch) -> None: monkeypatch.setenv("CODEX_API_BASE", "https://codex.test") diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 037339bf6..7b7b91074 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -8,6 +8,10 @@ from rotator_library.client.executor import RequestExecutor, RoutingExecutionError from rotator_library.core.types import RequestContext from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule +from rotator_library.providers.antigravity_provider import AntigravityProvider +from rotator_library.providers.claude_code_provider import ClaudeCodeProvider +from rotator_library.providers.codex_provider import CodexProvider +from rotator_library.providers.copilot_provider import CopilotProvider from rotator_library.routing import parse_route_target @@ -132,6 +136,18 @@ def _context(target=None) -> RequestContext: ) +def _provider_context(provider: str, model: str, kwargs: dict, target=None) -> RequestContext: + return RequestContext( + model=model, + provider=provider, + kwargs=kwargs, + streaming=False, + credentials=["cred"], + deadline=9999999999.0, + routing_targets=(target,) if target else None, + ) + + def _executor(http_client=None) -> RequestExecutor: executor = RequestExecutor.__new__(RequestExecutor) executor._http_client = http_client or FakeHTTPClient() @@ -219,6 +235,89 @@ def test_native_context_merges_json_field_cache_rules(monkeypatch, tmp_path) -> assert native_context.field_cache_rules[0].path == "json.path" +@pytest.mark.asyncio +async def test_claude_code_provider_runs_mock_live_native_request(monkeypatch) -> None: + monkeypatch.setenv("CLAUDE_CODE_API_BASE", "https://claude-code.test") + http_client = SequencedHTTPClient([ + {"id": "msg_1", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "ok"}], "usage": {"input_tokens": 1, "output_tokens": 1}} + ]) + provider = ClaudeCodeProvider() + target = parse_route_target("claude_code/claude-sonnet-4-5") + context = _provider_context( + "claude_code", + "claude_code/claude-sonnet-4-5", + {"model": "claude_code/claude-sonnet-4-5", "messages": [{"role": "developer", "content": "rules"}, {"role": "user", "content": "hi"}]}, + target, + ) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request("claude_code", context.model, provider, "secret", "stable", dict(context.kwargs), context) + + assert response["id"] == "msg_1" + assert http_client.calls[0]["endpoint"] == "https://claude-code.test/v1/messages" + assert http_client.calls[0]["json"]["model"] == "claude-sonnet-4-5" + assert http_client.calls[0]["json"]["messages"][0]["role"] == "user" + + +@pytest.mark.asyncio +async def test_codex_provider_runs_mock_live_native_request(monkeypatch) -> None: + monkeypatch.setenv("CODEX_API_BASE", "https://codex.test") + http_client = SequencedHTTPClient([ + {"id": "resp_1", "object": "response", "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "ok"}]}]} + ]) + provider = CodexProvider() + target = parse_route_target("codex/gpt-5.1-codex") + context = _provider_context("codex", "codex/gpt-5.1-codex", {"model": "codex/gpt-5.1-codex", "messages": [{"role": "user", "content": "hi"}]}, target) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request("codex", context.model, provider, "secret", "stable", dict(context.kwargs), context) + + assert response["id"] == "resp_1" + assert http_client.calls[0]["endpoint"] == "https://codex.test/v1/responses" + assert http_client.calls[0]["json"]["model"] == "gpt-5.1-codex" + assert http_client.calls[0]["json"]["input"][0]["content"] == [{"type": "text", "text": "hi"}] + assert "messages" not in http_client.calls[0]["json"] + + +@pytest.mark.asyncio +async def test_copilot_provider_runs_mock_live_native_request(monkeypatch) -> None: + monkeypatch.setenv("COPILOT_API_BASE", "https://copilot.test") + http_client = SequencedHTTPClient([ + {"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]} + ]) + provider = CopilotProvider() + target = parse_route_target("copilot/gpt-4.1") + context = _provider_context("copilot", "copilot/gpt-4.1", {"model": "copilot/gpt-4.1", "messages": [{"role": "developer", "content": "rules"}, {"role": "user", "content": "hi"}]}, target) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request("copilot", context.model, provider, "secret", "stable", dict(context.kwargs), context) + + assert response["id"] == "chat_1" + assert http_client.calls[0]["endpoint"] == "https://copilot.test/chat/completions" + assert http_client.calls[0]["json"]["model"] == "gpt-4.1" + assert http_client.calls[0]["json"]["messages"][0]["role"] == "system" + + +@pytest.mark.asyncio +async def test_antigravity_provider_runs_mock_live_native_request(monkeypatch) -> None: + monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") + http_client = SequencedHTTPClient([ + {"candidates": [{"content": {"role": "model", "parts": [{"text": "ok"}]}, "finishReason": "STOP"}], "usageMetadata": {"totalTokenCount": 2}} + ]) + provider = AntigravityProvider() + target = parse_route_target("antigravity/claude-sonnet-4.5") + context = _provider_context("antigravity", "antigravity/claude-sonnet-4.5", {"model": "antigravity/claude-sonnet-4.5", "messages": [{"role": "user", "content": "hi"}]}, target) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request("antigravity", context.model, provider, "secret", "stable", dict(context.kwargs), context) + + assert response["candidates"][0]["content"]["parts"][0]["text"] == "ok" + assert http_client.calls[0]["endpoint"] == "https://antigravity.test/v1internal:streamGenerateContent?alt=sse" + assert http_client.calls[0]["json"]["model"] == "claude-sonnet-4-5" + assert http_client.calls[0]["json"]["contents"][0]["parts"][0]["text"] == "hi" + assert "messages" not in http_client.calls[0]["json"] + + def test_native_context_raises_on_invalid_field_cache_config(monkeypatch, tmp_path) -> None: config_path = tmp_path / "config.json" config_path.write_text(json.dumps({"field_cache": {"provider": {"*": [{"name": "bad", "source": "response"}]}}}), encoding="utf-8") From 95b91355c47cbd5bee867facb7817c4b56c1c96f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:15:07 +0200 Subject: [PATCH 110/182] test(providers): preserve gemini cli custom path Adds regression coverage that providers with custom logic and native declarations still use the custom execution path in auto mode. Also asserts Gemini CLI keeps native streaming disabled while exposing protocol declarations and normalized model metadata for future native work. Tests: pytest tests/test_request_executor_native_routing.py tests/test_gemini_cli_protocol_declarations.py --- .../test_gemini_cli_protocol_declarations.py | 2 ++ tests/test_request_executor_native_routing.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/tests/test_gemini_cli_protocol_declarations.py b/tests/test_gemini_cli_protocol_declarations.py index 2a99adfe1..9364d35df 100644 --- a/tests/test_gemini_cli_protocol_declarations.py +++ b/tests/test_gemini_cli_protocol_declarations.py @@ -12,6 +12,8 @@ async def test_gemini_cli_declares_gemini_protocol_without_changing_custom_logic assert provider.has_custom_logic() is True assert provider.get_protocol_name("gemini-3-flash-preview") == "gemini" assert provider.get_adapter_names("gemini-3-flash-preview") == () + assert provider.supports_native_streaming("gemini-3-flash-preview", operation="generate") is False + assert provider.normalize_native_model("gemini_cli/gemini-3-flash-preview") == "gemini-3-flash-preview" @pytest.mark.asyncio diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 7b7b91074..ea8666ccb 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -123,6 +123,15 @@ async def acompletion(self, client, **kwargs): self.calls.append(kwargs) return {"id": "custom"} + def get_protocol_name(self, model=""): + return "gemini" + + def get_native_endpoint(self, model="", operation="chat"): + return "https://native.test/should-not-run" + + def get_native_headers(self, credential_identifier, model="", operation="chat"): + return {"Authorization": f"Bearer {credential_identifier}"} + def _context(target=None) -> RequestContext: return RequestContext( @@ -174,6 +183,22 @@ async def test_native_declared_provider_uses_native_executor_in_auto_mode() -> N assert http_client.calls[0]["json"]["model"] == "gpt-test" +@pytest.mark.asyncio +async def test_auto_mode_prefers_custom_logic_over_native_declaration() -> None: + http_client = FakeHTTPClient() + plugin = CustomPlugin() + context = _context(parse_route_target("provider/gpt-test")) + context.routing_target_index = 0 + + response = await _executor(http_client)._execute_provider_request( + "provider", "provider/gpt-test", plugin, "secret", "stable", dict(context.kwargs), context + ) + + assert response == {"id": "custom"} + assert plugin.calls[0]["credential_identifier"] == "secret" + assert http_client.calls == [] + + @pytest.mark.asyncio async def test_request_executor_reuses_native_field_cache_store() -> None: http_client = SequencedHTTPClient( From 216684437e3178d1643526e20f269121146ec633 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:27:12 +0200 Subject: [PATCH 111/182] fix(providers): harden native request safety Disables priority-provider native streaming opt-ins until the generic stream wrapper and native transport can execute them safely through RequestExecutor. Drops LiteLLM-only kwargs before native protocol parsing so unknown-field preservation cannot leak LiteLLM routing/logging controls upstream. Adds a safe Antigravity envelope adapter for the restored internal endpoint and verifies the envelope in mock-live native routing tests. Tests: pytest tests/test_request_executor_native_routing.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_protocol_anthropic_messages.py tests/test_protocol_responses.py tests/test_protocol_openai_chat.py tests/test_protocol_gemini.py tests/test_adapter_registry.py --- src/rotator_library/adapters/builtin.py | 31 +++++++++++++++++++ src/rotator_library/client/executor.py | 30 +++++++++++++++++- .../providers/antigravity_provider.py | 18 +++++++---- .../providers/claude_code_provider.py | 6 ++-- .../providers/codex_provider.py | 6 ++-- .../providers/copilot_provider.py | 6 ++-- tests/test_antigravity_provider_restore.py | 6 ++-- tests/test_claude_code_provider.py | 2 +- tests/test_codex_provider.py | 2 +- tests/test_copilot_provider.py | 2 +- tests/test_request_executor_native_routing.py | 21 +++++++++++-- 11 files changed, 106 insertions(+), 24 deletions(-) diff --git a/src/rotator_library/adapters/builtin.py b/src/rotator_library/adapters/builtin.py index 7406729b0..05cf49230 100644 --- a/src/rotator_library/adapters/builtin.py +++ b/src/rotator_library/adapters/builtin.py @@ -7,6 +7,7 @@ from copy import deepcopy from typing import Any +from uuid import uuid4 from ..field_cache.paths import extract_path, inject_path from .base import AdapterContext, PayloadAdapter @@ -149,6 +150,36 @@ def _transform_stage(self, payload: Any, context: AdapterContext, stage: str) -> return updated +class AntigravityEnvelopeAdapter(PayloadAdapter): + """Wrap Gemini payloads in the Antigravity internal request envelope. + + The active provider restores only stable envelope fields. Device profiles, + fingerprints, and other volatile client-emulation fields are intentionally + not generated here until they are verified against current service behavior. + """ + + name = "antigravity_envelope" + supported_stages = ("request",) + + async def transform_request(self, payload: Any, context: AdapterContext) -> Any: + if not isinstance(payload, dict) or "request" in payload: + return payload + config = context.config_for(self.name) + model = payload.get("model") or context.model + request_payload = {key: deepcopy(value) for key, value in payload.items() if key != "model"} + envelope = { + "model": model, + "request": request_payload, + "requestType": config.get("request_type", "CHAT_COMPLETION"), + "requestId": str(uuid4()), + "userAgent": config.get("user_agent"), + } + project = config.get("project") + if project: + envelope["project"] = project + return {key: value for key, value in envelope.items() if value is not None} + + def _delete_simple_path(payload: dict[str, Any], path: str) -> None: """Delete a simple dotted dict path after a conservative move operation.""" diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 827927f28..3d6e3ef1d 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -677,7 +677,7 @@ def _build_native_provider_context( raise RoutingExecutionError(f"Provider {provider} has no native endpoint/header helpers") public_model = model native_model = plugin.normalize_native_model(model) if hasattr(plugin, "normalize_native_model") else _strip_provider_prefix(model) - request_payload = dict(raw_request or {}) + request_payload = _native_request_payload(raw_request or {}) if native_model: request_payload["model"] = native_model operation = plugin.get_native_operation(native_model, request_payload, stream=stream) if hasattr(plugin, "get_native_operation") else "chat" @@ -2298,6 +2298,34 @@ def _strip_provider_prefix(model: str) -> str: return model.split("/", 1)[1] if "/" in model else model +_NATIVE_REQUEST_DROP_KEYS = { + "api_base", + "api_key", + "api_type", + "api_version", + "base_url", + "custom_llm_provider", + "drop_params", + "logger_fn", + "litellm_call_id", + "mock_response", + "organization", + "project", + "transaction_context", +} + + +def _native_request_payload(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Return provider-visible kwargs for native protocol parsing. + + Full executor calls prepare kwargs for LiteLLM before execution-mode routing. + Native providers must not receive LiteLLM-only routing, logging, or transport + controls because protocol adapters preserve unknown fields intentionally. + """ + + return {key: deepcopy(value) for key, value in kwargs.items() if key not in _NATIVE_REQUEST_DROP_KEYS and not key.startswith("litellm_")} + + def _provider_supports_native_streaming(plugin: Any, model: str) -> bool: """Return whether a provider declares native streaming support.""" diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index 6b2de1344..fd3a66cab 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -56,7 +56,7 @@ class AntigravityProvider(ProviderInterface): provider_env_name = "antigravity" protocol_name = "gemini" - adapter_names: tuple[str, ...] = () + adapter_names: tuple[str, ...] = ("antigravity_envelope",) field_cache_rules = ( FieldCacheRule( name="antigravity_thought_signature", @@ -68,7 +68,7 @@ class AntigravityProvider(ProviderInterface): metadata={"purpose": "preserve Gemini thought signatures across Antigravity turns"}, ), ) - native_streaming_supported = True + native_streaming_supported = False model_quota_groups = { "gemini": ["gemini-2.5-flash", "gemini-2.5-flash-lite", "gemini-3-pro-preview", "gemini-3-flash"], "claude": ["claude-sonnet-4.5", "claude-opus-4.5", "claude-opus-4.6"], @@ -136,9 +136,9 @@ def prepare_native_request(self, request: dict[str, Any], model: str = "", opera return prepared def supports_native_streaming(self, model: str = "", operation: str = "generate") -> bool: - """Return true for the tested stream-generate endpoint only.""" + """Return false until native stream wrapping is provider-safe.""" - return operation == "stream_generate" + return False def get_native_endpoint(self, model: str = "", operation: str = "generate") -> str: """Return Antigravity internal operation endpoints.""" @@ -148,9 +148,15 @@ def get_native_endpoint(self, model: str = "", operation: str = "generate") -> s return f"{self.get_api_base()}:streamGenerateContent?alt=sse" def get_adapter_config(self, model: str = "") -> dict[str, dict[str, Any]]: - """Configure minimal payload path copies needed by native tests.""" + """Configure the safe Antigravity internal request envelope.""" - return {} + return { + "antigravity_envelope": { + "project": os.getenv("ANTIGRAVITY_PROJECT", ""), + "user_agent": ANTIGRAVITY_HEADERS["User-Agent"], + "request_type": os.getenv("ANTIGRAVITY_REQUEST_TYPE", "CHAT_COMPLETION"), + } + } def get_model_tier_requirement(self, model: str) -> Optional[int]: """Antigravity exposes no restored model-tier restriction.""" diff --git a/src/rotator_library/providers/claude_code_provider.py b/src/rotator_library/providers/claude_code_provider.py index 4094fa61c..7e6a198e6 100644 --- a/src/rotator_library/providers/claude_code_provider.py +++ b/src/rotator_library/providers/claude_code_provider.py @@ -29,7 +29,7 @@ class ClaudeCodeProvider(ProviderInterface): provider_env_name = "claude_code" protocol_name = "anthropic_messages" adapter_names = ("suppress_developer_role",) - native_streaming_supported = True + native_streaming_supported = False field_cache_rules = ( FieldCacheRule( name="claude_code_thinking_signature", @@ -86,9 +86,9 @@ def normalize_native_model(self, model: str) -> str: return model.split("/", 1)[1] if model.startswith("claude_code/") else model def supports_native_streaming(self, model: str = "", operation: str = "messages") -> bool: - """Return true for the tested Anthropic Messages stream operation.""" + """Return false until the generic native stream wrapper is compatible.""" - return operation == "messages" + return False def get_native_endpoint(self, model: str = "", operation: str = "messages") -> str: """Return the provider endpoint for a native operation.""" diff --git a/src/rotator_library/providers/codex_provider.py b/src/rotator_library/providers/codex_provider.py index d7ac737be..7a882a972 100644 --- a/src/rotator_library/providers/codex_provider.py +++ b/src/rotator_library/providers/codex_provider.py @@ -23,7 +23,7 @@ class CodexProvider(ProviderInterface): provider_env_name = "codex" protocol_name = "responses" adapter_names: tuple[str, ...] = () - native_streaming_supported = True + native_streaming_supported = False field_cache_rules = ( FieldCacheRule( name="codex_previous_response_id", @@ -71,9 +71,9 @@ def normalize_native_model(self, model: str) -> str: return model.split("/", 1)[1] if model.startswith("codex/") else model def supports_native_streaming(self, model: str = "", operation: str = "responses") -> bool: - """Return true for tested Responses stream payloads.""" + """Return false until the generic native stream wrapper is compatible.""" - return operation == "responses" + return False def prepare_native_request(self, request: dict[str, Any], model: str = "", operation: str = "") -> dict[str, Any]: """Convert chat-style proxy input into Responses API input. diff --git a/src/rotator_library/providers/copilot_provider.py b/src/rotator_library/providers/copilot_provider.py index d7f9f384b..9d48846d4 100644 --- a/src/rotator_library/providers/copilot_provider.py +++ b/src/rotator_library/providers/copilot_provider.py @@ -29,7 +29,7 @@ class CopilotProvider(ProviderInterface): adapter_names = ("suppress_developer_role",) field_cache_rules: tuple = () default_rotation_mode = "sequential" - native_streaming_supported = True + native_streaming_supported = False async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: """Fetch Copilot-visible models with a safe fallback list.""" @@ -69,9 +69,9 @@ def normalize_native_model(self, model: str) -> str: return model.split("/", 1)[1] if model.startswith("copilot/") else model def supports_native_streaming(self, model: str = "", operation: str = "chat") -> bool: - """Return true for OpenAI Chat SSE streams.""" + """Return false until the generic native stream wrapper is compatible.""" - return operation == "chat" + return False def get_native_endpoint(self, model: str = "", operation: str = "chat") -> str: """Return the Copilot endpoint for a native operation.""" diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py index 878be5948..5809d465f 100644 --- a/tests/test_antigravity_provider_restore.py +++ b/tests/test_antigravity_provider_restore.py @@ -35,8 +35,8 @@ def test_antigravity_provider_restores_safe_declarations() -> None: provider = AntigravityProvider() assert provider.get_protocol_name("gemini-3-flash") == "gemini" - assert provider.get_adapter_names("gemini-3-flash") == () - assert provider.get_adapter_config("gemini-3-flash") == {} + assert provider.get_adapter_names("gemini-3-flash") == ("antigravity_envelope",) + assert provider.get_adapter_config("gemini-3-flash")["antigravity_envelope"]["request_type"] == "CHAT_COMPLETION" assert provider.get_model_tier_requirement("antigravity/gemini-3-flash") is None rules = provider.get_field_cache_rules("gemini-3-flash") assert rules[0].name == "antigravity_thought_signature" @@ -69,7 +69,7 @@ def test_antigravity_native_operation_model_and_stream_support() -> None: assert provider.get_native_operation("gemini-3-flash", {}, stream=False) == "generate" assert provider.get_native_operation("gemini-3-flash", {}, stream=True) == "stream_generate" - assert provider.supports_native_streaming("gemini-3-flash", operation="stream_generate") is True + assert provider.supports_native_streaming("gemini-3-flash", operation="stream_generate") is False assert provider.supports_native_streaming("gemini-3-flash", operation="generate") is False assert provider.prepare_native_request({"model": "antigravity/gemini-3-pro-low"}, model="gemini-3-pro-preview", operation="generate")["model"] == "gemini-3-pro-preview" diff --git a/tests/test_claude_code_provider.py b/tests/test_claude_code_provider.py index 8e1ffcada..cf47902cf 100644 --- a/tests/test_claude_code_provider.py +++ b/tests/test_claude_code_provider.py @@ -75,7 +75,7 @@ def test_claude_code_native_operation_model_and_stream_support() -> None: assert provider.get_native_operation("claude-sonnet-4-5", {}, stream=False) == "messages" assert provider.normalize_native_model("claude_code/claude-sonnet-4-5") == "claude-sonnet-4-5" - assert provider.supports_native_streaming("claude-sonnet-4-5", operation="messages") is True + assert provider.supports_native_streaming("claude-sonnet-4-5", operation="messages") is False assert provider.supports_native_streaming("claude-sonnet-4-5", operation="chat") is False diff --git a/tests/test_codex_provider.py b/tests/test_codex_provider.py index d04a286a0..60ba5790e 100644 --- a/tests/test_codex_provider.py +++ b/tests/test_codex_provider.py @@ -55,7 +55,7 @@ def test_codex_native_operation_model_and_stream_support() -> None: assert provider.get_native_operation("gpt-5.1-codex", {}, stream=False) == "responses" assert provider.normalize_native_model("codex/gpt-5.1-codex") == "gpt-5.1-codex" - assert provider.supports_native_streaming("gpt-5.1-codex", operation="responses") is True + assert provider.supports_native_streaming("gpt-5.1-codex", operation="responses") is False assert provider.supports_native_streaming("gpt-5.1-codex", operation="chat") is False diff --git a/tests/test_copilot_provider.py b/tests/test_copilot_provider.py index 9a5f252d5..3586c0995 100644 --- a/tests/test_copilot_provider.py +++ b/tests/test_copilot_provider.py @@ -74,7 +74,7 @@ def test_copilot_native_operation_model_and_stream_support() -> None: assert provider.get_native_operation("gpt-4.1", {}, stream=False) == "chat" assert provider.normalize_native_model("copilot/gpt-4.1") == "gpt-4.1" - assert provider.supports_native_streaming("gpt-4.1", operation="chat") is True + assert provider.supports_native_streaming("gpt-4.1", operation="chat") is False assert provider.supports_native_streaming("gpt-4.1", operation="responses") is False diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index ea8666ccb..e0b960bfd 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -339,8 +339,25 @@ async def test_antigravity_provider_runs_mock_live_native_request(monkeypatch) - assert response["candidates"][0]["content"]["parts"][0]["text"] == "ok" assert http_client.calls[0]["endpoint"] == "https://antigravity.test/v1internal:streamGenerateContent?alt=sse" assert http_client.calls[0]["json"]["model"] == "claude-sonnet-4-5" - assert http_client.calls[0]["json"]["contents"][0]["parts"][0]["text"] == "hi" - assert "messages" not in http_client.calls[0]["json"] + assert http_client.calls[0]["json"]["request"]["contents"][0]["parts"][0]["text"] == "hi" + assert http_client.calls[0]["json"]["requestType"] == "CHAT_COMPLETION" + assert "requestId" in http_client.calls[0]["json"] + assert "messages" not in http_client.calls[0]["json"]["request"] + + +def test_native_request_payload_drops_litellm_only_fields() -> None: + payload = executor_module._native_request_payload( + { + "model": "provider/gpt-test", + "messages": [{"role": "user", "content": "hi"}], + "custom_llm_provider": "openai", + "api_base": "https://litellm-only.test", + "transaction_context": {"id": "trace"}, + "litellm_call_id": "call", + } + ) + + assert payload == {"model": "provider/gpt-test", "messages": [{"role": "user", "content": "hi"}]} def test_native_context_raises_on_invalid_field_cache_config(monkeypatch, tmp_path) -> None: From 4247fa59c6712eb0ab3018dd93c5d052ee982637 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:44:19 +0200 Subject: [PATCH 112/182] fix(providers): close native streaming and antigravity gaps Makes explicit native streaming fail closed unless the provider supports it, preventing unsafe RequestExecutor streaming through the generic LiteLLM-oriented wrapper. Switches Antigravity non-stream native requests to the JSON generateContent endpoint, keeps streamGenerateContent only for stream operations, and injects thought signatures inside the safe request envelope. Tests: pytest tests/test_request_executor_native_routing.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_protocol_anthropic_messages.py tests/test_protocol_responses.py tests/test_protocol_openai_chat.py tests/test_protocol_gemini.py tests/test_adapter_registry.py --- src/rotator_library/client/executor.py | 25 ++++++++++++++----- .../providers/antigravity_provider.py | 12 ++++++--- tests/test_antigravity_provider_restore.py | 5 +++- tests/test_request_executor_native_routing.py | 19 +++++++++++++- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 3d6e3ef1d..d6dd8312a 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1393,12 +1393,7 @@ async def _execute_streaming( execution = target.execution if target else "auto" # Make the API call - if execution == "native" or ( - execution == "auto" - and plugin - and _provider_native_protocol(plugin, model, target) - and _provider_supports_native_streaming(plugin, model) - ): + if _should_use_native_streaming(plugin, model, target, execution, provider): native_context, native_request = self._build_native_provider_context( provider, model, @@ -2348,6 +2343,24 @@ def _provider_supports_native_streaming(plugin: Any, model: str) -> bool: return bool(support()) +def _should_use_native_streaming(plugin: Any, model: str, target: Optional[RouteTarget], execution: str, provider: str) -> bool: + """Return whether streaming may use the native executor. + + Explicit `@native` routing is still constrained by provider capability. + Falling through to native streaming when a provider has not opted in is unsafe + because the generic stream wrapper currently expects LiteLLM-shaped chunks. + """ + + if execution == "native": + if _provider_supports_native_streaming(plugin, model): + return True + raise RoutingExecutionError( + f"Provider {provider} does not support native streaming for {model}", + error_type="configuration_error", + ) + return bool(execution == "auto" and plugin and _provider_native_protocol(plugin, model, target) and _provider_supports_native_streaming(plugin, model)) + + def _redact_context_field_cache_paths(payload: Any, context: RequestContext, direction: str, plugin: Any) -> Any: """Redact configured field-cache paths before executor-level traces. diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index fd3a66cab..c6d61e406 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -64,7 +64,7 @@ class AntigravityProvider(ProviderInterface): path="candidates.*.content.parts.*.thoughtSignature", mode="all", scope=("provider", "model", "credential", "session"), - inject=FieldCacheInjection(target="request", path="metadata.thoughtSignatures", as_list=True), + inject=FieldCacheInjection(target="request", path="request.metadata.thoughtSignatures", as_list=True), metadata={"purpose": "preserve Gemini thought signatures across Antigravity turns"}, ), ) @@ -102,12 +102,14 @@ def get_native_headers(self, credential_identifier: str, model: str = "", operat is verified with tests. """ - return { + headers = { "Authorization": f"Bearer {credential_identifier}", "Content-Type": "application/json", - "Accept": "text/event-stream", **ANTIGRAVITY_HEADERS, } + if operation == "stream_generate": + headers["Accept"] = "text/event-stream" + return headers def get_native_operation(self, model: str = "", request: dict[str, Any] | None = None, stream: bool = False) -> str: """Return the Gemini generate operation used by Antigravity endpoints.""" @@ -145,7 +147,9 @@ def get_native_endpoint(self, model: str = "", operation: str = "generate") -> s if operation == "models": return f"{self.get_api_base()}:fetchAvailableModels" - return f"{self.get_api_base()}:streamGenerateContent?alt=sse" + if operation == "stream_generate": + return f"{self.get_api_base()}:streamGenerateContent?alt=sse" + return f"{self.get_api_base()}:generateContent" def get_adapter_config(self, model: str = "") -> dict[str, dict[str, Any]]: """Configure the safe Antigravity internal request envelope.""" diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py index 5809d465f..b8039ff57 100644 --- a/tests/test_antigravity_provider_restore.py +++ b/tests/test_antigravity_provider_restore.py @@ -49,10 +49,13 @@ def test_antigravity_provider_builds_static_headers_without_device_profile(monke headers = provider.get_native_headers("token") - assert provider.get_native_endpoint(operation="generate") == "https://antigravity.test/v1internal:streamGenerateContent?alt=sse" + assert provider.get_native_endpoint(operation="generate") == "https://antigravity.test/v1internal:generateContent" + assert provider.get_native_endpoint(operation="stream_generate") == "https://antigravity.test/v1internal:streamGenerateContent?alt=sse" assert provider.get_native_endpoint(operation="models") == "https://antigravity.test/v1internal:fetchAvailableModels" assert headers["Authorization"] == "Bearer token" assert headers["X-Goog-Api-Client"] == "google-cloud-sdk vscode_cloudshelleditor/0.1" + assert "Accept" not in headers + assert provider.get_native_headers("token", operation="stream_generate")["Accept"] == "text/event-stream" assert "X-Client-Device-Id" not in headers diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index e0b960bfd..6be1534ad 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -337,7 +337,7 @@ async def test_antigravity_provider_runs_mock_live_native_request(monkeypatch) - response = await _executor(http_client)._execute_provider_request("antigravity", context.model, provider, "secret", "stable", dict(context.kwargs), context) assert response["candidates"][0]["content"]["parts"][0]["text"] == "ok" - assert http_client.calls[0]["endpoint"] == "https://antigravity.test/v1internal:streamGenerateContent?alt=sse" + assert http_client.calls[0]["endpoint"] == "https://antigravity.test/v1internal:generateContent" assert http_client.calls[0]["json"]["model"] == "claude-sonnet-4-5" assert http_client.calls[0]["json"]["request"]["contents"][0]["parts"][0]["text"] == "hi" assert http_client.calls[0]["json"]["requestType"] == "CHAT_COMPLETION" @@ -360,6 +360,23 @@ def test_native_request_payload_drops_litellm_only_fields() -> None: assert payload == {"model": "provider/gpt-test", "messages": [{"role": "user", "content": "hi"}]} +def test_explicit_native_streaming_fails_when_provider_does_not_support_it() -> None: + target = parse_route_target("provider/gpt-test@native") + + with pytest.raises(RoutingExecutionError) as exc: + executor_module._should_use_native_streaming(NativePlugin(), "provider/gpt-test", target, "native", "provider") + + assert exc.value.error_type == "configuration_error" + + +def test_antigravity_cache_injection_targets_safe_envelope() -> None: + provider = AntigravityProvider() + + rules = provider.get_field_cache_rules("gemini-3-flash") + + assert rules[0].inject.path == "request.metadata.thoughtSignatures" + + def test_native_context_raises_on_invalid_field_cache_config(monkeypatch, tmp_path) -> None: config_path = tmp_path / "config.json" config_path.write_text(json.dumps({"field_cache": {"provider": {"*": [{"name": "bad", "source": "response"}]}}}), encoding="utf-8") From f0b9920e3b8bf9d8114f2f1f4a0390ff09f1b198 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:49:24 +0200 Subject: [PATCH 113/182] fix(providers): make antigravity envelope idempotence explicit Changes the Antigravity envelope adapter to skip wrapping only when the controlled envelope markers are present, not merely because a user payload contains a request key. Adds regression coverage for user-supplied request keys and already-wrapped Antigravity envelopes. Tests: pytest tests/test_request_executor_native_routing.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_protocol_anthropic_messages.py tests/test_protocol_responses.py tests/test_protocol_openai_chat.py tests/test_protocol_gemini.py tests/test_adapter_registry.py --- src/rotator_library/adapters/builtin.py | 10 +++++++++- tests/test_adapter_registry.py | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/rotator_library/adapters/builtin.py b/src/rotator_library/adapters/builtin.py index 05cf49230..1475d70a7 100644 --- a/src/rotator_library/adapters/builtin.py +++ b/src/rotator_library/adapters/builtin.py @@ -162,7 +162,9 @@ class AntigravityEnvelopeAdapter(PayloadAdapter): supported_stages = ("request",) async def transform_request(self, payload: Any, context: AdapterContext) -> Any: - if not isinstance(payload, dict) or "request" in payload: + if not isinstance(payload, dict): + return payload + if _looks_like_antigravity_envelope(payload): return payload config = context.config_for(self.name) model = payload.get("model") or context.model @@ -180,6 +182,12 @@ async def transform_request(self, payload: Any, context: AdapterContext) -> Any: return {key: value for key, value in envelope.items() if value is not None} +def _looks_like_antigravity_envelope(payload: dict[str, Any]) -> bool: + """Return whether a payload already has the controlled envelope shape.""" + + return "request" in payload and "requestType" in payload and "requestId" in payload + + def _delete_simple_path(payload: dict[str, Any], path: str) -> None: """Delete a simple dotted dict path after a conservative move operation.""" diff --git a/tests/test_adapter_registry.py b/tests/test_adapter_registry.py index 739a079d8..cef410c85 100644 --- a/tests/test_adapter_registry.py +++ b/tests/test_adapter_registry.py @@ -113,6 +113,31 @@ async def test_field_rename_adapter_copies_and_moves_configured_fields() -> None assert payload["old"]["field"] == "value" +@pytest.mark.asyncio +async def test_antigravity_envelope_wraps_user_request_key() -> None: + adapter = get_adapter("antigravity_envelope") + + result = await adapter.transform_request( + {"model": "gemini-3-flash", "request": {"user_supplied": True}, "contents": [{"parts": [{"text": "hi"}]}]}, + AdapterContext(adapter_config={"antigravity_envelope": {"request_type": "CHAT_COMPLETION", "user_agent": "test-agent"}}), + ) + + assert result["requestType"] == "CHAT_COMPLETION" + assert result["requestId"] + assert result["request"]["request"] == {"user_supplied": True} + assert result["request"]["contents"][0]["parts"][0]["text"] == "hi" + + +@pytest.mark.asyncio +async def test_antigravity_envelope_is_idempotent_for_controlled_envelope() -> None: + adapter = get_adapter("antigravity_envelope") + payload = {"model": "gemini-3-flash", "request": {"contents": []}, "requestType": "CHAT_COMPLETION", "requestId": "id"} + + result = await adapter.transform_request(payload, AdapterContext()) + + assert result == payload + + @pytest.mark.asyncio async def test_adapter_chain_order_is_preserved() -> None: payload = {"model": "public", "messages": [{"role": "developer", "content": "rules"}]} From 124141c1ca4c2f4c951fefcbc7b52a4430721c68 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 12:59:18 +0200 Subject: [PATCH 114/182] docs(experimental): plan routing fallback correction Adds the Phase 6b corrective plan for ordered fallback safety, hard-stop routing errors, structured response classification, streaming fallback policy enforcement, and sanitized target summaries. Tests: not run (planning document only) --- .../phase-6b-routing-fallback-correctness.md | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 docs/experimental/phase-6b-routing-fallback-correctness.md diff --git a/docs/experimental/phase-6b-routing-fallback-correctness.md b/docs/experimental/phase-6b-routing-fallback-correctness.md new file mode 100644 index 000000000..3d53895bd --- /dev/null +++ b/docs/experimental/phase-6b-routing-fallback-correctness.md @@ -0,0 +1,91 @@ +# Phase 6b: Routing/Fallback Correctness And Structured Error Safety + +## Goal + +Correct the Phase 6 audit findings. Phase 6 built ordered fallback groups and live executor wrappers, but fallback still needs stronger safety around hard-stop errors, structured error responses, streaming policy enforcement, and sanitized target summaries. + +## Non-Goals + +- Do not replace `UsageManager`, `SessionTracker`, retry-after parsing, or credential rotation behavior. +- Do not implement rich target-group selector syntax beyond ordered fallback chains. +- Do not add security or multi-user features. +- Do not make native streaming work; unsupported native streaming remains fail-closed. +- Do not include raw provider messages or credentials in cross-target summaries. +- Do not commit user-facing reports. + +## Current State + +- `FallbackPolicy.should_fallback()` uses group `failover_on` and `stop_on` directly, so a group can currently override auth/permanent errors into fallback eligibility. +- `_route_error_type_from_response()` only recognizes retry-like summaries and a few proxy error types; it does not robustly inspect structured status/code/details fields. +- `FallbackGroup.streaming_policy` exists but live streaming fallback does not use it. +- Streaming fallback correctly tracks visible output, but needs stronger tests for error/control frames and `never` streaming policy. +- Non-streaming fallback summaries are structural, but call sites still pass raw exception strings that should be avoided entirely. + +## Implementation Plan + +1. Add hard-stop route error categories. + - Add a non-overridable hard-stop set for auth, forbidden, invalid request, context-window, credential reauth, pre-request callback, cancellation, and configuration errors. + - Document that these are safety boundaries and group policies cannot opt into cross-target fallback for them. + +2. Normalize routing policy vocabulary. + - Add `normalize_route_error_type()` for aliases such as `auth`, `permission_denied`, `bad_request`, `validation`, `transient`, `network`, and `configuration`. + - Use it in policy, routing runner, retry helper, and executor route classification. + +3. Make hard stops win in `FallbackPolicy`. + - Streaming visible output still blocks fallback first. + - Hard-stop categories return false before group policy evaluation. + - Group stop/failover sets are normalized before matching. + +4. Validate routing group policy. + - Reject configured `failover_on` entries that normalize to hard-stop categories. + - Parse and validate `streaming_policy` from JSON routing config. + - Preserve environment override precedence. + +5. Enforce `FallbackGroup.streaming_policy` in live executor paths. + - `never` prevents streaming fallback even before output. + - `pre_output_only` remains the default. + - Trace metadata should include the policy used. + +6. Harden structured response classification. + - Inspect `error.type`, `error.code`, `error.status`, `error.details.status_code`, and detail classification fields before summaries. + - Hard-stop signals win over retryable text. + - Keep retryable classification for 429/quota/rate-limit, 5xx/server, timeout, and connection signals. + +7. Sanitize target-failure summaries. + - Keep raw messages out of fallback details and traces. + - Remove unnecessary raw string arguments from summary call sites. + - Add tests proving secrets/provider text do not appear. + +8. Improve stream fallback frame handling. + - Verify error/control frames are non-visible output. + - Test `event: error`, `type: error`, `response.failed`, `[DONE]`, comments, and heartbeat-like frames. + - Ensure visible text/tool deltas still block fallback. + +9. Add deterministic exhaustion metadata. + - Include sanitized target summaries in routing exhaustion traces for stream and non-stream paths. + - Avoid raw exception text in trace metadata. + +10. Add execution-mode safety tests. + - Explicit native configuration errors must be hard stops. + - Unsupported operation behavior must be clear and tested. + - LiteLLM fallback remains available for retryable errors. + +## Tests + +- `tests/test_fallback_policy.py` +- `tests/test_retry_policy.py` +- `tests/test_routing_config.py` +- `tests/test_fallback_resolver.py` +- `tests/test_routing_executor.py` +- `tests/test_streaming_fallback_policy.py` +- `tests/test_request_executor_native_routing.py` +- Phase 5b provider/native regression subset if executor routing helpers change. + +## Acceptance Criteria + +- Auth, forbidden, invalid request, context-window, credential reauth, pre-request callback, cancellation, and configuration errors never fallback across targets. +- Structured error responses classify deterministically with hard-stop signals taking precedence. +- Streaming fallback respects `FallbackGroup.streaming_policy`. +- Error/control stream frames before visible output do not accidentally lock routing. +- Cross-target summaries and traces remain sanitized. +- Focused tests and dual-agent review pass. From a5d3464b09aa8700ef3ecb7b5f58b75738852716 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:01:22 +0200 Subject: [PATCH 115/182] fix(routing): enforce hard-stop fallback policy Adds normalized route error vocabulary and non-overridable hard-stop categories for auth, permission, validation, cancellation, credential reauth, and configuration failures. Validates configured fallback policies so failover_on cannot include hard-stop errors and parses streaming_policy overrides for ordered fallback groups. Tests: pytest tests/test_fallback_policy.py tests/test_routing_config.py tests/test_retry_policy.py --- src/rotator_library/routing/config.py | 28 +++++++++++++--- src/rotator_library/routing/policy.py | 47 ++++++++++++++++++++++++--- src/rotator_library/routing/types.py | 12 +++++++ tests/test_fallback_policy.py | 27 +++++++++++++-- tests/test_routing_config.py | 39 ++++++++++++++++++++++ 5 files changed, 142 insertions(+), 11 deletions(-) diff --git a/src/rotator_library/routing/config.py b/src/rotator_library/routing/config.py index efb3eb27f..3d7bffb71 100644 --- a/src/rotator_library/routing/config.py +++ b/src/rotator_library/routing/config.py @@ -8,7 +8,8 @@ import os from collections.abc import Mapping -from .types import DEFAULT_FAILOVER_ON, DEFAULT_STOP_ON, FallbackGroup, RouteTarget, RoutingConfig +from .policy import normalize_route_error_set +from .types import DEFAULT_FAILOVER_ON, DEFAULT_STOP_ON, HARD_STOP_ON, FallbackGroup, RouteTarget, RoutingConfig class RoutingConfigError(ValueError): @@ -57,7 +58,10 @@ def load_routing_config_from_env(env: Mapping[str, str] | None = None, config: o target_specs = _csv(source.get(key, "")) if not target_specs: raise RoutingConfigError(f"fallback group {name} has no targets") - groups[name] = FallbackGroup(name=name, targets=tuple(parse_route_target(spec) for spec in target_specs)) + failover_on = _policy_set(source.get(f"{key}_FAILOVER_ON"), DEFAULT_FAILOVER_ON, name) + stop_on = _policy_set(source.get(f"{key}_STOP_ON"), DEFAULT_STOP_ON, name, allow_hard_stop=True) + streaming_policy = _streaming_policy(source.get(f"{key}_STREAMING_POLICY", "pre_output_only"), name) + groups[name] = FallbackGroup(name=name, targets=tuple(parse_route_target(spec) for spec in target_specs), failover_on=failover_on, stop_on=stop_on, streaming_policy=streaming_policy) model_routes: dict[str, str] = _model_routes_from_json_config(config) for key, value in source.items(): @@ -87,8 +91,9 @@ def _groups_from_json_config(config: object) -> dict[str, FallbackGroup]: groups[str(name)] = FallbackGroup( name=str(name), targets=tuple(parse_route_target(str(spec)) for spec in raw_targets), - failover_on=_string_set(raw_group.get("failover_on"), DEFAULT_FAILOVER_ON), - stop_on=_string_set(raw_group.get("stop_on"), DEFAULT_STOP_ON), + failover_on=_policy_set(raw_group.get("failover_on"), DEFAULT_FAILOVER_ON, str(name)), + stop_on=_policy_set(raw_group.get("stop_on"), DEFAULT_STOP_ON, str(name), allow_hard_stop=True), + streaming_policy=_streaming_policy(raw_group.get("streaming_policy", "pre_output_only"), str(name)), max_targets=int(raw_group["max_targets"]) if raw_group.get("max_targets") is not None else None, metadata=dict(raw_group.get("metadata", {})) if isinstance(raw_group.get("metadata", {}), Mapping) else {}, ) @@ -121,6 +126,21 @@ def _string_set(value: object, default: frozenset[str]) -> frozenset[str]: raise RoutingConfigError("routing policy lists must be strings or arrays") +def _policy_set(value: object, default: frozenset[str], group_name: str, *, allow_hard_stop: bool = False) -> frozenset[str]: + values = normalize_route_error_set(_string_set(value, default)) + unsafe = values & HARD_STOP_ON + if unsafe and not allow_hard_stop: + raise RoutingConfigError(f"fallback group {group_name} failover_on cannot include hard-stop errors: {', '.join(sorted(unsafe))}") + return values + + +def _streaming_policy(value: object, group_name: str): + policy = str(value or "pre_output_only").strip().lower() + if policy not in {"pre_output_only", "never"}: + raise RoutingConfigError(f"fallback group {group_name} has unsupported streaming_policy {value!r}") + return policy + + def _csv(value: str) -> list[str]: return [part.strip() for part in value.split(",") if part.strip()] diff --git a/src/rotator_library/routing/policy.py b/src/rotator_library/routing/policy.py index 92ea5241d..02aaba63e 100644 --- a/src/rotator_library/routing/policy.py +++ b/src/rotator_library/routing/policy.py @@ -5,7 +5,44 @@ from __future__ import annotations -from .types import DEFAULT_FAILOVER_ON, DEFAULT_STOP_ON, FallbackGroup +from .types import DEFAULT_FAILOVER_ON, DEFAULT_STOP_ON, HARD_STOP_ON, FallbackGroup + + +_ALIASES = { + "auth": "authentication", + "permission": "forbidden", + "permission_denied": "forbidden", + "access_denied": "forbidden", + "bad_request": "invalid_request", + "validation": "invalid_request", + "permanent": "invalid_request", + "context_length": "context_window_exceeded", + "configuration": "configuration_error", + "config": "configuration_error", + "quota": "quota_exceeded", + "capacity": "rate_limit", + "transient": "server_error", + "network": "api_connection", + "connection": "api_connection", +} + + +def normalize_route_error_type(error_type: str | None) -> str: + """Return the policy vocabulary for a raw classifier or config value. + + Routing config is user-facing, while the executor consumes classifier output. + Keeping normalization in one place prevents aliases such as `auth` from + bypassing non-overridable hard-stop categories. + """ + + normalized = str(error_type or "").strip().lower().replace("-", "_").replace(" ", "_") + return _ALIASES.get(normalized, normalized) + + +def normalize_route_error_set(values: frozenset[str]) -> frozenset[str]: + """Normalize a policy set while preserving unknown custom categories.""" + + return frozenset(normalize_route_error_type(value) for value in values) class FallbackPolicy: @@ -21,11 +58,13 @@ def should_fallback( ) -> bool: """Return whether fallback is allowed for a classified failure.""" - normalized = error_type.lower() - active_stop = group.stop_on if group else DEFAULT_STOP_ON - active_failover = group.failover_on if group else DEFAULT_FAILOVER_ON + normalized = normalize_route_error_type(error_type) if stream and emitted_output: return False + if normalized in HARD_STOP_ON: + return False + active_stop = normalize_route_error_set(group.stop_on if group else DEFAULT_STOP_ON) + active_failover = normalize_route_error_set(group.failover_on if group else DEFAULT_FAILOVER_ON) if normalized in active_stop: return False return normalized in active_failover diff --git a/src/rotator_library/routing/types.py b/src/rotator_library/routing/types.py index 42d2a5fbb..b0598e00b 100644 --- a/src/rotator_library/routing/types.py +++ b/src/rotator_library/routing/types.py @@ -42,6 +42,18 @@ "pre_request_callback", } ) +HARD_STOP_ON = frozenset( + { + "authentication", + "forbidden", + "invalid_request", + "context_window_exceeded", + "credential_reauth_needed", + "pre_request_callback_error", + "cancelled", + "configuration_error", + } +) @dataclass(frozen=True) diff --git a/tests/test_fallback_policy.py b/tests/test_fallback_policy.py index eb7cb0e88..7ea82203a 100644 --- a/tests/test_fallback_policy.py +++ b/tests/test_fallback_policy.py @@ -1,6 +1,7 @@ from __future__ import annotations from rotator_library.routing import FallbackPolicy, parse_route_target +from rotator_library.routing.policy import normalize_route_error_type from rotator_library.routing.types import FallbackGroup @@ -33,13 +34,33 @@ def test_policy_allows_stream_fallback_before_visible_output() -> None: assert FallbackPolicy().should_fallback("rate_limit", stream=True, emitted_output=False) is True -def test_policy_respects_group_overrides() -> None: +def test_policy_respects_safe_group_overrides() -> None: group = FallbackGroup( name="auth_safe", targets=(parse_route_target("a/model"), parse_route_target("b/model")), - failover_on=frozenset({"auth"}), + failover_on=frozenset({"network"}), stop_on=frozenset({"validation"}), ) - assert FallbackPolicy().should_fallback("auth", group=group) is True + assert FallbackPolicy().should_fallback("api_connection", group=group) is True assert FallbackPolicy().should_fallback("validation", group=group) is False + + +def test_policy_hard_stops_cannot_be_overridden_by_group_failover() -> None: + group = FallbackGroup( + name="unsafe", + targets=(parse_route_target("a/model"), parse_route_target("b/model")), + failover_on=frozenset({"auth", "configuration"}), + stop_on=frozenset(), + ) + + assert FallbackPolicy().should_fallback("authentication", group=group) is False + assert FallbackPolicy().should_fallback("configuration_error", group=group) is False + + +def test_policy_normalizes_user_facing_aliases() -> None: + assert normalize_route_error_type("auth") == "authentication" + assert normalize_route_error_type("permission-denied") == "forbidden" + assert normalize_route_error_type("bad request") == "invalid_request" + assert FallbackPolicy().should_fallback("network") is True + assert FallbackPolicy().should_fallback("validation") is False diff --git a/tests/test_routing_config.py b/tests/test_routing_config.py index 3d856d61b..ad38f68eb 100644 --- a/tests/test_routing_config.py +++ b/tests/test_routing_config.py @@ -46,3 +46,42 @@ def test_load_routing_config_rejects_duplicate_group_names() -> None: def test_load_routing_config_rejects_unknown_group_route() -> None: with pytest.raises(RoutingConfigError): load_routing_config_from_env({"MODEL_ROUTE_CODEX": "group:missing"}) + + +def test_load_routing_config_parses_group_policy_overrides() -> None: + config = load_routing_config_from_env( + { + "FALLBACK_GROUPS": "chain", + "FALLBACK_GROUP_CHAIN": "openai/gpt,copilot/gpt", + "FALLBACK_GROUP_CHAIN_FAILOVER_ON": "network,transient", + "FALLBACK_GROUP_CHAIN_STOP_ON": "validation", + "FALLBACK_GROUP_CHAIN_STREAMING_POLICY": "never", + } + ) + + group = config.fallback_groups["chain"] + assert group.failover_on == frozenset({"api_connection", "server_error"}) + assert group.stop_on == frozenset({"invalid_request"}) + assert group.streaming_policy == "never" + + +def test_load_routing_config_rejects_hard_stop_failover() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env( + { + "FALLBACK_GROUPS": "chain", + "FALLBACK_GROUP_CHAIN": "openai/gpt,copilot/gpt", + "FALLBACK_GROUP_CHAIN_FAILOVER_ON": "auth", + } + ) + + +def test_load_routing_config_rejects_unknown_streaming_policy() -> None: + with pytest.raises(RoutingConfigError): + load_routing_config_from_env( + { + "FALLBACK_GROUPS": "chain", + "FALLBACK_GROUP_CHAIN": "openai/gpt,copilot/gpt", + "FALLBACK_GROUP_CHAIN_STREAMING_POLICY": "always", + } + ) From 681f2dc2ff9ca98abb39b128f6ce795af94ad417 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:05:42 +0200 Subject: [PATCH 116/182] fix(routing): harden structured fallback decisions Classifies structured proxy error responses from typed/status fields before summaries so hard-stop errors win over retryable text. Sanitizes fallback target summaries by avoiding raw exception strings, enforces group streaming_policy in live streaming fallback, and normalizes route error types across retry/routing helpers. Tests: pytest tests/test_request_executor_native_routing.py tests/test_streaming_fallback_policy.py tests/test_fallback_policy.py tests/test_routing_config.py tests/test_retry_policy.py --- src/rotator_library/client/executor.py | 132 +++++++++++++++--- src/rotator_library/retry_policy.py | 5 +- src/rotator_library/routing/executor.py | 4 +- tests/test_request_executor_native_routing.py | 25 ++++ tests/test_streaming_fallback_policy.py | 52 ++++++- 5 files changed, 192 insertions(+), 26 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index d6dd8312a..a9d108bb5 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -65,6 +65,7 @@ from ..failure_logger import log_failure from ..retry_policy import decide_provider_cooldown, provider_cooldown_env from ..routing import FallbackPolicy, clone_context_for_target +from ..routing.policy import normalize_route_error_type from ..routing.types import RouteTarget from ..native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor from ..field_cache.paths import FieldCachePathError, PathToken, parse_path @@ -789,7 +790,7 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> _target_trace(target), metadata={"target_index": index, "error_type": error_type, "exception": exc.__class__.__name__}, ) - target_failures.append(_target_failure_summary(target, error_type, str(exc))) + target_failures.append(_target_failure_summary(target, error_type)) if index >= len(targets) - 1 or not policy.should_fallback(error_type, group=context.routing_group): self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "fallback_targets": target_failures}) raise @@ -805,7 +806,7 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> _target_trace(target), metadata={"target_index": index, "error_type": error_type}, ) - target_failures.append(_target_failure_summary(target, error_type, _route_error_message_from_response(result))) + target_failures.append(_target_failure_summary(target, error_type, status_code=_route_status_code_from_response(result))) if index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group): self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) continue @@ -833,11 +834,12 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy yield chunk return policy = FallbackPolicy() + target_failures: List[Dict[str, Any]] = [] self._log_routing_trace( context, "routing_decision", {"requested_model": context.model, "target_count": len(targets), "stream": True}, - metadata={"group": context.routing_group_name, "targets": [_target_trace(target) for target in targets]}, + metadata={"group": context.routing_group_name, "streaming_policy": _group_streaming_policy(context.routing_group), "targets": [_target_trace(target) for target in targets]}, ) for index, target in enumerate(targets): emitted_output = False @@ -876,6 +878,7 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy _target_trace(target), metadata={"target_index": index, "error_type": error_type, "emitted_output": emitted_output}, ) + target_failures.append(_target_failure_summary(target, error_type)) if emitted_output: self._log_routing_trace( context, @@ -884,7 +887,7 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy metadata={"target_index": index, "error_type": error_type}, ) raise - if index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False): + if index < len(targets) - 1 and _streaming_policy_allows_fallback(context.routing_group) and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False): self._log_routing_trace( context, "routing_fallback_selected", @@ -892,7 +895,7 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type, "stream": True}, ) continue - self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "stream": True}) + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "stream": True, "streaming_policy": _group_streaming_policy(context.routing_group), "fallback_targets": target_failures}) raise @staticmethod @@ -2519,9 +2522,9 @@ def _route_error_type(error: BaseException, provider: Optional[str] = None) -> s return "cancelled" explicit = getattr(error, "error_type", None) if explicit: - return str(explicit).lower() + return normalize_route_error_type(str(explicit)) classified = classify_error(error, provider) - return classified.error_type + return normalize_route_error_type(classified.error_type) def _route_error_type_from_response(response: Any) -> Optional[str]: @@ -2530,42 +2533,135 @@ def _route_error_type_from_response(response: Any) -> Optional[str]: if not isinstance(response, dict) or not isinstance(response.get("error"), dict): return None error = response["error"] - error_type = str(error.get("type", "")).lower() details = error.get("details") if isinstance(error.get("details"), dict) else {} + candidates = _structured_error_candidates(error, details) + hard_stop = _first_route_error_candidate(candidates, _HARD_STOP_ROUTE_ERRORS) + if hard_stop: + return hard_stop + retryable = _first_route_error_candidate(candidates, _RETRYABLE_ROUTE_ERRORS) + if retryable: + return retryable normal_summary = str(details.get("normal_error_summary", "")).lower() + if any(token in normal_summary for token in ("authentication", "forbidden", "invalid_request", "context_window", "credential_reauth", "configuration_error")): + return _summary_hard_stop_type(normal_summary) if any(token in normal_summary for token in ("rate_limit", "quota", "capacity")): - return "rate_limit" + return "quota_exceeded" if "quota" in normal_summary else "rate_limit" if any(token in normal_summary for token in ("server_error", "api_connection", "transient")): - return "server_error" + return "api_connection" if "api_connection" in normal_summary else "server_error" + error_type = normalize_route_error_type(str(error.get("type", ""))) if error_type in {"proxy_timeout", "proxy_all_credentials_exhausted"}: return "rate_limit" return None -def _route_error_message_from_response(response: Any) -> str: - """Extract a short non-secret message from a structured proxy error.""" +def _route_status_code_from_response(response: Any) -> Optional[int]: + """Return a structured status code without reading raw provider text.""" if not isinstance(response, dict) or not isinstance(response.get("error"), dict): - return "" - message = str(response["error"].get("message", "")) - return message[:150] + return None + error = response["error"] + details = error.get("details") if isinstance(error.get("details"), dict) else {} + for candidate in (details.get("status_code"), error.get("status_code"), error.get("code")): + try: + return int(candidate) + except (TypeError, ValueError): + continue + return None -def _target_failure_summary(target: RouteTarget, error_type: str, message: str = "") -> Dict[str, Any]: +def _target_failure_summary(target: RouteTarget, error_type: str, *, status_code: Optional[int] = None) -> Dict[str, Any]: """Return a client-safe fallback target failure summary.""" - return { + summary = { "target": target.name, "provider": target.provider, "model": target.prefixed_model, "execution": target.execution, - "error_type": error_type, + "error_type": normalize_route_error_type(error_type), # Provider error text can contain raw upstream payload fragments or # credential-like identifiers. Keep the cross-target summary structural; # detailed per-credential errors remain in the existing sanitized error # accumulator. "message": "", } + if status_code is not None: + summary["status_code"] = status_code + return summary + + +def _group_streaming_policy(group: Any) -> str: + """Return the active streaming fallback policy for trace and decisions.""" + + return str(getattr(group, "streaming_policy", "pre_output_only") or "pre_output_only") + + +def _streaming_policy_allows_fallback(group: Any) -> bool: + """Return whether a group permits pre-output streaming fallback.""" + + return _group_streaming_policy(group) != "never" + + +_HARD_STOP_ROUTE_ERRORS = { + "authentication", + "forbidden", + "invalid_request", + "context_window_exceeded", + "credential_reauth_needed", + "pre_request_callback_error", + "cancelled", + "configuration_error", +} +_RETRYABLE_ROUTE_ERRORS = {"rate_limit", "quota_exceeded", "server_error", "api_connection", "unsupported_operation"} + + +def _structured_error_candidates(error: Dict[str, Any], details: Dict[str, Any]) -> List[str]: + """Return normalized structured error hints before reading free-form text.""" + + values: List[Any] = [ + error.get("type"), + error.get("code"), + error.get("status"), + details.get("classification"), + details.get("error_type"), + details.get("status"), + ] + status_code = _route_status_code_from_response({"error": error}) + if status_code == 400: + values.append("invalid_request") + elif status_code == 401: + values.append("authentication") + elif status_code == 403: + values.append("forbidden") + elif status_code == 429: + values.append("rate_limit") + elif status_code is not None and status_code >= 500: + values.append("server_error") + return [normalize_route_error_type(str(value)) for value in values if value not in (None, "")] + + +def _first_route_error_candidate(candidates: List[str], allowed: set[str]) -> Optional[str]: + """Return the first structured candidate in an allowed policy set.""" + + for candidate in candidates: + if candidate in allowed: + return candidate + return None + + +def _summary_hard_stop_type(summary: str) -> str: + """Map legacy summary text to a hard-stop category without using raw text.""" + + if "credential_reauth" in summary: + return "credential_reauth_needed" + if "context_window" in summary: + return "context_window_exceeded" + if "configuration_error" in summary or "config" in summary: + return "configuration_error" + if "forbidden" in summary: + return "forbidden" + if "authentication" in summary: + return "authentication" + return "invalid_request" def _with_fallback_summary(response: Any, target_failures: List[Dict[str, Any]]) -> Any: diff --git a/src/rotator_library/retry_policy.py b/src/rotator_library/retry_policy.py index 6e133262e..52bbb5fe1 100644 --- a/src/rotator_library/retry_policy.py +++ b/src/rotator_library/retry_policy.py @@ -17,6 +17,7 @@ from .error_handler import ClassifiedError, classify_error, should_retry_same_key, should_rotate_on_error from .routing import FallbackPolicy +from .routing.policy import normalize_route_error_type from .routing.types import FallbackGroup DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS = 30 @@ -38,8 +39,8 @@ def classify_route_error(error: BaseException, provider: Optional[str] = None) - return "cancelled" explicit = getattr(error, "error_type", None) if explicit: - return str(explicit).lower() - return classify_error(error, provider).error_type + return normalize_route_error_type(str(explicit)) + return normalize_route_error_type(classify_error(error, provider).error_type) def should_retry_same_credential(classified_error: ClassifiedError, small_cooldown_threshold: int) -> bool: diff --git a/src/rotator_library/routing/executor.py b/src/rotator_library/routing/executor.py index 0e0c7e9b2..26af74ccb 100644 --- a/src/rotator_library/routing/executor.py +++ b/src/rotator_library/routing/executor.py @@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable from typing import Any -from .policy import FallbackPolicy +from .policy import FallbackPolicy, normalize_route_error_type from .types import FallbackGroup, RouteAttemptResult, RouteTarget, RoutingDecision AttemptCallback = Callable[[RouteTarget, int], Awaitable[Any]] @@ -65,4 +65,4 @@ async def run_group( def _error_type(error: BaseException) -> str: - return str(getattr(error, "error_type", error.__class__.__name__)).lower() + return normalize_route_error_type(str(getattr(error, "error_type", error.__class__.__name__))) diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 6be1534ad..1625397ae 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -360,6 +360,31 @@ def test_native_request_payload_drops_litellm_only_fields() -> None: assert payload == {"model": "provider/gpt-test", "messages": [{"role": "user", "content": "hi"}]} +def test_route_error_type_from_response_hard_stop_wins_over_retry_summary() -> None: + response = { + "error": { + "type": "authentication", + "message": "provider said quota secret-token", + "details": {"normal_error_summary": "rate_limit quota capacity", "status_code": 401}, + } + } + + assert executor_module._route_error_type_from_response(response) == "authentication" + + +def test_route_error_type_from_response_uses_structured_status_codes() -> None: + assert executor_module._route_error_type_from_response({"error": {"code": 403}}) == "forbidden" + assert executor_module._route_error_type_from_response({"error": {"details": {"status_code": 503}}}) == "server_error" + + +def test_target_failure_summary_is_structural_and_sanitized() -> None: + summary = executor_module._target_failure_summary(parse_route_target("openai/gpt"), "rate-limit", status_code=429) + + assert summary["error_type"] == "rate_limit" + assert summary["status_code"] == 429 + assert summary["message"] == "" + + def test_explicit_native_streaming_fails_when_provider_does_not_support_it() -> None: target = parse_route_target("provider/gpt-test@native") diff --git a/tests/test_streaming_fallback_policy.py b/tests/test_streaming_fallback_policy.py index 77879d7c9..bf259967a 100644 --- a/tests/test_streaming_fallback_policy.py +++ b/tests/test_streaming_fallback_policy.py @@ -34,8 +34,14 @@ def _context(*, logger=None) -> RequestContext: ) +def _context_never_streaming_fallback(*, logger=None) -> RequestContext: + context = _context(logger=logger) + context.routing_group = FallbackGroup(name="code_chain", targets=context.routing_targets, failover_on=frozenset({"rate_limit"}), streaming_policy="never") + return context + + @pytest.mark.asyncio -async def test_streaming_fallback_honors_group_override_before_output() -> None: +async def test_streaming_fallback_hard_stops_auth_even_with_group_override() -> None: executor = RequestExecutor.__new__(RequestExecutor) attempts = [] @@ -47,10 +53,10 @@ async def fake_stream(self, context): executor._execute_streaming = MethodType(fake_stream, executor) - chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] - assert attempts == ["codex", "openai"] - assert chunks == ["data: [DONE]\n\n"] + assert attempts == ["codex"] @pytest.mark.asyncio @@ -125,3 +131,41 @@ async def fake_stream(self, context): assert attempts == ["codex", "openai"] assert chunks[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_streaming_fallback_respects_never_policy_before_output() -> None: + executor = RequestExecutor.__new__(RequestExecutor) + attempts = [] + + async def fake_stream(self, context): + attempts.append(context.provider) + raise StreamFailure("rate_limit") + yield "" + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context_never_streaming_fallback())] + + assert attempts == ["codex"] + + +@pytest.mark.asyncio +async def test_streaming_fallback_exhaustion_trace_uses_sanitized_summaries(tmp_path) -> None: + executor = RequestExecutor.__new__(RequestExecutor) + logger = TransactionLogger("routing", "code", parent_dir=tmp_path) + + async def fake_stream(self, context): + raise StreamFailure("rate_limit") + yield "" + + executor._execute_streaming = MethodType(fake_stream, executor) + + with pytest.raises(StreamFailure): + [chunk async for chunk in executor._execute_streaming_with_fallback(_context_never_streaming_fallback(logger=logger))] + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + exhausted = [entry for entry in entries if entry["pass_name"] == "routing_fallback_exhausted"][-1] + assert exhausted["metadata"]["fallback_targets"][0]["message"] == "" + assert exhausted["metadata"]["streaming_policy"] == "never" From e15c185ebf806ec52beab24dfa41a913ecdce25f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:06:57 +0200 Subject: [PATCH 117/182] fix(streaming): classify control frames before fallback Extends visible-output detection to parse SSE event/data frames, comments, heartbeats, error events, and Responses failure events without treating them as model output. This keeps routing fallback available before visible output while preserving fail-closed behavior for malformed non-SSE chunks. Tests: pytest tests/test_stream_policy.py tests/test_streaming_fallback_policy.py --- src/rotator_library/streaming/policy.py | 31 +++++++++++++++++++++---- tests/test_stream_policy.py | 4 ++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/rotator_library/streaming/policy.py b/src/rotator_library/streaming/policy.py index d69bca69a..76e8c5962 100644 --- a/src/rotator_library/streaming/policy.py +++ b/src/rotator_library/streaming/policy.py @@ -73,16 +73,39 @@ def is_visible_stream_output(chunk: Optional[str], *, protocol: str = "openai_ch def _sse_json(chunk: str, *, malformed_is_visible: bool) -> dict[str, Any] | object | None: payload = chunk.strip() - if not payload.startswith("data:"): - return _MALFORMED_VISIBLE if malformed_is_visible and payload else None - payload = payload[5:].strip() + if not payload: + return None + if all(line.startswith(":") for line in payload.splitlines() if line.strip()): + return None + event_type = None + data_lines: list[str] = [] + for line in payload.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith(":"): + continue + if stripped.startswith("event:"): + event_type = stripped[6:].strip() + continue + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + continue + return _MALFORMED_VISIBLE if malformed_is_visible else None + if not data_lines: + if event_type in {"error", "response.failed"}: + return {"event_type": event_type} + return None + payload = "\n".join(data_lines).strip() if not payload or payload == "[DONE]": return None try: parsed = json.loads(payload) except json.JSONDecodeError: return _MALFORMED_VISIBLE if malformed_is_visible else None - return parsed if isinstance(parsed, dict) else (_MALFORMED_VISIBLE if malformed_is_visible else None) + if not isinstance(parsed, dict): + return _MALFORMED_VISIBLE if malformed_is_visible else None + if event_type and "event_type" not in parsed: + parsed["event_type"] = event_type + return parsed def _openai_chat_visible(data: dict[str, Any]) -> bool: diff --git a/tests/test_stream_policy.py b/tests/test_stream_policy.py index a7e1a9397..aa7fcc329 100644 --- a/tests/test_stream_policy.py +++ b/tests/test_stream_policy.py @@ -19,6 +19,9 @@ def test_visible_output_detection_for_chat_chunks() -> None: assert is_visible_stream_output('data: {"choices":[{"delta":{"content":"hello"}}]}\n\n') is True assert is_visible_stream_output('data: {"choices":[{"delta":{"tool_calls":[{"id":"call_1"}]}}]}\n\n') is True assert is_visible_stream_output('data: {"error":{"type":"rate_limit"}}\n\n') is False + assert is_visible_stream_output('event: error\ndata: {"type":"error","error":{"message":"x"}}\n\n') is False + assert is_visible_stream_output('data: {"type":"error","error":{"message":"x"}}\n\n') is False + assert is_visible_stream_output(': heartbeat\n\n') is False assert is_visible_stream_output("data: [DONE]\n\n") is False assert is_visible_stream_output("not-sse") is True @@ -26,3 +29,4 @@ def test_visible_output_detection_for_chat_chunks() -> None: def test_visible_output_detection_for_responses_events() -> None: assert is_visible_stream_output('data: {"event_type":"response.output_text.delta","delta":"hi"}\n\n', protocol="responses") is True assert is_visible_stream_output('data: {"event_type":"response.failed","error":{"message":"x"}}\n\n', protocol="responses") is False + assert is_visible_stream_output('event: response.failed\ndata: {"error":{"message":"x"}}\n\n', protocol="responses") is False From a957f0806f8c68243ed0769209428cdf9f58c0eb Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:17:21 +0200 Subject: [PATCH 118/182] fix(routing): close fallback review gaps Reads accumulated abnormal credential failures before classifying proxy_all_credentials_exhausted so hard-stop auth and permission failures cannot fallback as rate limits. Treats terminal stream error frames as target failures before visible output, enforces streaming_policy in the exported fallback runner, and classifies numeric status fields consistently. Removes unreachable duplicate invalid-request code from the shared error classifier and updates stale routing tests for the hard-stop invariant. Tests: pytest tests/test_fallback_policy.py tests/test_fallback_attempt_runner.py tests/test_retry_policy.py tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_routing_attempts.py tests/test_request_builder_routing.py tests/test_config_routing_json.py tests/test_streaming_fallback_policy.py tests/test_stream_policy.py tests/test_request_executor_native_routing.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py --- src/rotator_library/client/executor.py | 99 ++++++++++++++++++- src/rotator_library/error_handler.py | 5 - src/rotator_library/routing/executor.py | 2 +- src/rotator_library/routing/policy.py | 9 ++ tests/test_fallback_attempt_runner.py | 30 +++++- tests/test_request_executor_native_routing.py | 21 ++++ tests/test_routing_config.py | 9 ++ tests/test_streaming_fallback_policy.py | 5 +- 8 files changed, 165 insertions(+), 15 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index a9d108bb5..a103577b4 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -843,6 +843,8 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy ) for index, target in enumerate(targets): emitted_output = False + pending_chunks: List[str] = [] + terminal_error_type: Optional[str] = None target_context = clone_context_for_target( context, target, @@ -860,9 +862,42 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy ) try: async for chunk in self._execute_streaming(target_context): + chunk_error_type = _stream_chunk_error_type(chunk) + if chunk_error_type and not emitted_output: + terminal_error_type = chunk_error_type + pending_chunks.append(chunk) + continue if _stream_chunk_is_visible_output(chunk): + for pending in pending_chunks: + yield pending + pending_chunks.clear() emitted_output = True - yield chunk + yield chunk + continue + pending_chunks.append(chunk) + if terminal_error_type and not emitted_output: + error_type = terminal_error_type + self._log_routing_trace( + context, + "routing_stream_target_attempt_failed", + _target_trace(target), + metadata={"target_index": index, "error_type": error_type, "emitted_output": emitted_output, "terminal_error_frame": True}, + ) + target_failures.append(_target_failure_summary(target, error_type)) + if index < len(targets) - 1 and _streaming_policy_allows_fallback(context.routing_group) and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False): + self._log_routing_trace( + context, + "routing_fallback_selected", + _target_trace(targets[index + 1]), + metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type, "stream": True}, + ) + continue + self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "stream": True, "streaming_policy": _group_streaming_policy(context.routing_group), "fallback_targets": target_failures}) + for pending in pending_chunks: + yield pending + return + for pending in pending_chunks: + yield pending self._log_routing_trace( context, "routing_stream_target_attempt_succeeded", @@ -2561,7 +2596,7 @@ def _route_status_code_from_response(response: Any) -> Optional[int]: return None error = response["error"] details = error.get("details") if isinstance(error.get("details"), dict) else {} - for candidate in (details.get("status_code"), error.get("status_code"), error.get("code")): + for candidate in (details.get("status_code"), details.get("status"), error.get("status_code"), error.get("status"), error.get("code")): try: return int(candidate) except (TypeError, ValueError): @@ -2625,6 +2660,10 @@ def _structured_error_candidates(error: Dict[str, Any], details: Dict[str, Any]) details.get("error_type"), details.get("status"), ] + for abnormal in details.get("abnormal_errors") or []: + if isinstance(abnormal, dict): + values.append(abnormal.get("error_type")) + values.append(abnormal.get("status_code")) status_code = _route_status_code_from_response({"error": error}) if status_code == 400: values.append("invalid_request") @@ -2687,6 +2726,62 @@ def _stream_chunk_is_visible_output(chunk: str) -> bool: return is_visible_stream_output(chunk) +def _stream_chunk_error_type(chunk: str) -> Optional[str]: + """Return a route error type for terminal stream error frames. + + Per-target stream executors can emit a structured error SSE and `[DONE]` + instead of raising. Fallback wrappers must treat those frames as target + failures before visible output, while still forwarding them if no fallback is + available. + """ + + payload = _stream_chunk_payload(chunk) + if not isinstance(payload, dict): + return None + event_type = normalize_route_error_type(str(payload.get("event_type") or payload.get("type") or "")) + if event_type == "error" and isinstance(payload.get("error"), dict): + return _route_error_type_from_response({"error": payload["error"]}) or "server_error" + if event_type == "response.failed": + error = payload.get("error") if isinstance(payload.get("error"), dict) else {"type": "server_error"} + return _route_error_type_from_response({"error": error}) or "server_error" + if isinstance(payload.get("error"), dict): + return _route_error_type_from_response({"error": payload["error"]}) or "server_error" + return None + + +def _stream_chunk_payload(chunk: str) -> Optional[Dict[str, Any]]: + """Parse a minimal SSE payload for routing decisions only.""" + + text = str(chunk or "").strip() + if not text: + return None + event_type = None + data_lines: List[str] = [] + for line in text.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith(":"): + continue + if stripped.startswith("event:"): + event_type = stripped[6:].strip() + continue + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + if not data_lines: + return {"event_type": event_type} if event_type else None + data = "\n".join(data_lines).strip() + if not data or data == "[DONE]": + return None + try: + parsed = json.loads(data) + except json.JSONDecodeError: + return None + if not isinstance(parsed, dict): + return None + if event_type and "event_type" not in parsed: + parsed["event_type"] = event_type + return parsed + + def _can_start_stream_provider_cooldown(last_streamed_chunk: Optional[str]) -> bool: """Return whether a streaming failure occurred before visible output.""" diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index e0130c17c..d65d7eadb 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -925,11 +925,6 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr original_exception=e, status_code=status_code, ) - return ClassifiedError( - error_type="invalid_request", - original_exception=e, - status_code=status_code, - ) if 400 <= status_code < 500: # Other 4xx errors - generally client errors return ClassifiedError( diff --git a/src/rotator_library/routing/executor.py b/src/rotator_library/routing/executor.py index 26af74ccb..9be8f641b 100644 --- a/src/rotator_library/routing/executor.py +++ b/src/rotator_library/routing/executor.py @@ -59,7 +59,7 @@ async def run_group( emitted_output = bool(getattr(exc, "emitted_output", False)) attempts.append(RouteAttemptResult(target=target, success=False, error_type=error_type, emitted_output=emitted_output)) has_next = index < len(decision.targets) - 1 - if not has_next or not self.policy.should_fallback(error_type, group=group, emitted_output=emitted_output, stream=stream): + if not has_next or (stream and getattr(group, "streaming_policy", "pre_output_only") == "never") or not self.policy.should_fallback(error_type, group=group, emitted_output=emitted_output, stream=stream): raise FallbackExhaustedError(decision, tuple(attempts)) from exc raise FallbackExhaustedError(decision, tuple(attempts)) diff --git a/src/rotator_library/routing/policy.py b/src/rotator_library/routing/policy.py index 02aaba63e..152673b20 100644 --- a/src/rotator_library/routing/policy.py +++ b/src/rotator_library/routing/policy.py @@ -17,6 +17,7 @@ "validation": "invalid_request", "permanent": "invalid_request", "context_length": "context_window_exceeded", + "pre_request_callback": "pre_request_callback_error", "configuration": "configuration_error", "config": "configuration_error", "quota": "quota_exceeded", @@ -24,6 +25,14 @@ "transient": "server_error", "network": "api_connection", "connection": "api_connection", + "400": "invalid_request", + "401": "authentication", + "403": "forbidden", + "429": "rate_limit", + "500": "server_error", + "502": "server_error", + "503": "server_error", + "504": "server_error", } diff --git a/tests/test_fallback_attempt_runner.py b/tests/test_fallback_attempt_runner.py index 60219cd73..ce1a0c9ee 100644 --- a/tests/test_fallback_attempt_runner.py +++ b/tests/test_fallback_attempt_runner.py @@ -61,7 +61,7 @@ async def attempt(target, index): await FallbackAttemptRunner().run(_decision(), attempt) assert len(exc.value.attempts) == 1 - assert exc.value.attempts[0].error_type == "validation" + assert exc.value.attempts[0].error_type == "invalid_request" @pytest.mark.asyncio @@ -77,7 +77,7 @@ async def attempt(target, index): @pytest.mark.asyncio -async def test_attempt_runner_honors_group_policy_overrides() -> None: +async def test_attempt_runner_hard_stops_group_policy_overrides() -> None: group = FallbackGroup( name="custom", targets=_decision().targets, @@ -92,7 +92,27 @@ async def attempt(target, index): raise ClassifiedFailure("authentication") return {"target": target.prefixed_model} - result = await FallbackAttemptRunner().run_group(_decision(), group, attempt) + with pytest.raises(FallbackExhaustedError): + await FallbackAttemptRunner().run_group(_decision(), group, attempt) - assert result == {"target": "openai/gpt-5.1"} - assert calls == [0, 1] + assert calls == [0] + + +@pytest.mark.asyncio +async def test_attempt_runner_respects_never_streaming_policy() -> None: + group = FallbackGroup( + name="custom", + targets=_decision().targets, + failover_on=frozenset({"rate_limit"}), + streaming_policy="never", + ) + calls = [] + + async def attempt(target, index): + calls.append(index) + raise ClassifiedFailure("rate_limit") + + with pytest.raises(FallbackExhaustedError): + await FallbackAttemptRunner().run_group(_decision(), group, attempt, stream=True) + + assert calls == [0] diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 1625397ae..0b76c5db8 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -374,9 +374,30 @@ def test_route_error_type_from_response_hard_stop_wins_over_retry_summary() -> N def test_route_error_type_from_response_uses_structured_status_codes() -> None: assert executor_module._route_error_type_from_response({"error": {"code": 403}}) == "forbidden" + assert executor_module._route_error_type_from_response({"error": {"status": 401}}) == "authentication" + assert executor_module._route_error_type_from_response({"error": {"details": {"status": 403}}}) == "forbidden" assert executor_module._route_error_type_from_response({"error": {"details": {"status_code": 503}}}) == "server_error" +def test_route_error_type_from_response_reads_abnormal_errors_before_proxy_summary() -> None: + response = { + "error": { + "type": "proxy_all_credentials_exhausted", + "details": { + "normal_error_summary": "rate_limit quota", + "abnormal_errors": [{"error_type": "authentication", "status_code": 401}], + }, + } + } + + assert executor_module._route_error_type_from_response(response) == "authentication" + + +def test_stream_chunk_error_type_detects_terminal_error_frames() -> None: + assert executor_module._stream_chunk_error_type('data: {"error":{"type":"rate_limit"}}\n\n') == "rate_limit" + assert executor_module._stream_chunk_error_type('event: response.failed\ndata: {"error":{"type":"authentication"}}\n\n') == "authentication" + + def test_target_failure_summary_is_structural_and_sanitized() -> None: summary = executor_module._target_failure_summary(parse_route_target("openai/gpt"), "rate-limit", status_code=429) diff --git a/tests/test_routing_config.py b/tests/test_routing_config.py index ad38f68eb..353f0da60 100644 --- a/tests/test_routing_config.py +++ b/tests/test_routing_config.py @@ -75,6 +75,15 @@ def test_load_routing_config_rejects_hard_stop_failover() -> None: } ) + with pytest.raises(RoutingConfigError): + load_routing_config_from_env( + { + "FALLBACK_GROUPS": "chain", + "FALLBACK_GROUP_CHAIN": "openai/gpt,copilot/gpt", + "FALLBACK_GROUP_CHAIN_FAILOVER_ON": "pre_request_callback", + } + ) + def test_load_routing_config_rejects_unknown_streaming_policy() -> None: with pytest.raises(RoutingConfigError): diff --git a/tests/test_streaming_fallback_policy.py b/tests/test_streaming_fallback_policy.py index bf259967a..d2088f674 100644 --- a/tests/test_streaming_fallback_policy.py +++ b/tests/test_streaming_fallback_policy.py @@ -122,7 +122,8 @@ async def fake_stream(self, context): attempts.append(context.provider) if len(attempts) == 1: yield 'data: {"error":{"type":"rate_limit"}}\n\n' - raise StreamFailure("rate_limit") + yield "data: [DONE]\n\n" + return yield "data: [DONE]\n\n" executor._execute_streaming = MethodType(fake_stream, executor) @@ -130,7 +131,7 @@ async def fake_stream(self, context): chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] assert attempts == ["codex", "openai"] - assert chunks[-1] == "data: [DONE]\n\n" + assert chunks == ["data: [DONE]\n\n"] @pytest.mark.asyncio From bbe704d512defea6880b4b0469c471dfcecb5034 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:27:37 +0200 Subject: [PATCH 119/182] fix(routing): preserve stream hard-stop semantics Prevents hard-stop streaming failures from being converted into generic internal errors with raw messages, and recognizes terminal error SSE frames even when they use top-level type/code fields. Makes the public fallback runner honor RoutingDecision.group, auto-detects Responses stream visibility, and expands structured dict classification for numeric client/error statuses. Tests: pytest tests/test_fallback_policy.py tests/test_fallback_attempt_runner.py tests/test_retry_policy.py tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_routing_attempts.py tests/test_request_builder_routing.py tests/test_config_routing_json.py tests/test_request_executor_fallback_groups.py tests/test_streaming_fallback_policy.py tests/test_stream_policy.py tests/test_request_executor_native_routing.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py --- src/rotator_library/client/executor.py | 8 +++--- src/rotator_library/error_handler.py | 13 +++++++++- src/rotator_library/routing/executor.py | 2 +- src/rotator_library/streaming/policy.py | 3 +++ tests/test_fallback_attempt_runner.py | 21 ++++++++++++++- .../test_request_executor_fallback_groups.py | 26 +++++++++---------- tests/test_request_executor_native_routing.py | 1 + tests/test_retry_policy.py | 8 +++++- tests/test_stream_policy.py | 1 + 9 files changed, 63 insertions(+), 20 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index a103577b4..e54885f3a 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1844,7 +1844,8 @@ async def _execute_streaming( except Exception as e: lib_logger.error(f"Unhandled exception in streaming: {e}", exc_info=True) - error_data = {"error": {"message": str(e), "type": "proxy_internal_error"}} + classified = classify_error(e, context.provider) + error_data = {"error": {"message": "Streaming request failed", "type": classified.error_type, "details": {"status_code": classified.status_code}}} for line in self._terminal_stream_error_lines(context, error_data): yield line @@ -2739,8 +2740,9 @@ def _stream_chunk_error_type(chunk: str) -> Optional[str]: if not isinstance(payload, dict): return None event_type = normalize_route_error_type(str(payload.get("event_type") or payload.get("type") or "")) - if event_type == "error" and isinstance(payload.get("error"), dict): - return _route_error_type_from_response({"error": payload["error"]}) or "server_error" + if event_type == "error": + error = payload.get("error") if isinstance(payload.get("error"), dict) else payload + return _route_error_type_from_response({"error": error}) or "server_error" if event_type == "response.failed": error = payload.get("error") if isinstance(payload.get("error"), dict) else {"type": "server_error"} return _route_error_type_from_response({"error": error}) or "server_error" diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index d65d7eadb..11f54680e 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -759,12 +759,23 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr if isinstance(e, dict): payload = e.get("error", e) if isinstance(payload, dict): + details = payload.get("details") if isinstance(payload.get("details"), dict) else {} code = payload.get("code") + status_value = payload.get("status_code") or payload.get("status") or details.get("status_code") or details.get("status") or code status = str(payload.get("status", "")).upper() try: - status_code = int(code) if code is not None else None + status_code = int(status_value) if status_value is not None else None except (TypeError, ValueError): status_code = None + if status_code == 400: + return ClassifiedError(error_type="invalid_request", original_exception=e, status_code=status_code) + if status_code == 401: + return ClassifiedError(error_type="authentication", original_exception=e, status_code=status_code) + if status_code == 403: + return ClassifiedError(error_type="forbidden", original_exception=e, status_code=status_code) + if status_code == 429: + body = str(payload).lower() + return ClassifiedError(error_type="quota_exceeded" if "quota" in body or "resource_exhausted" in body else "rate_limit", original_exception=e, status_code=status_code) if (status_code is not None and status_code >= 500) or status in { "INTERNAL", "UNAVAILABLE", diff --git a/src/rotator_library/routing/executor.py b/src/rotator_library/routing/executor.py index 9be8f641b..dce897e9c 100644 --- a/src/rotator_library/routing/executor.py +++ b/src/rotator_library/routing/executor.py @@ -38,7 +38,7 @@ def __init__(self, policy: FallbackPolicy | None = None) -> None: async def run(self, decision: RoutingDecision, attempt: AttemptCallback, *, stream: bool = False) -> Any: """Try targets in order until one succeeds or fallback is exhausted.""" - return await self.run_group(decision, None, attempt, stream=stream) + return await self.run_group(decision, decision.group, attempt, stream=stream) async def run_group( self, diff --git a/src/rotator_library/streaming/policy.py b/src/rotator_library/streaming/policy.py index 76e8c5962..117d5ee37 100644 --- a/src/rotator_library/streaming/policy.py +++ b/src/rotator_library/streaming/policy.py @@ -63,6 +63,9 @@ def is_visible_stream_output(chunk: Optional[str], *, protocol: str = "openai_ch return False if data.get("error"): return False + event_type = data.get("event_type") or data.get("type") + if isinstance(event_type, str) and event_type.startswith("response."): + return _responses_visible(data) if protocol == "responses": return _responses_visible(data) return _openai_chat_visible(data) diff --git a/tests/test_fallback_attempt_runner.py b/tests/test_fallback_attempt_runner.py index ce1a0c9ee..167e5698d 100644 --- a/tests/test_fallback_attempt_runner.py +++ b/tests/test_fallback_attempt_runner.py @@ -14,10 +14,12 @@ def __init__(self, error_type: str, *, emitted_output: bool = False) -> None: def _decision() -> RoutingDecision: + targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) return RoutingDecision( requested_model="code", group_name="code_chain", - targets=(parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")), + targets=targets, + group=FallbackGroup(name="code_chain", targets=targets), reason="model_route_group", ) @@ -116,3 +118,20 @@ async def attempt(target, index): await FallbackAttemptRunner().run_group(_decision(), group, attempt, stream=True) assert calls == [0] + + +@pytest.mark.asyncio +async def test_attempt_runner_run_uses_decision_group_streaming_policy() -> None: + decision = _decision() + never_group = FallbackGroup(name="code_chain", targets=decision.targets, failover_on=frozenset({"rate_limit"}), streaming_policy="never") + decision = RoutingDecision(requested_model=decision.requested_model, group_name=decision.group_name, targets=decision.targets, group=never_group) + calls = [] + + async def attempt(target, index): + calls.append(index) + raise ClassifiedFailure("rate_limit") + + with pytest.raises(FallbackExhaustedError): + await FallbackAttemptRunner().run(decision, attempt, stream=True) + + assert calls == [0] diff --git a/tests/test_request_executor_fallback_groups.py b/tests/test_request_executor_fallback_groups.py index ad7799999..d86bdcce8 100644 --- a/tests/test_request_executor_fallback_groups.py +++ b/tests/test_request_executor_fallback_groups.py @@ -99,7 +99,7 @@ async def fake_execute(self, context): @pytest.mark.asyncio -async def test_non_streaming_fallback_group_honors_group_override() -> None: +async def test_non_streaming_fallback_group_hard_stops_group_override() -> None: executor = RequestExecutor.__new__(RequestExecutor) attempts = [] @@ -112,20 +112,20 @@ async def fake_execute(self, context): executor._execute_non_streaming = MethodType(fake_execute, executor) targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) - result = await executor._execute_non_streaming_with_fallback( - _context( - routing_targets=targets, - routing_group=FallbackGroup( - name="code_chain", - targets=targets, - failover_on=frozenset({"authentication"}), - stop_on=frozenset({"validation"}), - ), + with pytest.raises(ClassifiedFailure): + await executor._execute_non_streaming_with_fallback( + _context( + routing_targets=targets, + routing_group=FallbackGroup( + name="code_chain", + targets=targets, + failover_on=frozenset({"authentication"}), + stop_on=frozenset({"validation"}), + ), + ) ) - ) - assert result == {"id": "ok", "model": "openai/gpt-5.1"} - assert attempts == ["codex", "openai"] + assert attempts == ["codex"] @pytest.mark.asyncio diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 0b76c5db8..dbad94165 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -396,6 +396,7 @@ def test_route_error_type_from_response_reads_abnormal_errors_before_proxy_summa def test_stream_chunk_error_type_detects_terminal_error_frames() -> None: assert executor_module._stream_chunk_error_type('data: {"error":{"type":"rate_limit"}}\n\n') == "rate_limit" assert executor_module._stream_chunk_error_type('event: response.failed\ndata: {"error":{"type":"authentication"}}\n\n') == "authentication" + assert executor_module._stream_chunk_error_type('event: error\ndata: {"type":"error","code":429}\n\n') == "rate_limit" def test_target_failure_summary_is_structural_and_sanitized() -> None: diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 704c95eef..819ddfdc3 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -2,7 +2,7 @@ import asyncio -from rotator_library.error_handler import ClassifiedError, PreRequestCallbackError +from rotator_library.error_handler import ClassifiedError, PreRequestCallbackError, classify_error from rotator_library.retry_policy import ( classify_route_error, decide_provider_cooldown, @@ -86,3 +86,9 @@ def test_provider_cooldown_is_conservative_for_quota_by_default() -> None: assert disabled.should_start is False assert disabled.reason == "quota_cooldown_disabled" assert enabled.should_start is True + + +def test_shared_classifier_handles_structured_dict_status_codes() -> None: + assert classify_error({"error": {"status": 401}}).error_type == "authentication" + assert classify_error({"error": {"details": {"status_code": 403}}}).error_type == "forbidden" + assert classify_error({"error": {"code": 429}}).error_type == "rate_limit" diff --git a/tests/test_stream_policy.py b/tests/test_stream_policy.py index aa7fcc329..d72874951 100644 --- a/tests/test_stream_policy.py +++ b/tests/test_stream_policy.py @@ -28,5 +28,6 @@ def test_visible_output_detection_for_chat_chunks() -> None: def test_visible_output_detection_for_responses_events() -> None: assert is_visible_stream_output('data: {"event_type":"response.output_text.delta","delta":"hi"}\n\n', protocol="responses") is True + assert is_visible_stream_output('data: {"event_type":"response.output_text.delta","delta":"hi"}\n\n') is True assert is_visible_stream_output('data: {"event_type":"response.failed","error":{"message":"x"}}\n\n', protocol="responses") is False assert is_visible_stream_output('event: response.failed\ndata: {"error":{"message":"x"}}\n\n', protocol="responses") is False From 79c963388a3956a35b5a8b5f81ea265279018082 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:36:20 +0200 Subject: [PATCH 120/182] fix(routing): complete structured error classification Extends shared structured-dict classification for explicit error type/code/status values, preserving hard-stop streaming failures such as configuration_error and authentication instead of treating them as unknown. Adds route-policy aliases for context-window error codes and treats Responses function/tool-call deltas as visible stream output so fallback cannot duplicate client-visible tool data. Tests: pytest tests/test_fallback_policy.py tests/test_fallback_attempt_runner.py tests/test_retry_policy.py tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_routing_attempts.py tests/test_request_builder_routing.py tests/test_config_routing_json.py tests/test_request_executor_fallback_groups.py tests/test_streaming_fallback_policy.py tests/test_stream_policy.py tests/test_request_executor_native_routing.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py --- src/rotator_library/error_handler.py | 39 +++++++++++++++++++ src/rotator_library/routing/policy.py | 3 ++ src/rotator_library/streaming/policy.py | 2 + tests/test_fallback_policy.py | 1 + tests/test_request_executor_native_routing.py | 1 + tests/test_retry_policy.py | 14 +++++++ tests/test_stream_policy.py | 1 + 7 files changed, 61 insertions(+) diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index 11f54680e..3e5b1157a 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -776,6 +776,9 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr if status_code == 429: body = str(payload).lower() return ClassifiedError(error_type="quota_exceeded" if "quota" in body or "resource_exhausted" in body else "rate_limit", original_exception=e, status_code=status_code) + structured_type = _classify_structured_error_text(payload, details) + if structured_type: + return ClassifiedError(error_type=structured_type, original_exception=e, status_code=status_code) if (status_code is not None and status_code >= 500) or status in { "INTERNAL", "UNAVAILABLE", @@ -786,6 +789,14 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code=status_code or 503, ) + explicit_error_type = getattr(e, "error_type", None) + if explicit_error_type: + return ClassifiedError( + error_type=str(explicit_error_type).strip().lower().replace("-", "_").replace(" ", "_"), + original_exception=e, + status_code=getattr(e, "status_code", None), + ) + error_text = str(e) error_type_name = type(e).__name__ if ( @@ -1090,6 +1101,34 @@ def is_rate_limit_error(e: Exception) -> bool: return isinstance(e, RateLimitError) +def _classify_structured_error_text(payload: dict, details: dict) -> Optional[str]: + """Classify structured error type/code/status text without raw messages.""" + + values = [payload.get("type"), payload.get("code"), payload.get("status"), details.get("error_type"), details.get("classification"), details.get("status")] + normalized = {str(value).strip().lower().replace("-", "_").replace(" ", "_") for value in values if value not in (None, "")} + if normalized & {"authentication", "auth", "unauthorized", "invalid_api_key"}: + return "authentication" + if normalized & {"forbidden", "permission_denied", "access_denied"}: + return "forbidden" + if normalized & {"context_window_exceeded", "context_length", "context_length_exceeded", "too_many_tokens"}: + return "context_window_exceeded" + if normalized & {"invalid_request", "bad_request", "validation", "invalid_argument"}: + return "invalid_request" + if normalized & {"credential_reauth_needed"}: + return "credential_reauth_needed" + if normalized & {"configuration_error", "config", "configuration"}: + return "configuration_error" + if normalized & {"rate_limit", "rate_limited", "too_many_requests"}: + return "rate_limit" + if normalized & {"quota_exceeded", "resource_exhausted", "quota"}: + return "quota_exceeded" + if normalized & {"server_error", "internal", "unavailable"}: + return "server_error" + if normalized & {"api_connection", "network", "connection"}: + return "api_connection" + return None + + def is_server_error(e: Exception) -> bool: """Checks if the exception is a temporary server-side error.""" return isinstance( diff --git a/src/rotator_library/routing/policy.py b/src/rotator_library/routing/policy.py index 152673b20..d9becdec0 100644 --- a/src/rotator_library/routing/policy.py +++ b/src/rotator_library/routing/policy.py @@ -17,6 +17,9 @@ "validation": "invalid_request", "permanent": "invalid_request", "context_length": "context_window_exceeded", + "context_length_exceeded": "context_window_exceeded", + "context_window": "context_window_exceeded", + "too_many_tokens": "context_window_exceeded", "pre_request_callback": "pre_request_callback_error", "configuration": "configuration_error", "config": "configuration_error", diff --git a/src/rotator_library/streaming/policy.py b/src/rotator_library/streaming/policy.py index 117d5ee37..6c32af9aa 100644 --- a/src/rotator_library/streaming/policy.py +++ b/src/rotator_library/streaming/policy.py @@ -132,6 +132,8 @@ def _responses_visible(data: dict[str, Any]) -> bool: event_type = data.get("event_type") or data.get("type") if event_type == "response.output_text.delta": return bool(str(data.get("delta", "")).strip()) + if isinstance(event_type, str) and ("function_call" in event_type or "tool_call" in event_type): + return _has_visible_text(data.get("delta")) or _has_visible_text(data.get("arguments")) or _has_visible_text(data.get("item")) if event_type == "response.failed": return False return False diff --git a/tests/test_fallback_policy.py b/tests/test_fallback_policy.py index 7ea82203a..3cc8531d9 100644 --- a/tests/test_fallback_policy.py +++ b/tests/test_fallback_policy.py @@ -62,5 +62,6 @@ def test_policy_normalizes_user_facing_aliases() -> None: assert normalize_route_error_type("auth") == "authentication" assert normalize_route_error_type("permission-denied") == "forbidden" assert normalize_route_error_type("bad request") == "invalid_request" + assert normalize_route_error_type("context_length_exceeded") == "context_window_exceeded" assert FallbackPolicy().should_fallback("network") is True assert FallbackPolicy().should_fallback("validation") is False diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index dbad94165..7218472d8 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -397,6 +397,7 @@ def test_stream_chunk_error_type_detects_terminal_error_frames() -> None: assert executor_module._stream_chunk_error_type('data: {"error":{"type":"rate_limit"}}\n\n') == "rate_limit" assert executor_module._stream_chunk_error_type('event: response.failed\ndata: {"error":{"type":"authentication"}}\n\n') == "authentication" assert executor_module._stream_chunk_error_type('event: error\ndata: {"type":"error","code":429}\n\n') == "rate_limit" + assert executor_module._stream_chunk_error_type('event: error\ndata: {"type":"error","code":"context_length_exceeded"}\n\n') == "context_window_exceeded" def test_target_failure_summary_is_structural_and_sanitized() -> None: diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 819ddfdc3..5b2d627fd 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -92,3 +92,17 @@ def test_shared_classifier_handles_structured_dict_status_codes() -> None: assert classify_error({"error": {"status": 401}}).error_type == "authentication" assert classify_error({"error": {"details": {"status_code": 403}}}).error_type == "forbidden" assert classify_error({"error": {"code": 429}}).error_type == "rate_limit" + + +def test_shared_classifier_handles_structured_dict_type_and_code_text() -> None: + assert classify_error({"error": {"type": "authentication"}}).error_type == "authentication" + assert classify_error({"error": {"code": "permission_denied"}}).error_type == "forbidden" + assert classify_error({"error": {"code": "context_length_exceeded"}}).error_type == "context_window_exceeded" + assert classify_error({"error": {"type": "rate_limit"}}).error_type == "rate_limit" + + +def test_shared_classifier_preserves_explicit_error_type_attributes() -> None: + class ConfigurationFailure(Exception): + error_type = "configuration_error" + + assert classify_error(ConfigurationFailure("raw secret message")).error_type == "configuration_error" diff --git a/tests/test_stream_policy.py b/tests/test_stream_policy.py index d72874951..729c7f0a4 100644 --- a/tests/test_stream_policy.py +++ b/tests/test_stream_policy.py @@ -29,5 +29,6 @@ def test_visible_output_detection_for_chat_chunks() -> None: def test_visible_output_detection_for_responses_events() -> None: assert is_visible_stream_output('data: {"event_type":"response.output_text.delta","delta":"hi"}\n\n', protocol="responses") is True assert is_visible_stream_output('data: {"event_type":"response.output_text.delta","delta":"hi"}\n\n') is True + assert is_visible_stream_output('data: {"event_type":"response.function_call_arguments.delta","delta":"{\\"x\\":"}\n\n') is True assert is_visible_stream_output('data: {"event_type":"response.failed","error":{"message":"x"}}\n\n', protocol="responses") is False assert is_visible_stream_output('event: response.failed\ndata: {"error":{"message":"x"}}\n\n', protocol="responses") is False From 49e07f85990b45cf398b7a26afa9032e5fd9bc64 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:42:18 +0200 Subject: [PATCH 121/182] fix(routing): preserve context-window dict classification Keeps structured context-length codes specific when providers include both HTTP 400 and context_length_exceeded in dict-shaped error payloads. This is a defense-in-depth routing classifier precision fix after Phase 6b review acceptance; fallback safety was already preserved because both categories are hard stops. Tests: pytest tests/test_fallback_policy.py tests/test_fallback_attempt_runner.py tests/test_retry_policy.py tests/test_routing_config.py tests/test_fallback_resolver.py tests/test_routing_attempts.py tests/test_request_builder_routing.py tests/test_config_routing_json.py tests/test_request_executor_fallback_groups.py tests/test_streaming_fallback_policy.py tests/test_stream_policy.py tests/test_request_executor_native_routing.py tests/test_provider_protocol_declarations.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py --- src/rotator_library/error_handler.py | 4 ++-- tests/test_retry_policy.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index 3e5b1157a..7819a797a 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -767,8 +767,9 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code = int(status_value) if status_value is not None else None except (TypeError, ValueError): status_code = None + structured_type = _classify_structured_error_text(payload, details) if status_code == 400: - return ClassifiedError(error_type="invalid_request", original_exception=e, status_code=status_code) + return ClassifiedError(error_type=structured_type or "invalid_request", original_exception=e, status_code=status_code) if status_code == 401: return ClassifiedError(error_type="authentication", original_exception=e, status_code=status_code) if status_code == 403: @@ -776,7 +777,6 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr if status_code == 429: body = str(payload).lower() return ClassifiedError(error_type="quota_exceeded" if "quota" in body or "resource_exhausted" in body else "rate_limit", original_exception=e, status_code=status_code) - structured_type = _classify_structured_error_text(payload, details) if structured_type: return ClassifiedError(error_type=structured_type, original_exception=e, status_code=status_code) if (status_code is not None and status_code >= 500) or status in { diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 5b2d627fd..33c23da47 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -98,6 +98,7 @@ def test_shared_classifier_handles_structured_dict_type_and_code_text() -> None: assert classify_error({"error": {"type": "authentication"}}).error_type == "authentication" assert classify_error({"error": {"code": "permission_denied"}}).error_type == "forbidden" assert classify_error({"error": {"code": "context_length_exceeded"}}).error_type == "context_window_exceeded" + assert classify_error({"error": {"status_code": 400, "code": "context_length_exceeded"}}).error_type == "context_window_exceeded" assert classify_error({"error": {"type": "rate_limit"}}).error_type == "rate_limit" From e29b15d2421456c33f1e1ace021f7456760e06d1 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:44:53 +0200 Subject: [PATCH 122/182] docs(experimental): plan retry cooldown correction Adds the Phase 7b corrective plan for provider/model scoped cooldowns, model-capacity handling, bounded repeated-transient backoff, failure history, and streaming cooldown decision scope. Tests: not run (planning document only) --- .../phase-7b-retry-cooldown-backoff.md | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 docs/experimental/phase-7b-retry-cooldown-backoff.md diff --git a/docs/experimental/phase-7b-retry-cooldown-backoff.md b/docs/experimental/phase-7b-retry-cooldown-backoff.md new file mode 100644 index 000000000..5ceb14ff7 --- /dev/null +++ b/docs/experimental/phase-7b-retry-cooldown-backoff.md @@ -0,0 +1,81 @@ +# Phase 7b: Retry/Cooldown Backoff And Failure History + +## Goal + +Correct the Phase 7 audit findings. Phase 7 made provider cooldown activation real and centralized retry/failover decisions, but it still lacks provider/model cooldown scopes, scoped backoff, and structured in-memory failure history. + +## Non-Goals + +- Do not replace `UsageManager`, credential cooldowns, usage windows, `classify_error()`, or retry-after parsing. +- Do not introduce SQLite or a new persistence database. +- Do not make native streaming safe; unsupported native streaming remains fail-closed. +- Do not change Phase 6b fallback hard-stop behavior. +- Do not commit user-facing reports. + +## Current State + +- `CooldownManager` tracks provider cooldowns only. +- `RequestExecutor._wait_for_cooldown()` waits on provider cooldown only. +- `retry_policy.decide_provider_cooldown()` returns provider-level decisions without model scope or failure-history context. +- `MODEL_CAPACITY_EXHAUSTED` is noticed in logs but still behaves like a generic provider server error. +- There is no structured provider/model failure-history ring for future observability and bounded backoff decisions. + +## Implementation Plan + +1. Extend `CooldownManager` with scoped cooldown methods. + - Preserve `start_cooldown()`, `is_cooling_down()`, and `get_remaining_cooldown()`. + - Add provider/model scoped methods and extend-only semantics per scope key. + - Provider cooldown blocks all models; model cooldown blocks only that provider/model. + +2. Update executor cooldown waiting. + - Pass model into cooldown wait paths. + - Wait for the max of provider and model cooldown remaining. + - Trace waits without credentials or raw provider text. + +3. Add model-capacity detection. + - Detect `MODEL_CAPACITY_EXHAUSTED`, model capacity, and capacity-exhausted signals from exceptions and dict payloads. + - Keep `error_type="server_error"` for compatibility, but choose model cooldown scope. + +4. Extend cooldown decisions. + - Add `scope`, `model`, and optional backoff metadata to `ProviderCooldownDecision`. + - Large retry-after rate limits stay provider-scoped. + - Model-capacity failures become model-scoped. + - Quota cooldown remains disabled by default. + +5. Add in-memory failure history. + - Bounded ring with timestamp, provider, model, error type, scope, duration, and reason. + - No disk persistence. + - Executor records successful cooldown starts for future observability and backoff. + +6. Add bounded repeated-transient backoff. + - Track repeated transient `server_error`/`api_connection` failures within a configurable window. + - Escalate cooldown duration conservatively up to a max. + - Do not backoff hard-stop categories. + +7. Wire scoped cooldown start in `RequestExecutor`. + - Use scoped cooldown methods when available and fall back to provider-only fakes in tests. + - Keep existing trace pass names with added scope/model metadata. + +8. Update streaming error decisions. + - Keep `decide_streaming_error_action()` side-effect-free. + - Include cooldown scope/model in the decision. + - Visible output still suppresses provider/model cooldown. + +## Tests + +- `tests/test_cooldown_activation.py` +- `tests/test_retry_policy.py` +- `tests/test_streaming_error_handler.py` +- Executor-focused cooldown/trace tests. +- Phase 6b routing regression subset. + +## Acceptance Criteria + +- Provider cooldown behavior remains backwards-compatible and extend-only. +- Model cooldowns do not block unrelated models on the same provider. +- Provider cooldowns still block all models for that provider. +- Model-capacity errors produce model-scoped cooldown/backoff. +- Repeated transient failures can produce bounded scoped backoff from in-memory history. +- Cooldown starts/waits are traceable and sanitized. +- Streaming cooldown decisions expose scope and respect visible-output blocking. +- Focused tests and dual-agent review pass. From 4181413bed389538951cdaf33711179abb7896c3 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:46:18 +0200 Subject: [PATCH 123/182] feat(cooldown): add provider model scoped cooldowns Extends CooldownManager with provider/model scoped cooldown keys, max-remaining request checks, snapshots, and backward-compatible provider-only methods. Model scoped cooldowns are independent from provider cooldowns so model-capacity failures can pause one deployment without blocking every model on the provider. Tests: pytest tests/test_cooldown_activation.py --- src/rotator_library/cooldown_manager.py | 115 +++++++++++++++++++++--- tests/test_cooldown_activation.py | 35 ++++++++ 2 files changed, 136 insertions(+), 14 deletions(-) diff --git a/src/rotator_library/cooldown_manager.py b/src/rotator_library/cooldown_manager.py index 184056f6f..a4dbafefd 100644 --- a/src/rotator_library/cooldown_manager.py +++ b/src/rotator_library/cooldown_manager.py @@ -3,9 +3,21 @@ import asyncio import time +from dataclasses import dataclass from typing import Dict +@dataclass(frozen=True) +class CooldownSnapshot: + """Read-only view of one active cooldown scope for tests/observability.""" + + provider: str + scope: str + model: str | None + remaining: float + reason: str | None = None + + class CooldownManager: """ Manages global cooldown periods for API providers to handle IP-based rate limiting. @@ -15,14 +27,12 @@ class CooldownManager: def __init__(self): self._cooldowns: Dict[str, float] = {} + self._metadata: Dict[str, dict[str, str | None]] = {} self._lock = asyncio.Lock() async def is_cooling_down(self, provider: str) -> bool: """Checks if a provider is currently in a cooldown period.""" - async with self._lock: - return ( - provider in self._cooldowns and time.time() < self._cooldowns[provider] - ) + return await self.is_scoped_cooling_down(provider, scope="provider") async def start_cooldown(self, provider: str, duration: int): """ @@ -32,23 +42,100 @@ async def start_cooldown(self, provider: str, duration: int): A shorter new cooldown must not shorten an existing longer cooldown; provider-wide throttles often arrive concurrently from several requests. """ - async with self._lock: - new_expiry = time.time() + max(0, duration) - current_expiry = self._cooldowns.get(provider, 0) - if new_expiry > current_expiry: - self._cooldowns[provider] = new_expiry + await self.start_scoped_cooldown(provider, duration, scope="provider") async def get_cooldown_remaining(self, provider: str) -> float: """ Returns the remaining cooldown time in seconds for a provider. Returns 0 if the provider is not in a cooldown period. """ - async with self._lock: - if provider in self._cooldowns: - remaining = self._cooldowns[provider] - time.time() - return max(0, remaining) - return 0 + return await self.get_scoped_remaining(provider, scope="provider") async def get_remaining_cooldown(self, provider: str) -> float: """Backward-compatible alias for get_cooldown_remaining.""" return await self.get_cooldown_remaining(provider) + + async def start_scoped_cooldown( + self, + provider: str, + duration: int, + *, + model: str | None = None, + scope: str = "provider", + reason: str | None = None, + ) -> None: + """Start or extend a provider/model cooldown scope. + + Model scopes are intentionally separate from provider scopes because + capacity failures often belong to one model deployment rather than the + entire provider. Provider scopes remain available for provider-wide + throttles and block every model for the provider. + """ + + key = _cooldown_key(provider, scope=scope, model=model) + async with self._lock: + new_expiry = time.time() + max(0, duration) + current_expiry = self._cooldowns.get(key, 0) + if new_expiry > current_expiry: + self._cooldowns[key] = new_expiry + self._metadata[key] = {"provider": provider, "scope": scope, "model": model, "reason": reason} + + async def get_scoped_remaining(self, provider: str, *, model: str | None = None, scope: str = "provider") -> float: + """Return remaining seconds for one cooldown scope.""" + + key = _cooldown_key(provider, scope=scope, model=model) + async with self._lock: + return self._remaining_for_key(key, time.time()) + + async def get_max_remaining(self, provider: str, *, model: str | None = None) -> float: + """Return the max provider/model cooldown remaining for a request.""" + + async with self._lock: + now = time.time() + provider_remaining = self._remaining_for_key(_cooldown_key(provider, scope="provider"), now) + model_remaining = self._remaining_for_key(_cooldown_key(provider, scope="model", model=model), now) if model else 0 + return max(provider_remaining, model_remaining) + + async def is_scoped_cooling_down(self, provider: str, *, model: str | None = None, scope: str = "provider") -> bool: + """Return whether one scoped cooldown is currently active.""" + + return (await self.get_scoped_remaining(provider, model=model, scope=scope)) > 0 + + async def snapshot(self) -> tuple[CooldownSnapshot, ...]: + """Return active cooldown scopes for tests and future observability.""" + + async with self._lock: + now = time.time() + snapshots = [] + for key, expires_at in list(self._cooldowns.items()): + remaining = max(0, expires_at - now) + if remaining <= 0: + continue + metadata = self._metadata.get(key, {}) + snapshots.append( + CooldownSnapshot( + provider=str(metadata.get("provider") or key), + scope=str(metadata.get("scope") or "provider"), + model=metadata.get("model"), + remaining=remaining, + reason=metadata.get("reason"), + ) + ) + return tuple(snapshots) + + def _remaining_for_key(self, key: str, now: float) -> float: + expires_at = self._cooldowns.get(key) + if expires_at is None: + return 0 + remaining = expires_at - now + if remaining <= 0: + self._cooldowns.pop(key, None) + self._metadata.pop(key, None) + return 0 + return remaining + + +def _cooldown_key(provider: str, *, scope: str, model: str | None = None) -> str: + if scope == "model" and model: + return f"provider:{provider}:model:{model}" + return f"provider:{provider}" diff --git a/tests/test_cooldown_activation.py b/tests/test_cooldown_activation.py index a24f16941..11e0ed82a 100644 --- a/tests/test_cooldown_activation.py +++ b/tests/test_cooldown_activation.py @@ -36,6 +36,41 @@ async def test_start_cooldown_extends_but_does_not_shorten() -> None: assert after_longer > after_shorter +@pytest.mark.asyncio +async def test_model_cooldown_is_independent_from_provider_cooldown() -> None: + manager = CooldownManager() + + await manager.start_scoped_cooldown("provider", 30, model="model-a", scope="model", reason="capacity") + + assert await manager.is_scoped_cooling_down("provider", model="model-a", scope="model") is True + assert await manager.is_scoped_cooling_down("provider", model="model-b", scope="model") is False + assert await manager.is_cooling_down("provider") is False + + +@pytest.mark.asyncio +async def test_max_remaining_uses_provider_or_model_scope() -> None: + manager = CooldownManager() + + await manager.start_cooldown("provider", 1) + await manager.start_scoped_cooldown("provider", 30, model="model-a", scope="model") + + assert await manager.get_max_remaining("provider", model="model-a") > 20 + assert 0 < await manager.get_max_remaining("provider", model="model-b") <= 1 + + +@pytest.mark.asyncio +async def test_cooldown_snapshot_reports_scopes() -> None: + manager = CooldownManager() + + await manager.start_scoped_cooldown("provider", 30, model="model-a", scope="model", reason="capacity") + snapshot = await manager.snapshot() + + assert snapshot[0].provider == "provider" + assert snapshot[0].scope == "model" + assert snapshot[0].model == "model-a" + assert snapshot[0].reason == "capacity" + + def _classified(error_type: str, retry_after=None) -> ClassifiedError: return ClassifiedError(error_type, original_exception=Exception(error_type), retry_after=retry_after) From 9535baed7472db5d8b63ba75f997cf70164828f2 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:47:53 +0200 Subject: [PATCH 124/182] feat(retry): add scoped cooldown decisions Extends ProviderCooldownDecision with cooldown scope/model/backoff metadata and adds model-capacity detection for model-scoped cooldown decisions. Adds a bounded in-memory FailureHistory ring used for conservative repeated transient backoff without adding persistence or replacing credential usage tracking. Tests: pytest tests/test_retry_policy.py --- src/rotator_library/retry_policy.py | 116 +++++++++++++++++++++++++++- tests/test_retry_policy.py | 52 +++++++++++++ 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/src/rotator_library/retry_policy.py b/src/rotator_library/retry_policy.py index 52bbb5fe1..dc4a1ba67 100644 --- a/src/rotator_library/retry_policy.py +++ b/src/rotator_library/retry_policy.py @@ -12,8 +12,10 @@ import asyncio import os +import time +from collections import deque from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from .error_handler import ClassifiedError, classify_error, should_retry_same_key, should_rotate_on_error from .routing import FallbackPolicy @@ -30,6 +32,22 @@ class ProviderCooldownDecision: should_start: bool duration: int = 0 reason: str = "not_applicable" + scope: str = "provider" + model: Optional[str] = None + backoff_level: int = 0 + + +@dataclass(frozen=True) +class FailureHistoryEntry: + """One sanitized provider/model failure event kept in memory only.""" + + timestamp: float + provider: str + model: Optional[str] + error_type: str + scope: str + duration: int + reason: str def classify_route_error(error: BaseException, provider: Optional[str] = None) -> str: @@ -62,6 +80,9 @@ def decide_provider_cooldown( provider_cooldown_min_seconds: int, default_duration: int = DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS, cooldown_on_quota: bool = False, + model: Optional[str] = None, + original_error: Any = None, + failure_history: "FailureHistory | None" = None, ) -> ProviderCooldownDecision: """Return whether a provider-wide cooldown should be activated. @@ -76,6 +97,7 @@ def decide_provider_cooldown( return ProviderCooldownDecision(False, reason="quota_cooldown_disabled") if error_type not in {"rate_limit", "server_error", "api_connection", "quota_exceeded"}: return ProviderCooldownDecision(False, reason="non_provider_cooldown_error") + scope = "model" if model and is_model_capacity_error(original_error or classified_error.original_exception) else "provider" retry_after = classified_error.retry_after if retry_after is not None: @@ -85,13 +107,101 @@ def decide_provider_cooldown( return ProviderCooldownDecision(False, reason="small_retry_after") if retry_after < provider_cooldown_min_seconds: return ProviderCooldownDecision(False, reason="below_provider_cooldown_minimum") - return ProviderCooldownDecision(True, duration=int(retry_after), reason="retry_after") + return ProviderCooldownDecision(True, duration=int(retry_after), reason="retry_after", scope=scope, model=model if scope == "model" else None) if error_type in {"server_error", "api_connection"} and default_duration >= provider_cooldown_min_seconds: - return ProviderCooldownDecision(True, duration=int(default_duration), reason="default_transient_cooldown") + backoff_level = 0 + duration = int(default_duration) + if failure_history is not None: + backoff = failure_history.backoff_for(error_type=error_type, scope=scope, model=model if scope == "model" else None, default_duration=duration) + duration = backoff.duration + backoff_level = backoff.level + return ProviderCooldownDecision(True, duration=duration, reason="model_capacity_cooldown" if scope == "model" else "default_transient_cooldown", scope=scope, model=model if scope == "model" else None, backoff_level=backoff_level) return ProviderCooldownDecision(False, reason="missing_retry_after") +@dataclass(frozen=True) +class BackoffDecision: + """Bounded backoff duration derived from recent transient failures.""" + + duration: int + level: int = 0 + + +class FailureHistory: + """Bounded in-memory provider/model failure history. + + This is intentionally process-local. It provides enough recent context for + conservative cooldown backoff and future observability without introducing a + persistence layer or changing credential usage accounting. + """ + + def __init__(self, *, max_entries: int | None = None, clock: Any = None) -> None: + self.max_entries = max(1, max_entries if max_entries is not None else _env_int("FAILURE_HISTORY_MAX_ENTRIES", 200)) + self._entries: deque[FailureHistoryEntry] = deque(maxlen=self.max_entries) + self._clock = clock or time.time + + def record(self, *, provider: str, model: Optional[str], error_type: str, scope: str, duration: int, reason: str) -> None: + """Record one sanitized cooldown/failure event.""" + + self._entries.append( + FailureHistoryEntry( + timestamp=float(self._clock()), + provider=provider, + model=model, + error_type=error_type, + scope=scope, + duration=duration, + reason=reason, + ) + ) + + def snapshot(self) -> tuple[FailureHistoryEntry, ...]: + """Return recent entries for tests and future read-only reporting.""" + + return tuple(self._entries) + + def backoff_for(self, *, error_type: str, scope: str, model: Optional[str], default_duration: int) -> BackoffDecision: + """Return bounded backoff for repeated transient failures.""" + + window = _env_int("PROVIDER_BACKOFF_WINDOW_SECONDS", 60) + threshold = max(1, _env_int("PROVIDER_BACKOFF_THRESHOLD", 3)) + base = max(1, _env_int("PROVIDER_BACKOFF_BASE_SECONDS", default_duration)) + max_seconds = max(base, _env_int("PROVIDER_BACKOFF_MAX_SECONDS", 300)) + now = float(self._clock()) + recent = [ + entry + for entry in self._entries + if now - entry.timestamp <= window + and entry.error_type == error_type + and entry.scope == scope + and (scope != "model" or entry.model == model) + ] + if len(recent) + 1 < threshold: + return BackoffDecision(default_duration, level=0) + level = len(recent) + 1 - threshold + 1 + return BackoffDecision(min(max_seconds, base * (2 ** (level - 1))), level=level) + + +def is_model_capacity_error(error: Any) -> bool: + """Return whether an error indicates model/deployment capacity exhaustion.""" + + if error is None: + return False + if isinstance(error, dict): + text = str(error).lower() + else: + parts = [str(error)] + response = getattr(error, "response", None) + if response is not None: + parts.append(str(getattr(response, "text", ""))) + body = getattr(error, "body", None) + if body is not None: + parts.append(str(body)) + text = " ".join(parts).lower() + return "model_capacity_exhausted" in text or "model capacity" in text or "capacity exhausted" in text + + def provider_cooldown_env() -> tuple[int, int, bool]: """Read provider-cooldown env controls with conservative defaults.""" diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 33c23da47..61d6a2270 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -4,8 +4,10 @@ from rotator_library.error_handler import ClassifiedError, PreRequestCallbackError, classify_error from rotator_library.retry_policy import ( + FailureHistory, classify_route_error, decide_provider_cooldown, + is_model_capacity_error, is_target_failover_eligible, should_retry_same_credential, should_rotate_credential, @@ -68,6 +70,7 @@ def test_provider_cooldown_uses_large_retry_after_not_small_retry_after() -> Non assert small.reason == "small_retry_after" assert large.should_start is True assert large.duration == 60 + assert large.scope == "provider" def test_provider_cooldown_is_conservative_for_quota_by_default() -> None: @@ -88,6 +91,55 @@ def test_provider_cooldown_is_conservative_for_quota_by_default() -> None: assert enabled.should_start is True +def test_model_capacity_error_uses_model_scoped_cooldown() -> None: + error = Exception("503 MODEL_CAPACITY_EXHAUSTED") + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=30, + model="gpt-test", + original_error=error, + ) + + assert is_model_capacity_error(error) is True + assert decision.should_start is True + assert decision.scope == "model" + assert decision.model == "gpt-test" + assert decision.reason == "model_capacity_cooldown" + + +def test_failure_history_escalates_repeated_transient_backoff(monkeypatch) -> None: + now = 1000.0 + history = FailureHistory(clock=lambda: now) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "3") + monkeypatch.setenv("PROVIDER_BACKOFF_BASE_SECONDS", "10") + monkeypatch.setenv("PROVIDER_BACKOFF_MAX_SECONDS", "40") + history.record(provider="openai", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + history.record(provider="openai", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + failure_history=history, + ) + + assert decision.duration == 10 + assert decision.backoff_level == 1 + history.record(provider="openai", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + failure_history=history, + ) + assert decision.duration == 20 + assert decision.backoff_level == 2 + + def test_shared_classifier_handles_structured_dict_status_codes() -> None: assert classify_error({"error": {"status": 401}}).error_type == "authentication" assert classify_error({"error": {"details": {"status_code": 403}}}).error_type == "forbidden" From 309f5a170311369c2d3dc7891ea0377621502d32 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:52:24 +0200 Subject: [PATCH 125/182] feat(retry): wire scoped cooldowns into executor Updates RequestExecutor to wait on the max provider/model cooldown, start scoped cooldowns when retry policy requests them, and record scoped cooldown starts in FailureHistory. Keeps compatibility with existing provider-only cooldown fakes and adds sanitized cooldown wait/start trace metadata with scope and model. Tests: pytest tests/test_cooldown_activation.py tests/test_retry_policy.py --- src/rotator_library/client/executor.py | 81 +++++++++++++++++++------- tests/test_cooldown_activation.py | 43 +++++++++++++- 2 files changed, 102 insertions(+), 22 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index e54885f3a..2959aa413 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -63,7 +63,7 @@ from ..request_sanitizer import sanitize_request_payload from ..transaction_logger import TransactionLogger from ..failure_logger import log_failure -from ..retry_policy import decide_provider_cooldown, provider_cooldown_env +from ..retry_policy import FailureHistory, decide_provider_cooldown, provider_cooldown_env from ..routing import FallbackPolicy, clone_context_for_target from ..routing.policy import normalize_route_error_type from ..routing.types import RouteTarget @@ -155,6 +155,7 @@ def __init__( # StreamingHandler no longer needs usage_manager - we pass cred_context directly self._streaming_handler = StreamingHandler() self._native_executor = NativeProviderExecutor() + self._failure_history = FailureHistory() def _get_transient_retry_delay(self) -> float: """Small jittered delay used before transient retries and rotations.""" @@ -1078,7 +1079,7 @@ async def _execute_non_streaming( break # Wait for provider cooldown - await self._wait_for_cooldown(provider, deadline) + await self._wait_for_cooldown(provider, deadline, model=model, context=context) # Acquire credential using context manager try: @@ -1347,7 +1348,7 @@ async def _execute_streaming( remaining = deadline - time.time() if remaining <= 0: break - await self._wait_for_cooldown(provider, deadline) + await self._wait_for_cooldown(provider, deadline, model=model, context=context) # Acquire credential using context manager try: @@ -1540,9 +1541,7 @@ async def _execute_streaming( if _can_start_stream_provider_cooldown( last_streamed_chunk ): - await self._maybe_start_provider_cooldown( - provider, classified, context=context - ) + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=original) log_failure( api_key=cred, model=model, @@ -1631,9 +1630,7 @@ async def _execute_streaming( if _can_start_stream_provider_cooldown( last_streamed_chunk ): - await self._maybe_start_provider_cooldown( - provider, classified, context=context - ) + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( api_key=cred, model=model, @@ -1712,9 +1709,7 @@ async def _execute_streaming( if _can_start_stream_provider_cooldown( last_streamed_chunk ): - await self._maybe_start_provider_cooldown( - provider, classified, context=context - ) + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( api_key=cred, model=model, @@ -1754,7 +1749,7 @@ async def _execute_streaming( last_exception = e classified = classify_error(e, provider) if _can_start_stream_provider_cooldown(last_streamed_chunk): - await self._maybe_start_provider_cooldown(provider, classified, context=context) + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure(api_key=cred, model=model, attempt=attempt + 1, error=e, request_headers=request_headers) error_accumulator.record_error(cred, classified, str(e)[:150]) if not should_rotate_on_error(classified): @@ -1769,9 +1764,7 @@ async def _execute_streaming( if _can_start_stream_provider_cooldown( last_streamed_chunk ): - await self._maybe_start_provider_cooldown( - provider, classified, context=context - ) + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( api_key=cred, model=model, @@ -1881,9 +1874,12 @@ async def _wait_for_cooldown( self, provider: str, deadline: float, + *, + model: Optional[str] = None, + context: Optional[RequestContext] = None, ) -> None: """ - Wait for provider-level cooldown to end. + Wait for provider/model cooldown to end. Args: provider: Provider name @@ -1892,7 +1888,10 @@ async def _wait_for_cooldown( if not self._cooldown: return - remaining = await self._cooldown.get_remaining_cooldown(provider) + if hasattr(self._cooldown, "get_max_remaining"): + remaining = await self._cooldown.get_max_remaining(provider, model=model) + else: + remaining = await self._cooldown.get_remaining_cooldown(provider) if remaining > 0: budget = deadline - time.time() if remaining > budget: @@ -1901,6 +1900,7 @@ async def _wait_for_cooldown( ) return # Will fail on no keys available lib_logger.info(f"Waiting {remaining:.1f}s for {provider} cooldown") + self._log_cooldown_wait_trace(context, provider, model, remaining) await asyncio.sleep(remaining) async def _handle_error_with_context( @@ -1962,7 +1962,7 @@ async def _handle_error_with_context( cred_context.mark_failure(classified) return ErrorAction.FAIL - await self._maybe_start_provider_cooldown(provider, classified, context=context) + await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=error) # Check if should retry same key (including small cooldown auto-retry) small_cooldown_threshold = int( @@ -2015,6 +2015,8 @@ async def _maybe_start_provider_cooldown( classified: ClassifiedError, *, context: Optional[RequestContext], + model: Optional[str] = None, + original_error: Any = None, ) -> None: """Start provider-wide cooldown for large provider-level throttles. @@ -2037,6 +2039,9 @@ async def _maybe_start_provider_cooldown( provider_cooldown_min_seconds=min_seconds, default_duration=default_seconds, cooldown_on_quota=cooldown_on_quota, + model=model, + original_error=original_error, + failure_history=getattr(self, "_failure_history", None), ) if not decision.should_start: self._log_provider_cooldown_trace( @@ -2046,10 +2051,19 @@ async def _maybe_start_provider_cooldown( classified, decision.duration, decision.reason, + scope=decision.scope, + model=decision.model, + backoff_level=decision.backoff_level, ) return try: - await self._cooldown.start_cooldown(provider, decision.duration) + if hasattr(self._cooldown, "start_scoped_cooldown"): + await self._cooldown.start_scoped_cooldown(provider, decision.duration, model=decision.model, scope=decision.scope, reason=decision.reason) + else: + await self._cooldown.start_cooldown(provider, decision.duration) + history = getattr(self, "_failure_history", None) + if history is not None: + history.record(provider=provider, model=decision.model, error_type=classified.error_type, scope=decision.scope, duration=decision.duration, reason=decision.reason) self._log_provider_cooldown_trace( context, "provider_cooldown_started", @@ -2057,6 +2071,9 @@ async def _maybe_start_provider_cooldown( classified, decision.duration, decision.reason, + scope=decision.scope, + model=decision.model, + backoff_level=decision.backoff_level, ) except Exception as exc: lib_logger.debug("Failed to start provider cooldown for %s: %s", provider, exc) @@ -2069,17 +2086,24 @@ def _log_provider_cooldown_trace( classified: ClassifiedError, duration: int, reason: str, + *, + scope: str = "provider", + model: Optional[str] = None, + backoff_level: int = 0, ) -> None: if not context or not context.transaction_logger: return context.transaction_logger.log_transform_pass( pass_name, - {"provider": provider, "error_type": classified.error_type, "duration": duration}, + {"provider": provider, "model": model, "scope": scope, "error_type": classified.error_type, "duration": duration}, direction="metadata", stage="retry", metadata={ "provider": provider, "duration": duration, + "scope": scope, + "model": model, + "backoff_level": backoff_level, "error_type": classified.error_type, "retry_after_present": classified.retry_after is not None, "reason": reason, @@ -2087,6 +2111,21 @@ def _log_provider_cooldown_trace( snapshot=False, ) + @staticmethod + def _log_cooldown_wait_trace(context: Optional[RequestContext], provider: str, model: Optional[str], remaining: float) -> None: + """Trace cooldown waits without exposing credentials or payloads.""" + + if not context or not context.transaction_logger: + return + context.transaction_logger.log_transform_pass( + "cooldown_wait", + {"provider": provider, "model": model, "remaining": remaining}, + direction="metadata", + stage="retry", + metadata={"provider": provider, "model": model, "remaining": remaining}, + snapshot=False, + ) + def _record_session_response(self, context: RequestContext, response: Any) -> None: """Let the tracker learn anchors emitted by a successful response. diff --git a/tests/test_cooldown_activation.py b/tests/test_cooldown_activation.py index 11e0ed82a..208b611f6 100644 --- a/tests/test_cooldown_activation.py +++ b/tests/test_cooldown_activation.py @@ -14,10 +14,19 @@ class FakeCooldown: def __init__(self) -> None: self.started = [] + self.scoped_started = [] + self.waits = [] async def start_cooldown(self, provider, duration): self.started.append((provider, duration)) + async def start_scoped_cooldown(self, provider, duration, *, model=None, scope="provider", reason=None): + self.scoped_started.append((provider, duration, model, scope, reason)) + + async def get_max_remaining(self, provider, *, model=None): + self.waits.append((provider, model)) + return 0 + @pytest.mark.asyncio async def test_start_cooldown_extends_but_does_not_shorten() -> None: @@ -78,6 +87,9 @@ def _classified(error_type: str, retry_after=None) -> ClassifiedError: def _executor(cooldown) -> RequestExecutor: executor = RequestExecutor.__new__(RequestExecutor) executor._cooldown = cooldown + from rotator_library.retry_policy import FailureHistory + + executor._failure_history = FailureHistory() return executor @@ -106,7 +118,7 @@ async def test_large_retry_after_starts_provider_cooldown_and_traces(tmp_path, m context=_context(logger), ) - assert cooldown.started == [("openai", 60)] + assert cooldown.scoped_started == [("openai", 60, None, "provider", "retry_after")] trace = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") assert "provider_cooldown_started" in trace @@ -123,6 +135,35 @@ async def test_small_retry_after_skips_provider_cooldown(monkeypatch) -> None: ) assert cooldown.started == [] + assert cooldown.scoped_started == [] + + +@pytest.mark.asyncio +async def test_model_capacity_starts_model_scoped_cooldown(monkeypatch) -> None: + monkeypatch.setenv("PROVIDER_COOLDOWN_DEFAULT_SECONDS", "30") + monkeypatch.setenv("PROVIDER_COOLDOWN_MIN_SECONDS", "10") + cooldown = FakeCooldown() + executor = _executor(cooldown) + + await executor._maybe_start_provider_cooldown( + "openai", + _classified("server_error"), + context=None, + model="gpt-5", + original_error=Exception("MODEL_CAPACITY_EXHAUSTED"), + ) + + assert cooldown.scoped_started == [("openai", 30, "gpt-5", "model", "model_capacity_cooldown")] + assert executor._failure_history.snapshot()[0].scope == "model" + + +@pytest.mark.asyncio +async def test_wait_for_cooldown_uses_model_scope_when_available() -> None: + cooldown = FakeCooldown() + + await _executor(cooldown)._wait_for_cooldown("openai", 9999999999.0, model="gpt-5") + + assert cooldown.waits == [("openai", "gpt-5")] def test_streaming_provider_cooldown_gate_allows_only_pre_output_failures() -> None: From 0b8551c5503515b4e461300fde369caa8d09a202 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 13:54:22 +0200 Subject: [PATCH 126/182] feat(streaming): expose scoped cooldown decisions Adds cooldown scope/model fields to side-effect-free streaming error decisions and passes model-capacity context into retry policy cooldown decisions. Classifies model-capacity exception text as server_error so existing retry semantics apply while retry policy can choose model-scoped cooldown. Tests: pytest tests/test_streaming_error_handler.py tests/test_retry_policy.py --- src/rotator_library/error_handler.py | 7 +++++++ src/rotator_library/streaming/errors.py | 11 +++++++++++ tests/test_streaming_error_handler.py | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index 7819a797a..5388b1c7d 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -1090,6 +1090,13 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code=status_code or 503, ) + if "model_capacity_exhausted" in error_text.lower() or "model capacity" in error_text.lower() or "capacity exhausted" in error_text.lower(): + return ClassifiedError( + error_type="server_error", + original_exception=e, + status_code=status_code or 503, + ) + # Fallback for any other unclassified errors return ClassifiedError( error_type="unknown", original_exception=e, status_code=status_code diff --git a/src/rotator_library/streaming/errors.py b/src/rotator_library/streaming/errors.py index 83c9624b5..2e5c894c3 100644 --- a/src/rotator_library/streaming/errors.py +++ b/src/rotator_library/streaming/errors.py @@ -20,6 +20,8 @@ class StreamingErrorDecision: action: str start_provider_cooldown: bool = False provider_cooldown_duration: int = 0 + provider_cooldown_scope: str = "provider" + provider_cooldown_model: str | None = None reason: str = "" @@ -35,6 +37,7 @@ def decide_streaming_error_action( provider_cooldown_default_seconds: int, cooldown_on_quota: bool = False, allow_reasoning_only_retry: bool = False, + model: str | None = None, ) -> StreamingErrorDecision: """Classify a stream failure without sleeping or mutating state. @@ -51,6 +54,8 @@ def decide_streaming_error_action( provider_cooldown_default_seconds=provider_cooldown_default_seconds, cooldown_on_quota=cooldown_on_quota, last_streamed_chunk=last_streamed_chunk, + model=model, + original_error=error, ) if not should_rotate_on_error(classified): return _decision(classified, "fail", cooldown, "non_rotatable") @@ -69,6 +74,8 @@ def _cooldown_decision( provider_cooldown_default_seconds: int, cooldown_on_quota: bool, last_streamed_chunk: str | None, + model: str | None, + original_error: Exception, ) -> ProviderCooldownDecision: if is_visible_stream_output(last_streamed_chunk): return ProviderCooldownDecision(False, reason="visible_output") @@ -78,6 +85,8 @@ def _cooldown_decision( provider_cooldown_min_seconds=provider_cooldown_min_seconds, default_duration=provider_cooldown_default_seconds, cooldown_on_quota=cooldown_on_quota, + model=model, + original_error=original_error, ) @@ -92,5 +101,7 @@ def _decision( action=action, start_provider_cooldown=cooldown.should_start, provider_cooldown_duration=cooldown.duration, + provider_cooldown_scope=cooldown.scope, + provider_cooldown_model=cooldown.model, reason=reason if not cooldown.should_start else f"{reason};cooldown:{cooldown.reason}", ) diff --git a/tests/test_streaming_error_handler.py b/tests/test_streaming_error_handler.py index c60b822e9..d21fe0e90 100644 --- a/tests/test_streaming_error_handler.py +++ b/tests/test_streaming_error_handler.py @@ -43,6 +43,7 @@ def test_streaming_error_decision_starts_cooldown_before_visible_output() -> Non assert decision.action == "rotate" assert decision.start_provider_cooldown is True assert decision.provider_cooldown_duration == 60 + assert decision.provider_cooldown_scope == "provider" def test_streaming_error_decision_blocks_after_visible_output_and_skips_cooldown() -> None: @@ -75,3 +76,21 @@ def test_streaming_error_decision_allows_reasoning_only_retry_when_enabled() -> ) assert decision.action == "retry_same" + + +def test_streaming_error_decision_reports_model_cooldown_scope() -> None: + decision = decide_streaming_error_action( + Exception("MODEL_CAPACITY_EXHAUSTED"), + provider="openai", + model="gpt-5", + last_streamed_chunk=None, + attempt=1, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.start_provider_cooldown is True + assert decision.provider_cooldown_scope == "model" + assert decision.provider_cooldown_model == "gpt-5" From 45ebfa99f7354f6baaa1f6d9b7de3a0fd1aa7080 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:04:22 +0200 Subject: [PATCH 127/182] fix(retry): guard stream retries after visible output Tracks whether any visible stream output was emitted so cooldown, retry, and rotation decisions cannot be re-enabled by a later non-visible chunk. Scopes FailureHistory backoff by provider as well as model/scope to prevent one provider's transient failures from escalating another provider's cooldown. Tests: pytest tests/test_cooldown_activation.py tests/test_retry_policy.py tests/test_streaming_error_handler.py tests/test_stream_policy.py tests/test_fallback_policy.py tests/test_fallback_attempt_runner.py tests/test_request_executor_fallback_groups.py tests/test_streaming_fallback_policy.py tests/test_request_executor_native_routing.py --- src/rotator_library/client/executor.py | 64 +++++++++++++++++++++---- src/rotator_library/retry_policy.py | 6 ++- src/rotator_library/streaming/errors.py | 3 ++ tests/test_cooldown_activation.py | 1 + tests/test_retry_policy.py | 29 +++++++++++ 5 files changed, 91 insertions(+), 12 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 2959aa413..dfd068785 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1417,6 +1417,7 @@ async def _execute_streaming( # Execute request with retries for attempt in range(self._max_retries): last_streamed_chunk: Optional[str] = None + stream_visible_output_emitted = False try: lib_logger.info( @@ -1527,10 +1528,14 @@ async def _execute_streaming( plugin=plugin, ): last_streamed_chunk = chunk + if _stream_chunk_is_visible_output(chunk): + stream_visible_output_emitted = True yield chunk else: async for chunk in base_stream: last_streamed_chunk = chunk + if _stream_chunk_is_visible_output(chunk): + stream_visible_output_emitted = True yield chunk return @@ -1539,7 +1544,8 @@ async def _execute_streaming( original = getattr(e, "data", e) classified = classify_error(original, provider) if _can_start_stream_provider_cooldown( - last_streamed_chunk + last_streamed_chunk, + emitted_output=stream_visible_output_emitted, ): await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=original) log_failure( @@ -1578,10 +1584,7 @@ async def _execute_streaming( cred_context.mark_failure(classified) raise - if not can_retry_stream_after_error( - last_streamed_chunk, - self._stream_retry_on_reasoning_only_enabled(), - ): + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): cred_context.mark_failure(classified) error_data = { "error": { @@ -1628,7 +1631,8 @@ async def _execute_streaming( last_exception = e classified = classify_error(e, provider) if _can_start_stream_provider_cooldown( - last_streamed_chunk + last_streamed_chunk, + emitted_output=stream_visible_output_emitted, ): await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( @@ -1667,6 +1671,13 @@ async def _execute_streaming( cred_context.mark_failure(classified) raise + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): + cred_context.mark_failure(classified) + error_data = {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line + return + # Check for small cooldown - retry same key instead of rotating small_cooldown_threshold = int( os.environ.get( @@ -1707,7 +1718,8 @@ async def _execute_streaming( last_exception = e classified = classify_error(e, provider) if _can_start_stream_provider_cooldown( - last_streamed_chunk + last_streamed_chunk, + emitted_output=stream_visible_output_emitted, ): await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( @@ -1718,6 +1730,13 @@ async def _execute_streaming( request_headers=request_headers, ) + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): + cred_context.mark_failure(classified) + error_data = {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line + return + if attempt >= self._max_retries - 1: error_accumulator.record_error( cred, classified, str(e)[:150] @@ -1748,13 +1767,19 @@ async def _execute_streaming( raise last_exception = e classified = classify_error(e, provider) - if _can_start_stream_provider_cooldown(last_streamed_chunk): + if _can_start_stream_provider_cooldown(last_streamed_chunk, emitted_output=stream_visible_output_emitted): await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure(api_key=cred, model=model, attempt=attempt + 1, error=e, request_headers=request_headers) error_accumulator.record_error(cred, classified, str(e)[:150]) if not should_rotate_on_error(classified): cred_context.mark_failure(classified) raise + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): + cred_context.mark_failure(classified) + error_data = {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line + return cred_context.mark_failure(classified) break @@ -1762,7 +1787,8 @@ async def _execute_streaming( last_exception = e classified = classify_error(e, provider) if _can_start_stream_provider_cooldown( - last_streamed_chunk + last_streamed_chunk, + emitted_output=stream_visible_output_emitted, ): await self._maybe_start_provider_cooldown(provider, classified, context=context, model=model, original_error=e) log_failure( @@ -1780,6 +1806,13 @@ async def _execute_streaming( cred_context.mark_failure(classified) raise + if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): + cred_context.mark_failure(classified) + error_data = {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type}} + for line in self._terminal_stream_error_lines(context, error_data): + yield line + return + small_cooldown_threshold = int( os.environ.get( "SMALL_COOLDOWN_RETRY_THRESHOLD", @@ -2039,6 +2072,7 @@ async def _maybe_start_provider_cooldown( provider_cooldown_min_seconds=min_seconds, default_duration=default_seconds, cooldown_on_quota=cooldown_on_quota, + provider=provider, model=model, original_error=original_error, failure_history=getattr(self, "_failure_history", None), @@ -2823,7 +2857,17 @@ def _stream_chunk_payload(chunk: str) -> Optional[Dict[str, Any]]: return parsed -def _can_start_stream_provider_cooldown(last_streamed_chunk: Optional[str]) -> bool: +def _can_start_stream_provider_cooldown(last_streamed_chunk: Optional[str], *, emitted_output: bool = False) -> bool: """Return whether a streaming failure occurred before visible output.""" + if emitted_output: + return False return last_streamed_chunk is None or not _stream_chunk_is_visible_output(last_streamed_chunk) + + +def _can_retry_stream_after_error(last_streamed_chunk: Optional[str], allow_reasoning_only_retry: bool, *, emitted_output: bool = False) -> bool: + """Return whether a stream can retry/rotate without duplicating output.""" + + if emitted_output: + return False + return can_retry_stream_after_error(last_streamed_chunk, allow_reasoning_only_retry) diff --git a/src/rotator_library/retry_policy.py b/src/rotator_library/retry_policy.py index dc4a1ba67..63445cc26 100644 --- a/src/rotator_library/retry_policy.py +++ b/src/rotator_library/retry_policy.py @@ -80,6 +80,7 @@ def decide_provider_cooldown( provider_cooldown_min_seconds: int, default_duration: int = DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS, cooldown_on_quota: bool = False, + provider: Optional[str] = None, model: Optional[str] = None, original_error: Any = None, failure_history: "FailureHistory | None" = None, @@ -113,7 +114,7 @@ def decide_provider_cooldown( backoff_level = 0 duration = int(default_duration) if failure_history is not None: - backoff = failure_history.backoff_for(error_type=error_type, scope=scope, model=model if scope == "model" else None, default_duration=duration) + backoff = failure_history.backoff_for(provider=provider, error_type=error_type, scope=scope, model=model if scope == "model" else None, default_duration=duration) duration = backoff.duration backoff_level = backoff.level return ProviderCooldownDecision(True, duration=duration, reason="model_capacity_cooldown" if scope == "model" else "default_transient_cooldown", scope=scope, model=model if scope == "model" else None, backoff_level=backoff_level) @@ -161,7 +162,7 @@ def snapshot(self) -> tuple[FailureHistoryEntry, ...]: return tuple(self._entries) - def backoff_for(self, *, error_type: str, scope: str, model: Optional[str], default_duration: int) -> BackoffDecision: + def backoff_for(self, *, provider: Optional[str], error_type: str, scope: str, model: Optional[str], default_duration: int) -> BackoffDecision: """Return bounded backoff for repeated transient failures.""" window = _env_int("PROVIDER_BACKOFF_WINDOW_SECONDS", 60) @@ -173,6 +174,7 @@ def backoff_for(self, *, error_type: str, scope: str, model: Optional[str], defa entry for entry in self._entries if now - entry.timestamp <= window + and (provider is None or entry.provider == provider) and entry.error_type == error_type and entry.scope == scope and (scope != "model" or entry.model == model) diff --git a/src/rotator_library/streaming/errors.py b/src/rotator_library/streaming/errors.py index 2e5c894c3..1e9ffaf88 100644 --- a/src/rotator_library/streaming/errors.py +++ b/src/rotator_library/streaming/errors.py @@ -54,6 +54,7 @@ def decide_streaming_error_action( provider_cooldown_default_seconds=provider_cooldown_default_seconds, cooldown_on_quota=cooldown_on_quota, last_streamed_chunk=last_streamed_chunk, + provider=provider, model=model, original_error=error, ) @@ -74,6 +75,7 @@ def _cooldown_decision( provider_cooldown_default_seconds: int, cooldown_on_quota: bool, last_streamed_chunk: str | None, + provider: str, model: str | None, original_error: Exception, ) -> ProviderCooldownDecision: @@ -85,6 +87,7 @@ def _cooldown_decision( provider_cooldown_min_seconds=provider_cooldown_min_seconds, default_duration=provider_cooldown_default_seconds, cooldown_on_quota=cooldown_on_quota, + provider=provider, model=model, original_error=original_error, ) diff --git a/tests/test_cooldown_activation.py b/tests/test_cooldown_activation.py index 208b611f6..1bb238eeb 100644 --- a/tests/test_cooldown_activation.py +++ b/tests/test_cooldown_activation.py @@ -170,3 +170,4 @@ def test_streaming_provider_cooldown_gate_allows_only_pre_output_failures() -> N assert _can_start_stream_provider_cooldown(None) is True assert _can_start_stream_provider_cooldown('data: {"error":{"type":"rate_limit"}}\n\n') is True assert _can_start_stream_provider_cooldown('data: {"choices":[{"delta":{"content":"visible"}}]}\n\n') is False + assert _can_start_stream_provider_cooldown('data: {"usage":{"total_tokens":1}}\n\n', emitted_output=True) is False diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 61d6a2270..f8f2a4343 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -123,6 +123,7 @@ def test_failure_history_escalates_repeated_transient_backoff(monkeypatch) -> No small_cooldown_threshold=10, provider_cooldown_min_seconds=10, default_duration=10, + provider="openai", failure_history=history, ) @@ -134,12 +135,40 @@ def test_failure_history_escalates_repeated_transient_backoff(monkeypatch) -> No small_cooldown_threshold=10, provider_cooldown_min_seconds=10, default_duration=10, + provider="openai", failure_history=history, ) assert decision.duration == 20 assert decision.backoff_level == 2 +def test_failure_history_backoff_is_provider_scoped(monkeypatch) -> None: + history = FailureHistory(clock=lambda: 1000.0) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") + monkeypatch.setenv("PROVIDER_BACKOFF_BASE_SECONDS", "10") + history.record(provider="provider-a", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + + provider_b = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="provider-b", + failure_history=history, + ) + provider_a = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="provider-a", + failure_history=history, + ) + + assert provider_b.backoff_level == 0 + assert provider_a.backoff_level == 1 + + def test_shared_classifier_handles_structured_dict_status_codes() -> None: assert classify_error({"error": {"status": 401}}).error_type == "authentication" assert classify_error({"error": {"details": {"status_code": 403}}}).error_type == "forbidden" From b1e0adef371699ac71ae72727a0210c25871f799 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:09:43 +0200 Subject: [PATCH 128/182] fix(retry): latch streaming policy decisions Adds an emitted_output latch to the reusable streaming error decision helper so cooldown/retry decisions stay blocked after any prior visible stream output, even if the latest chunk is non-visible. Tightens FailureHistory backoff matching so missing provider no longer acts as a cross-provider wildcard. Tests: pytest tests/test_cooldown_activation.py tests/test_retry_policy.py tests/test_streaming_error_handler.py tests/test_stream_policy.py tests/test_fallback_policy.py tests/test_fallback_attempt_runner.py tests/test_request_executor_fallback_groups.py tests/test_streaming_fallback_policy.py tests/test_request_executor_native_routing.py --- src/rotator_library/retry_policy.py | 2 +- src/rotator_library/streaming/errors.py | 7 +++++-- tests/test_retry_policy.py | 16 ++++++++++++++++ tests/test_streaming_error_handler.py | 17 +++++++++++++++++ 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/rotator_library/retry_policy.py b/src/rotator_library/retry_policy.py index 63445cc26..ff7f1c6ca 100644 --- a/src/rotator_library/retry_policy.py +++ b/src/rotator_library/retry_policy.py @@ -174,7 +174,7 @@ def backoff_for(self, *, provider: Optional[str], error_type: str, scope: str, m entry for entry in self._entries if now - entry.timestamp <= window - and (provider is None or entry.provider == provider) + and entry.provider == provider and entry.error_type == error_type and entry.scope == scope and (scope != "model" or entry.model == model) diff --git a/src/rotator_library/streaming/errors.py b/src/rotator_library/streaming/errors.py index 1e9ffaf88..94467e228 100644 --- a/src/rotator_library/streaming/errors.py +++ b/src/rotator_library/streaming/errors.py @@ -38,6 +38,7 @@ def decide_streaming_error_action( cooldown_on_quota: bool = False, allow_reasoning_only_retry: bool = False, model: str | None = None, + emitted_output: bool = False, ) -> StreamingErrorDecision: """Classify a stream failure without sleeping or mutating state. @@ -54,13 +55,14 @@ def decide_streaming_error_action( provider_cooldown_default_seconds=provider_cooldown_default_seconds, cooldown_on_quota=cooldown_on_quota, last_streamed_chunk=last_streamed_chunk, + emitted_output=emitted_output, provider=provider, model=model, original_error=error, ) if not should_rotate_on_error(classified): return _decision(classified, "fail", cooldown, "non_rotatable") - if not can_retry_stream_after_error(last_streamed_chunk, allow_reasoning_only_retry): + if emitted_output or not can_retry_stream_after_error(last_streamed_chunk, allow_reasoning_only_retry): return _decision(classified, "fallback_blocked_after_output", cooldown, "visible_output") if should_retry_same_key(classified, small_cooldown_threshold) and attempt < max_retries - 1: return _decision(classified, "retry_same", cooldown, "retry_same_credential") @@ -75,11 +77,12 @@ def _cooldown_decision( provider_cooldown_default_seconds: int, cooldown_on_quota: bool, last_streamed_chunk: str | None, + emitted_output: bool, provider: str, model: str | None, original_error: Exception, ) -> ProviderCooldownDecision: - if is_visible_stream_output(last_streamed_chunk): + if emitted_output or is_visible_stream_output(last_streamed_chunk): return ProviderCooldownDecision(False, reason="visible_output") return decide_provider_cooldown( classified, diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index f8f2a4343..1386ffc96 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -169,6 +169,22 @@ def test_failure_history_backoff_is_provider_scoped(monkeypatch) -> None: assert provider_a.backoff_level == 1 +def test_failure_history_backoff_requires_matching_provider(monkeypatch) -> None: + history = FailureHistory(clock=lambda: 1000.0) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") + history.record(provider="provider-a", model=None, error_type="server_error", scope="provider", duration=10, reason="test") + + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + failure_history=history, + ) + + assert decision.backoff_level == 0 + + def test_shared_classifier_handles_structured_dict_status_codes() -> None: assert classify_error({"error": {"status": 401}}).error_type == "authentication" assert classify_error({"error": {"details": {"status_code": 403}}}).error_type == "forbidden" diff --git a/tests/test_streaming_error_handler.py b/tests/test_streaming_error_handler.py index d21fe0e90..f85b513d1 100644 --- a/tests/test_streaming_error_handler.py +++ b/tests/test_streaming_error_handler.py @@ -62,6 +62,23 @@ def test_streaming_error_decision_blocks_after_visible_output_and_skips_cooldown assert decision.start_provider_cooldown is False +def test_streaming_error_decision_blocks_after_prior_visible_output() -> None: + decision = decide_streaming_error_action( + _rate_limit(60), + provider="openai", + last_streamed_chunk='data: {"usage":{"total_tokens":1}}\n\n', + emitted_output=True, + attempt=0, + max_retries=2, + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + provider_cooldown_default_seconds=30, + ) + + assert decision.action == "fallback_blocked_after_output" + assert decision.start_provider_cooldown is False + + def test_streaming_error_decision_allows_reasoning_only_retry_when_enabled() -> None: decision = decide_streaming_error_action( _rate_limit(3), From 6b791cac92561db65e55c411ea6042e2279821e3 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:16:48 +0200 Subject: [PATCH 129/182] docs(experimental): plan streaming hardening correction Adds the Phase 8b corrective plan for stream cancellation, heartbeat support, TTFB/stall timeout policy, native HTTP streaming transport, and lifecycle traces. Tests: not run (planning document only) --- .../phase-8b-streaming-hardening.md | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 docs/experimental/phase-8b-streaming-hardening.md diff --git a/docs/experimental/phase-8b-streaming-hardening.md b/docs/experimental/phase-8b-streaming-hardening.md new file mode 100644 index 000000000..932bec89a --- /dev/null +++ b/docs/experimental/phase-8b-streaming-hardening.md @@ -0,0 +1,84 @@ +# Phase 8b: Streaming Hardening, Cancellation, Heartbeats, And Stall Policy + +## Goal + +Correct the Phase 8 audit findings. Phase 8 added stream events, metrics, formatters, and observability, but stream hardening still needs upstream cancellation, active TTFB/stall policy, heartbeat support, and generic native HTTP streaming support. + +## Non-Goals + +- Do not enable native streaming for priority providers from Phase 5b. +- Do not rewrite the entire streaming executor or change existing chat-completions SSE chunk format by default. +- Do not implement a WebSocket FastAPI route. +- Do not weaken Phase 6b fallback visible-output safety or Phase 7b cooldown/retry latch behavior. +- Do not introduce persistence. +- Do not commit user-facing reports. + +## Current State + +- `StreamMonitor` records TTFB, TTFT, chunk counts, errors, cancellation, and stall status, but no active timeout policy uses it. +- `StreamingHandler.wrap_stream()` detects client disconnect but does not guarantee upstream iterator closure. +- No heartbeat frames are emitted during long waits between chunks. +- `NativeHTTPTransport.stream_json_lines()` requires custom injected clients to expose `stream_json_lines()` and does not support generic `httpx.AsyncClient.stream()`. +- Existing tests cover metrics and fallback, but not upstream cancellation, heartbeats, TTFB timeout, stall timeout, or generic native stream transport. + +## Implementation Plan + +1. Extend stream runtime settings. + - Add `ttfb_timeout_seconds`, `stall_timeout_seconds`, `heartbeat_interval_seconds`, and `cancel_upstream_on_disconnect`. + - Add env overrides: `STREAM_TTFB_TIMEOUT_SECONDS`, `STREAM_STALL_TIMEOUT_SECONDS`, `STREAM_HEARTBEAT_INTERVAL_SECONDS`, `STREAM_CANCEL_UPSTREAM_ON_DISCONNECT`. + - Keep defaults behavior-compatible: no timeout/heartbeat unless configured; upstream cancellation enabled on disconnect. + +2. Add upstream stream close helper. + - Close via `aclose()` when available, otherwise `close()`. + - Use it on client disconnect, cancellation, or abnormal stream exit. + - Log close failures only at debug/trace level. + +3. Add heartbeat formatting. + - Add `format_heartbeat()` to SSE/WebSocket/JSONL formatters. + - SSE heartbeat is a comment frame such as `: heartbeat\n\n`. + - Heartbeats must not count as visible output, session evidence, or usage. + +4. Add heartbeat emission in `StreamingHandler.wrap_stream()`. + - When configured, wait for upstream chunks with heartbeat interval timeout and yield heartbeat comments while waiting. + - Default remains no heartbeat. + +5. Add TTFB timeout policy. + - If configured and no first byte arrives in time, raise `StreamedAPIError` with structured `api_connection` timeout payload. + - This occurs before visible output so existing retry/fallback can apply. + +6. Add stall timeout policy. + - If configured and no chunk arrives for the configured interval after first byte, raise `StreamedAPIError` with structured `api_connection` timeout payload. + - Phase 7b visible-output latch must still suppress retry/fallback/cooldown if output was already visible. + +7. Add native `httpx` stream support. + - Keep custom `stream_json_lines()` support first. + - Otherwise use `client.stream("POST", ...)` with `aiter_lines()` or `aiter_bytes()` fallback. + - Preserve `[DONE]`, ignore empty lines, and parse `data:` JSON when possible. + +8. Add lifecycle trace events. + - `stream_heartbeat` + - `stream_ttfb_timeout` + - `stream_stall_timeout` + - `stream_upstream_cancelled` + - `stream_upstream_close_failed` + - Keep snapshots disabled and metadata sanitized. + +## Tests + +- Formatter heartbeat tests. +- Streaming handler disconnect/upstream close tests. +- Heartbeat interval tests. +- TTFB timeout tests. +- Stall timeout tests before and after visible output. +- Native HTTP transport tests for `httpx`-style streaming. +- Phase 7b retry/cooldown/routing regression subset. + +## Acceptance Criteria + +- Client disconnect closes upstream async streams when possible. +- Heartbeats are supported, disabled by default, and non-visible. +- Configured TTFB/stall timeouts produce structured stream errors. +- Visible-output latch still prevents retry/fallback/cooldown after output. +- Native HTTP transport supports generic `httpx` streaming plus custom test clients. +- Stream traces include heartbeat/timeout/cancel metadata without secrets. +- Focused tests and dual-agent review pass. From 508d13bdf46f172d07eac9546bee8e7180998fa7 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:18:12 +0200 Subject: [PATCH 130/182] feat(streaming): add heartbeat runtime settings Extends experimental stream runtime settings with active hardening knobs and adds transport-level heartbeat formatting for SSE, WebSocket, and JSONL without changing default stream output. Tests: pytest tests/test_config_stream_settings.py tests/test_stream_transport.py --- src/rotator_library/config/experimental.py | 10 ++++++---- src/rotator_library/streaming/transport.py | 21 +++++++++++++++++++++ tests/test_config_stream_settings.py | 11 +++++++++-- tests/test_stream_transport.py | 8 ++++++++ 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py index 65a0a4ca2..19d8b9f64 100644 --- a/src/rotator_library/config/experimental.py +++ b/src/rotator_library/config/experimental.py @@ -57,14 +57,15 @@ def is_empty(self) -> bool: class StreamRuntimeSettings: """Runtime stream observability settings. - Timeout and heartbeat values are parsed here for future enforcement, but - Phase 10 only wires `trace_metrics` into runtime behavior. This avoids - surprising long-running reasoning streams while still validating config. + Timeout and heartbeat values default to disabled so existing long-running + reasoning streams keep working. Operators can opt into active stream + hardening through env or JSON config without changing provider code. """ ttfb_timeout_seconds: Optional[float] = None stall_timeout_seconds: Optional[float] = None heartbeat_seconds: Optional[float] = None + cancel_upstream_on_disconnect: bool = True trace_metrics: bool = True @@ -155,7 +156,8 @@ def get_stream_runtime_settings( return StreamRuntimeSettings( ttfb_timeout_seconds=_optional_positive_float(_env_or_json(source, "STREAM_TTFB_TIMEOUT_SECONDS", streaming, "ttfb_timeout_seconds"), "STREAM_TTFB_TIMEOUT_SECONDS"), stall_timeout_seconds=_optional_positive_float(_env_or_json(source, "STREAM_STALL_TIMEOUT_SECONDS", streaming, "stall_timeout_seconds"), "STREAM_STALL_TIMEOUT_SECONDS"), - heartbeat_seconds=_optional_positive_float(_env_or_json(source, "STREAM_HEARTBEAT_SECONDS", streaming, "heartbeat_seconds"), "STREAM_HEARTBEAT_SECONDS"), + heartbeat_seconds=_optional_positive_float(_env_or_json(source, "STREAM_HEARTBEAT_INTERVAL_SECONDS", streaming, "heartbeat_interval_seconds", default=_env_or_json(source, "STREAM_HEARTBEAT_SECONDS", streaming, "heartbeat_seconds")), "STREAM_HEARTBEAT_INTERVAL_SECONDS"), + cancel_upstream_on_disconnect=as_bool(_env_or_json(source, "STREAM_CANCEL_UPSTREAM_ON_DISCONNECT", streaming, "cancel_upstream_on_disconnect", default=True), name="STREAM_CANCEL_UPSTREAM_ON_DISCONNECT"), trace_metrics=as_bool(_env_or_json(source, "STREAM_TRACE_METRICS", streaming, "trace_metrics", default=True), name="STREAM_TRACE_METRICS"), ) diff --git a/src/rotator_library/streaming/transport.py b/src/rotator_library/streaming/transport.py index 7472757df..b7ebd1b43 100644 --- a/src/rotator_library/streaming/transport.py +++ b/src/rotator_library/streaming/transport.py @@ -24,6 +24,17 @@ def format_error(self, event: StreamEvent) -> str: def format_done(self) -> str: return "data: [DONE]\n\n" + def format_heartbeat(self, comment: str = "heartbeat") -> str: + """Return an SSE comment heartbeat frame. + + Comment frames keep HTTP connections active without becoming model + output, so routing/session code must continue treating them as + non-visible stream data. + """ + + safe_comment = comment.replace("\r", " ").replace("\n", " ") + return f": {safe_comment}\n\n" + def is_terminal_event(self, event: StreamEvent) -> bool: return event.event_type in {"completed", "cancelled", "error"} @@ -43,6 +54,11 @@ def format_error(self, event: StreamEvent) -> dict: def format_done(self) -> dict: return {"type": "completed", "payload": {"done": True}} + def format_heartbeat(self) -> dict: + """Return the future WebSocket heartbeat message shape.""" + + return {"type": "heartbeat", "payload": {"visible_output": False}} + def is_terminal_event(self, event: StreamEvent) -> bool: return event.event_type in {"completed", "cancelled", "error"} @@ -61,5 +77,10 @@ def format_error(self, event: StreamEvent) -> str: def format_done(self) -> str: return json.dumps({"event_type": "completed", "done": True}) + "\n" + def format_heartbeat(self) -> str: + """Return a JSONL heartbeat record for test transports.""" + + return json.dumps({"event_type": "heartbeat", "visible_output": False}) + "\n" + def is_terminal_event(self, event: StreamEvent) -> bool: return event.event_type in {"completed", "cancelled", "error"} diff --git a/tests/test_config_stream_settings.py b/tests/test_config_stream_settings.py index 2ef253354..5da2c5141 100644 --- a/tests/test_config_stream_settings.py +++ b/tests/test_config_stream_settings.py @@ -7,7 +7,7 @@ def test_stream_settings_parse_json_values() -> None: config = load_config_from_mapping( - {"streaming": {"ttfb_timeout_seconds": 5, "stall_timeout_seconds": 30, "heartbeat_seconds": 10, "trace_metrics": False}} + {"streaming": {"ttfb_timeout_seconds": 5, "stall_timeout_seconds": 30, "heartbeat_interval_seconds": 10, "cancel_upstream_on_disconnect": False, "trace_metrics": False}} ) settings = get_stream_runtime_settings(config=config, env={}) @@ -15,18 +15,25 @@ def test_stream_settings_parse_json_values() -> None: assert settings.ttfb_timeout_seconds == 5 assert settings.stall_timeout_seconds == 30 assert settings.heartbeat_seconds == 10 + assert settings.cancel_upstream_on_disconnect is False assert settings.trace_metrics is False def test_stream_settings_env_overrides_json_values() -> None: config = load_config_from_mapping({"streaming": {"trace_metrics": True, "heartbeat_seconds": 10}}) - settings = get_stream_runtime_settings(config=config, env={"STREAM_TRACE_METRICS": "false", "STREAM_HEARTBEAT_SECONDS": "2"}) + settings = get_stream_runtime_settings(config=config, env={"STREAM_TRACE_METRICS": "false", "STREAM_HEARTBEAT_INTERVAL_SECONDS": "2"}) assert settings.trace_metrics is False assert settings.heartbeat_seconds == 2 +def test_stream_settings_accept_legacy_heartbeat_env_name() -> None: + settings = get_stream_runtime_settings(env={"STREAM_HEARTBEAT_SECONDS": "3"}) + + assert settings.heartbeat_seconds == 3 + + def test_stream_settings_invalid_boolean_fails_clearly() -> None: with pytest.raises(ExperimentalConfigError): get_stream_runtime_settings(env={"STREAM_TRACE_METRICS": "maybe"}) diff --git a/tests/test_stream_transport.py b/tests/test_stream_transport.py index be377f13f..4b52e1e3f 100644 --- a/tests/test_stream_transport.py +++ b/tests/test_stream_transport.py @@ -10,11 +10,18 @@ def test_sse_formatter_outputs_named_event() -> None: assert "data:" in formatted +def test_sse_formatter_outputs_comment_heartbeat() -> None: + formatted = SSEStreamFormatter().format_heartbeat("keepalive") + + assert formatted == ": keepalive\n\n" + + def test_websocket_formatter_exposes_future_transport_shape() -> None: formatted = WebSocketStreamFormatter().format_event(StreamEvent("delta", data={"text": "hi"})) assert formatted["type"] == "delta" assert formatted["payload"]["transport"] == "sse" + assert WebSocketStreamFormatter().format_heartbeat()["type"] == "heartbeat" def test_jsonl_formatter_outputs_one_line_json() -> None: @@ -22,3 +29,4 @@ def test_jsonl_formatter_outputs_one_line_json() -> None: assert formatted.endswith("\n") assert "metadata" in formatted + assert "heartbeat" in JSONLineStreamFormatter().format_heartbeat() From 664e322dd3824d9873aa55c8c6d4d520a57a04dc Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:21:07 +0200 Subject: [PATCH 131/182] feat(streaming): harden stream lifecycle handling Adds upstream close-on-disconnect and abnormal-exit behavior, configured heartbeat emission, and structured TTFB/stall timeout errors while keeping all hardening knobs disabled by default except upstream cancellation. Tests: pytest tests/test_request_executor_stream_metrics.py --- src/rotator_library/client/streaming.py | 146 +++++++++++++++++- tests/test_request_executor_stream_metrics.py | 109 +++++++++++++ 2 files changed, 254 insertions(+), 1 deletion(-) diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 0298badde..dcf52e414 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -13,7 +13,9 @@ - Client disconnect handling """ +import asyncio import codecs +import contextlib import json import logging import re @@ -24,6 +26,7 @@ from ..core.types import ProcessedChunk from ..core.utils import normalize_usage_for_response from ..streaming import StreamEvent, StreamMonitor, stream_event_from_sse_chunk +from ..streaming.transport import SSEStreamFormatter from ..usage.accounting import UsageRecord, extract_usage_record from ..usage.costs import CostBreakdown, CostCalculator @@ -89,6 +92,10 @@ async def wrap_stream( from ..config.experimental import get_stream_runtime_settings stream_settings = get_stream_runtime_settings() + formatter = SSEStreamFormatter() + upstream_closed = False + stream_cancelled = False + last_heartbeat_at = monitor.metrics.started_at lifecycle_logger = transaction_logger if stream_settings.trace_metrics else None self._log_stream_lifecycle( lifecycle_logger, @@ -100,6 +107,35 @@ async def wrap_stream( # Use manual iteration to allow continue after partial JSON errors stream_iterator = stream.__aiter__() + async def close_upstream(reason: str) -> None: + """Best-effort close for upstream async streams. + + Client disconnects and timeout failures should not leave provider + HTTP streams running in the background. Close failures are logged as + lifecycle metadata only and never replace the original stream error. + """ + + nonlocal upstream_closed + if upstream_closed or not stream_settings.cancel_upstream_on_disconnect: + return + upstream_closed = True + for candidate in (stream_iterator, stream): + try: + closer = getattr(candidate, "aclose", None) + if closer: + await closer() + self._log_stream_lifecycle(lifecycle_logger, "stream_upstream_cancelled", monitor, StreamEvent("cancelled", protocol="openai_chat", data={"reason": reason})) + return + closer = getattr(candidate, "close", None) + if closer: + closer() + self._log_stream_lifecycle(lifecycle_logger, "stream_upstream_cancelled", monitor, StreamEvent("cancelled", protocol="openai_chat", data={"reason": reason})) + return + except Exception as exc: + lib_logger.debug("Failed to close upstream stream: %s", exc) + self._log_stream_lifecycle(lifecycle_logger, "stream_upstream_close_failed", monitor, StreamEvent("error", protocol="openai_chat", data={"reason": reason, "error_type": type(exc).__name__})) + return + try: while True: try: @@ -110,7 +146,35 @@ async def wrap_stream( ) break - chunk = await stream_iterator.__anext__() + next_task = asyncio.create_task(stream_iterator.__anext__()) + try: + while True: + wait_seconds = _next_stream_wait_seconds(monitor, stream_settings, last_heartbeat_at) + done, _ = await asyncio.wait({next_task}, timeout=wait_seconds) + if done: + chunk = next_task.result() + break + + timeout_error = _stream_timeout_error(monitor, stream_settings) + if timeout_error: + next_task.cancel() + with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + await next_task + await close_upstream(timeout_error[0]) + self._log_stream_lifecycle(lifecycle_logger, timeout_error[2], monitor, StreamEvent("error", protocol="openai_chat", data={"error": timeout_error[1]})) + raise StreamedAPIError(timeout_error[1]["message"], data={"error": timeout_error[1]}) + + if _heartbeat_due(monitor, stream_settings, last_heartbeat_at): + heartbeat = formatter.format_heartbeat() + last_heartbeat_at = time.monotonic() + self._log_stream_lifecycle(lifecycle_logger, "stream_heartbeat", monitor, StreamEvent("heartbeat", protocol="openai_chat", visible_output=False)) + yield heartbeat + except Exception: + if not next_task.done(): + next_task.cancel() + with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + await next_task + raise raw_event = StreamEvent( "raw_chunk", @@ -229,10 +293,24 @@ async def wrap_stream( # Not a JSON-related error, re-raise monitor.metrics.error_count += 1 + await close_upstream("stream_exception") raise except StreamedAPIError: # Re-raise for retry loop + await close_upstream("streamed_api_error") + raise + + except asyncio.CancelledError: + stream_cancelled = True + monitor.cancel() + self._log_stream_lifecycle( + lifecycle_logger, + "stream_cancelled", + monitor, + StreamEvent("cancelled", protocol="openai_chat", data={"reason": "task_cancelled"}), + ) + await close_upstream("task_cancelled") raise finally: @@ -297,6 +375,7 @@ async def wrap_stream( yield "data: [DONE]\n\n" elif request and await request.is_disconnected(): + stream_cancelled = True monitor.cancel() self._log_stream_lifecycle( lifecycle_logger, @@ -304,6 +383,9 @@ async def wrap_stream( monitor, StreamEvent("cancelled", protocol="openai_chat"), ) + await close_upstream("client_disconnect") + elif stream_cancelled: + await close_upstream("stream_cancelled") @staticmethod def _log_stream_lifecycle( @@ -542,6 +624,68 @@ def _log_stream_usage_accounting( ) +def _next_stream_wait_seconds( + monitor: StreamMonitor, + settings: Any, + last_heartbeat_at: float, +) -> Optional[float]: + """Return the next active stream policy deadline. + + A pending upstream `__anext__()` task is not cancelled for heartbeats. The + handler waits until the nearest policy deadline, emits heartbeat metadata if + needed, and keeps waiting on the same upstream task. + """ + + candidates: list[float] = [] + now = time.monotonic() + if settings.heartbeat_seconds: + candidates.append(max(0.0, last_heartbeat_at + settings.heartbeat_seconds - now)) + if settings.ttfb_timeout_seconds and monitor.metrics.first_byte_at is None: + candidates.append(max(0.0, monitor.metrics.started_at + settings.ttfb_timeout_seconds - now)) + if settings.stall_timeout_seconds and monitor.metrics.first_byte_at is not None: + last_chunk_at = monitor.metrics.last_chunk_at or monitor.metrics.first_byte_at + candidates.append(max(0.0, last_chunk_at + settings.stall_timeout_seconds - now)) + return min(candidates) if candidates else None + + +def _heartbeat_due(monitor: StreamMonitor, settings: Any, last_heartbeat_at: float) -> bool: + """Return whether a heartbeat should be emitted while upstream is idle.""" + + if not settings.heartbeat_seconds: + return False + return time.monotonic() - last_heartbeat_at >= settings.heartbeat_seconds + + +def _stream_timeout_error(monitor: StreamMonitor, settings: Any) -> Optional[tuple[str, dict[str, Any], str]]: + """Return a structured stream timeout error when a configured deadline expires.""" + + now = time.monotonic() + if settings.ttfb_timeout_seconds and monitor.metrics.first_byte_at is None: + if now - monitor.metrics.started_at >= settings.ttfb_timeout_seconds: + return ( + "ttfb_timeout", + { + "message": "Stream timed out before first byte", + "type": "api_connection", + "details": {"timeout_type": "ttfb", "timeout_seconds": settings.ttfb_timeout_seconds}, + }, + "stream_ttfb_timeout", + ) + if settings.stall_timeout_seconds and monitor.metrics.first_byte_at is not None: + last_chunk_at = monitor.metrics.last_chunk_at or monitor.metrics.first_byte_at + if now - last_chunk_at >= settings.stall_timeout_seconds: + return ( + "stall_timeout", + { + "message": "Stream stalled while waiting for provider data", + "type": "api_connection", + "details": {"timeout_type": "stall", "timeout_seconds": settings.stall_timeout_seconds}, + }, + "stream_stall_timeout", + ) + return None + + class StreamBuffer: """ Buffer for reassembling fragmented JSON in streams. diff --git a/tests/test_request_executor_stream_metrics.py b/tests/test_request_executor_stream_metrics.py index 8917b9008..04b87160b 100644 --- a/tests/test_request_executor_stream_metrics.py +++ b/tests/test_request_executor_stream_metrics.py @@ -1,10 +1,12 @@ from __future__ import annotations import json +import asyncio import pytest from rotator_library.client.streaming import StreamingHandler +from rotator_library.core.errors import StreamedAPIError from rotator_library.transaction_logger import TransactionLogger @@ -13,6 +15,61 @@ async def _chunks(): yield {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}} +class HangingStream: + def __init__(self) -> None: + self.closed = False + + def __aiter__(self): + return self + + async def __anext__(self): + await asyncio.sleep(1) + return {"choices": [{"delta": {"content": "late"}}]} + + async def aclose(self) -> None: + self.closed = True + + +class DelayedStream: + def __init__(self, delay: float = 0.03) -> None: + self.delay = delay + self.index = 0 + self.closed = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index == 0: + self.index += 1 + await asyncio.sleep(self.delay) + return {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + if self.index == 1: + self.index += 1 + return {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}} + raise StopAsyncIteration + + async def aclose(self) -> None: + self.closed = True + + +class FirstThenHangStream(HangingStream): + def __init__(self) -> None: + super().__init__() + self.index = 0 + + async def __anext__(self): + if self.index == 0: + self.index += 1 + return {"id": "chunk_1", "choices": [{"delta": {}}]} + return await super().__anext__() + + +class DisconnectedRequest: + async def is_disconnected(self) -> bool: + return True + + def _trace_passes(log_dir): return [json.loads(line)["pass_name"] for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] @@ -46,3 +103,55 @@ async def test_stream_trace_metrics_can_be_disabled_without_changing_output(tmp_ pass_names = _trace_passes(logger.log_dir) assert "stream_started" not in pass_names assert "stream_metrics_final" not in pass_names + + +@pytest.mark.asyncio +async def test_streaming_handler_closes_upstream_on_client_disconnect(monkeypatch) -> None: + monkeypatch.delenv("STREAM_TTFB_TIMEOUT_SECONDS", raising=False) + stream = HangingStream() + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test", request=DisconnectedRequest())] + + assert chunks == [] + assert stream.closed is True + + +@pytest.mark.asyncio +async def test_streaming_handler_emits_configured_heartbeats(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + monkeypatch.delenv("STREAM_TTFB_TIMEOUT_SECONDS", raising=False) + monkeypatch.delenv("STREAM_STALL_TIMEOUT_SECONDS", raising=False) + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(DelayedStream(), "cred", "openai/gpt-test")] + + assert any(chunk.startswith(": heartbeat") for chunk in chunks) + assert chunks[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_streaming_handler_ttfb_timeout_closes_upstream(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.01") + monkeypatch.delenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", raising=False) + stream = HangingStream() + + with pytest.raises(StreamedAPIError) as exc: + _ = [chunk async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test")] + + assert stream.closed is True + assert exc.value.data["error"]["details"]["timeout_type"] == "ttfb" + + +@pytest.mark.asyncio +async def test_streaming_handler_stall_timeout_after_first_byte(monkeypatch) -> None: + monkeypatch.setenv("STREAM_STALL_TIMEOUT_SECONDS", "0.01") + monkeypatch.delenv("STREAM_TTFB_TIMEOUT_SECONDS", raising=False) + stream = FirstThenHangStream() + chunks = [] + + with pytest.raises(StreamedAPIError) as exc: + async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test"): + chunks.append(chunk) + + assert chunks and chunks[0].startswith("data: ") + assert stream.closed is True + assert exc.value.data["error"]["details"]["timeout_type"] == "stall" From ec7b1aa81c22fb945490a3161450a21f25c26df9 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:22:19 +0200 Subject: [PATCH 132/182] feat(native): support httpx stream transport Extends NativeHTTPTransport.stream_json_lines to use generic httpx-style client.stream responses with line and byte iteration while preserving the custom injected stream_json_lines seam. Tests: pytest tests/test_native_streaming_transport_seam.py --- src/rotator_library/native_provider/http.py | 58 +++++++++++++++-- tests/test_native_streaming_transport_seam.py | 65 +++++++++++++++++++ 2 files changed, 116 insertions(+), 7 deletions(-) diff --git a/src/rotator_library/native_provider/http.py b/src/rotator_library/native_provider/http.py index b322875ea..cebee80fd 100644 --- a/src/rotator_library/native_provider/http.py +++ b/src/rotator_library/native_provider/http.py @@ -5,7 +5,8 @@ from __future__ import annotations -from typing import Any +import json +from typing import Any, AsyncIterator class NativeHTTPTransport: @@ -28,17 +29,60 @@ async def post_json(self, endpoint: str, *, headers: dict[str, str], payload: di return response.json() return response - async def stream_json_lines(self, endpoint: str, *, headers: dict[str, str], payload: dict[str, Any]): + async def stream_json_lines(self, endpoint: str, *, headers: dict[str, str], payload: dict[str, Any]) -> AsyncIterator[Any]: """Yield provider stream chunks from an injected streaming-capable client. - Tests and provider-specific clients can expose `stream_json_lines()` to - avoid binding this foundation to one HTTP client's streaming API. A later - streaming phase can add richer `httpx.stream()` support without changing - native provider executor semantics. + Provider-specific test clients can still expose `stream_json_lines()`. + When a normal `httpx.AsyncClient`-style object is injected, this method + now uses `client.stream()` directly so native streaming has a real HTTP + seam without enabling any provider that has not opted in safely. """ if hasattr(self.client, "stream_json_lines"): async for chunk in self.client.stream_json_lines(endpoint, headers=headers, json=payload): yield chunk return - raise NotImplementedError("Injected native HTTP client does not expose stream_json_lines") + if hasattr(self.client, "stream"): + async with self.client.stream("POST", endpoint, headers=headers, json=payload) as response: + if hasattr(response, "raise_for_status"): + response.raise_for_status() + if hasattr(response, "aiter_lines"): + async for line in response.aiter_lines(): + parsed = _parse_stream_line(line) + if parsed is not None: + yield parsed + return + if hasattr(response, "aiter_bytes"): + buffer = "" + async for chunk in response.aiter_bytes(): + text = chunk.decode("utf-8", errors="replace") if isinstance(chunk, (bytes, bytearray)) else str(chunk) + buffer += text + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + parsed = _parse_stream_line(line) + if parsed is not None: + yield parsed + parsed = _parse_stream_line(buffer) + if parsed is not None: + yield parsed + return + raise NotImplementedError("Injected native HTTP client does not expose streaming support") + + +def _parse_stream_line(line: Any) -> Any: + """Parse one HTTP streaming line while preserving provider sentinels.""" + + if line is None: + return None + text = line.decode("utf-8", errors="replace") if isinstance(line, (bytes, bytearray)) else str(line) + text = text.strip() + if not text: + return None + if text.startswith("data:"): + text = text[len("data:") :].strip() + if text == "[DONE]": + return "[DONE]" + try: + return json.loads(text) + except json.JSONDecodeError: + return text diff --git a/tests/test_native_streaming_transport_seam.py b/tests/test_native_streaming_transport_seam.py index 88b3f7a8a..45480d0b3 100644 --- a/tests/test_native_streaming_transport_seam.py +++ b/tests/test_native_streaming_transport_seam.py @@ -1,5 +1,8 @@ from __future__ import annotations +import pytest + +from rotator_library.native_provider.http import NativeHTTPTransport from rotator_library.native_provider.streaming import native_stream_event_from_formatted, provider_supports_native_streaming from rotator_library.providers.provider_interface import ProviderInterface @@ -23,3 +26,65 @@ def test_native_formatted_sse_chunk_uses_common_stream_event_seam() -> None: assert event.visible_output is True assert event.event_type == "delta" + + +class FakeStreamResponse: + def __init__(self, lines): + self.lines = lines + self.raised = False + + def raise_for_status(self): + self.raised = True + + async def aiter_lines(self): + for line in self.lines: + yield line + + +class FakeStreamContext: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class FakeHttpxClient: + def __init__(self, response): + self.response = response + self.calls = [] + + def stream(self, method, endpoint, *, headers, json): + self.calls.append((method, endpoint, headers, json)) + return FakeStreamContext(self.response) + + +@pytest.mark.asyncio +async def test_native_http_transport_streams_httpx_lines() -> None: + response = FakeStreamResponse(["", 'data: {"delta":"hi"}', "data: [DONE]"]) + client = FakeHttpxClient(response) + + chunks = [chunk async for chunk in NativeHTTPTransport(client).stream_json_lines("https://example.test/stream", headers={"h": "v"}, payload={"p": True})] + + assert client.calls == [("POST", "https://example.test/stream", {"h": "v"}, {"p": True})] + assert response.raised is True + assert chunks == [{"delta": "hi"}, "[DONE]"] + + +class FakeByteResponse: + def raise_for_status(self): + pass + + async def aiter_bytes(self): + yield b'data: {"a": 1}\n\n' + yield b'data: [DONE]\n' + + +@pytest.mark.asyncio +async def test_native_http_transport_streams_httpx_bytes() -> None: + chunks = [chunk async for chunk in NativeHTTPTransport(FakeHttpxClient(FakeByteResponse())).stream_json_lines("url", headers={}, payload={})] + + assert chunks == [{"a": 1}, "[DONE]"] From 1946c7560b13159122775d1f7d41a25a7a33faf8 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:32:32 +0200 Subject: [PATCH 133/182] fix(streaming): preserve retry safety around heartbeats Fixes Phase 8b review findings by polling disconnects while waiting on upstream chunks, forcing upstream close on timeout/abnormal exit, treating heartbeat comments as retry-neutral, passing heartbeats through routed streams, and preserving structured timeout details in terminal stream errors. Tests: pytest tests/test_config_stream_settings.py tests/test_stream_transport.py tests/test_request_executor_stream_metrics.py tests/test_native_streaming_transport_seam.py tests/test_stream_metrics.py tests/test_stream_events.py tests/test_stream_policy.py tests/test_streaming_error_handler.py tests/test_streaming_fallback_policy.py tests/test_streaming_usage_accounting.py tests/test_cooldown_activation.py tests/test_retry_policy.py --- src/rotator_library/client/executor.py | 34 +++++++++++++----- src/rotator_library/client/streaming.py | 35 ++++++++++++++----- src/rotator_library/streaming/policy.py | 11 ++++++ tests/test_request_executor_stream_metrics.py | 34 ++++++++++++++++++ tests/test_stream_policy.py | 4 +++ 5 files changed, 100 insertions(+), 18 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index dfd068785..d4789f8ab 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -77,7 +77,7 @@ from .filters import CredentialFilter from .transforms import ProviderTransforms from .streaming import StreamingHandler -from ..streaming.policy import can_retry_stream_after_error, is_visible_stream_output +from ..streaming.policy import can_retry_stream_after_error, is_stream_heartbeat_or_comment, is_visible_stream_output if TYPE_CHECKING: from ..usage import UsageManager @@ -863,6 +863,9 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy ) try: async for chunk in self._execute_streaming(target_context): + if is_stream_heartbeat_or_comment(chunk): + yield chunk + continue chunk_error_type = _stream_chunk_error_type(chunk) if chunk_error_type and not emitted_output: terminal_error_type = chunk_error_type @@ -1527,13 +1530,15 @@ async def _execute_streaming( context=context, plugin=plugin, ): - last_streamed_chunk = chunk + if not is_stream_heartbeat_or_comment(chunk): + last_streamed_chunk = chunk if _stream_chunk_is_visible_output(chunk): stream_visible_output_emitted = True yield chunk else: async for chunk in base_stream: - last_streamed_chunk = chunk + if not is_stream_heartbeat_or_comment(chunk): + last_streamed_chunk = chunk if _stream_chunk_is_visible_output(chunk): stream_visible_output_emitted = True yield chunk @@ -1586,12 +1591,7 @@ async def _execute_streaming( if not _can_retry_stream_after_error(last_streamed_chunk, self._stream_retry_on_reasoning_only_enabled(), emitted_output=stream_visible_output_emitted): cred_context.mark_failure(classified) - error_data = { - "error": { - "message": "Upstream stream failed after output began", - "type": classified.error_type, - } - } + error_data = _streamed_error_payload(e, classified) for line in self._terminal_stream_error_lines(context, error_data): yield line return @@ -2871,3 +2871,19 @@ def _can_retry_stream_after_error(last_streamed_chunk: Optional[str], allow_reas if emitted_output: return False return can_retry_stream_after_error(last_streamed_chunk, allow_reasoning_only_retry) + + +def _streamed_error_payload(error: StreamedAPIError, classified: Any) -> Dict[str, Any]: + """Return a terminal stream error while preserving structured timeout data.""" + + data = getattr(error, "data", None) + if isinstance(data, dict) and isinstance(data.get("error"), dict): + source = data["error"] + return { + "error": { + "message": source.get("message") or "Upstream stream failed after output began", + "type": source.get("type") or classified.error_type, + "details": source.get("details") or {"status_code": classified.status_code}, + } + } + return {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type, "details": {"status_code": classified.status_code}}} diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index dcf52e414..dde270fe6 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -107,7 +107,7 @@ async def wrap_stream( # Use manual iteration to allow continue after partial JSON errors stream_iterator = stream.__aiter__() - async def close_upstream(reason: str) -> None: + async def close_upstream(reason: str, *, force: bool = False) -> None: """Best-effort close for upstream async streams. Client disconnects and timeout failures should not leave provider @@ -116,7 +116,7 @@ async def close_upstream(reason: str) -> None: """ nonlocal upstream_closed - if upstream_closed or not stream_settings.cancel_upstream_on_disconnect: + if upstream_closed or (not force and not stream_settings.cancel_upstream_on_disconnect): return upstream_closed = True for candidate in (stream_iterator, stream): @@ -150,8 +150,25 @@ async def close_upstream(reason: str) -> None: try: while True: wait_seconds = _next_stream_wait_seconds(monitor, stream_settings, last_heartbeat_at) - done, _ = await asyncio.wait({next_task}, timeout=wait_seconds) - if done: + wait_tasks = {next_task} + disconnect_task = None + if request is not None: + disconnect_task = asyncio.create_task(request.is_disconnected()) + wait_tasks.add(disconnect_task) + done, _ = await asyncio.wait(wait_tasks, timeout=wait_seconds) + if disconnect_task is not None: + if disconnect_task in done and disconnect_task.result(): + stream_cancelled = True + next_task.cancel() + with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + await next_task + await close_upstream("client_disconnect") + return + if not disconnect_task.done(): + disconnect_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await disconnect_task + if next_task in done: chunk = next_task.result() break @@ -160,7 +177,7 @@ async def close_upstream(reason: str) -> None: next_task.cancel() with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): await next_task - await close_upstream(timeout_error[0]) + await close_upstream(timeout_error[0], force=True) self._log_stream_lifecycle(lifecycle_logger, timeout_error[2], monitor, StreamEvent("error", protocol="openai_chat", data={"error": timeout_error[1]})) raise StreamedAPIError(timeout_error[1]["message"], data={"error": timeout_error[1]}) @@ -293,12 +310,12 @@ async def close_upstream(reason: str) -> None: # Not a JSON-related error, re-raise monitor.metrics.error_count += 1 - await close_upstream("stream_exception") + await close_upstream("stream_exception", force=True) raise except StreamedAPIError: # Re-raise for retry loop - await close_upstream("streamed_api_error") + await close_upstream("streamed_api_error", force=True) raise except asyncio.CancelledError: @@ -310,7 +327,7 @@ async def close_upstream(reason: str) -> None: monitor, StreamEvent("cancelled", protocol="openai_chat", data={"reason": "task_cancelled"}), ) - await close_upstream("task_cancelled") + await close_upstream("task_cancelled", force=True) raise finally: @@ -385,7 +402,7 @@ async def close_upstream(reason: str) -> None: ) await close_upstream("client_disconnect") elif stream_cancelled: - await close_upstream("stream_cancelled") + await close_upstream("stream_cancelled", force=True) @staticmethod def _log_stream_lifecycle( diff --git a/src/rotator_library/streaming/policy.py b/src/rotator_library/streaming/policy.py index 6c32af9aa..5973a9747 100644 --- a/src/rotator_library/streaming/policy.py +++ b/src/rotator_library/streaming/policy.py @@ -21,6 +21,8 @@ def can_retry_stream_after_error(last_streamed_chunk: Optional[str], allow_reaso if last_streamed_chunk is None: return True + if is_stream_heartbeat_or_comment(last_streamed_chunk): + return True if not allow_reasoning_only_retry: return False data = _sse_json(last_streamed_chunk, malformed_is_visible=False) @@ -71,6 +73,15 @@ def is_visible_stream_output(chunk: Optional[str], *, protocol: str = "openai_ch return _openai_chat_visible(data) +def is_stream_heartbeat_or_comment(chunk: Optional[str]) -> bool: + """Return true for SSE comment-only frames that must not affect retry state.""" + + if chunk is None: + return False + payload = chunk.strip() + return bool(payload) and all(line.startswith(":") for line in payload.splitlines() if line.strip()) + + _MALFORMED_VISIBLE = object() diff --git a/tests/test_request_executor_stream_metrics.py b/tests/test_request_executor_stream_metrics.py index 04b87160b..49f7fb32c 100644 --- a/tests/test_request_executor_stream_metrics.py +++ b/tests/test_request_executor_stream_metrics.py @@ -70,6 +70,16 @@ async def is_disconnected(self) -> bool: return True +class DelayedDisconnectedRequest: + def __init__(self) -> None: + self.calls = 0 + + async def is_disconnected(self) -> bool: + self.calls += 1 + await asyncio.sleep(0.01) + return self.calls >= 1 + + def _trace_passes(log_dir): return [json.loads(line)["pass_name"] for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] @@ -116,6 +126,18 @@ async def test_streaming_handler_closes_upstream_on_client_disconnect(monkeypatc assert stream.closed is True +@pytest.mark.asyncio +async def test_streaming_handler_closes_upstream_when_disconnect_happens_during_wait(monkeypatch) -> None: + monkeypatch.delenv("STREAM_TTFB_TIMEOUT_SECONDS", raising=False) + monkeypatch.delenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", raising=False) + stream = HangingStream() + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test", request=DelayedDisconnectedRequest())] + + assert chunks == [] + assert stream.closed is True + + @pytest.mark.asyncio async def test_streaming_handler_emits_configured_heartbeats(monkeypatch) -> None: monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") @@ -141,6 +163,18 @@ async def test_streaming_handler_ttfb_timeout_closes_upstream(monkeypatch) -> No assert exc.value.data["error"]["details"]["timeout_type"] == "ttfb" +@pytest.mark.asyncio +async def test_stream_timeout_closes_upstream_even_when_disconnect_close_disabled(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.01") + monkeypatch.setenv("STREAM_CANCEL_UPSTREAM_ON_DISCONNECT", "false") + stream = HangingStream() + + with pytest.raises(StreamedAPIError): + _ = [chunk async for chunk in StreamingHandler().wrap_stream(stream, "cred", "openai/gpt-test")] + + assert stream.closed is True + + @pytest.mark.asyncio async def test_streaming_handler_stall_timeout_after_first_byte(monkeypatch) -> None: monkeypatch.setenv("STREAM_STALL_TIMEOUT_SECONDS", "0.01") diff --git a/tests/test_stream_policy.py b/tests/test_stream_policy.py index 729c7f0a4..49d2a354c 100644 --- a/tests/test_stream_policy.py +++ b/tests/test_stream_policy.py @@ -15,6 +15,10 @@ def test_reasoning_only_retry_policy_is_preserved() -> None: assert compat_retry_policy(reasoning_chunk, True) is True +def test_heartbeat_comments_do_not_block_stream_retry() -> None: + assert can_retry_stream_after_error(': heartbeat\n\n', False) is True + + def test_visible_output_detection_for_chat_chunks() -> None: assert is_visible_stream_output('data: {"choices":[{"delta":{"content":"hello"}}]}\n\n') is True assert is_visible_stream_output('data: {"choices":[{"delta":{"tool_calls":[{"id":"call_1"}]}}]}\n\n') is True From 24e988f1692507507fb8066944a68ed37d2e9416 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:43:05 +0200 Subject: [PATCH 134/182] fix(streaming): harden native stream sentinels Ignores native SSE comments, handles native stream done sentinels without handing raw strings to the chat stream wrapper, passes through formatted SSE chunks safely, and preserves timeout details when retry-safe pre-output stream timeouts exhaust attempts. Tests: pytest tests/test_config_stream_settings.py tests/test_stream_transport.py tests/test_request_executor_stream_metrics.py tests/test_native_streaming_transport_seam.py tests/test_stream_metrics.py tests/test_stream_events.py tests/test_stream_policy.py tests/test_streaming_error_handler.py tests/test_streaming_fallback_policy.py tests/test_streaming_usage_accounting.py tests/test_cooldown_activation.py tests/test_retry_policy.py tests/test_protocol_openai_chat.py --- src/rotator_library/client/executor.py | 20 +++++++++++++ src/rotator_library/client/streaming.py | 29 +++++++++++++++++++ .../native_provider/executor.py | 2 ++ src/rotator_library/native_provider/http.py | 2 ++ src/rotator_library/protocols/openai_chat.py | 4 +-- tests/test_native_streaming_transport_seam.py | 2 +- tests/test_request_executor_stream_metrics.py | 11 +++++++ tests/test_streaming_fallback_policy.py | 12 ++++++++ 8 files changed, 79 insertions(+), 3 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index d4789f8ab..97ecff96d 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1071,6 +1071,7 @@ async def _execute_non_streaming( retry_state = RetryState() last_exception: Optional[Exception] = None + last_stream_error_payload: Optional[Dict[str, Any]] = None while time.time() < deadline: # Check for untried credentials @@ -1548,6 +1549,7 @@ async def _execute_streaming( last_exception = e original = getattr(e, "data", e) classified = classify_error(original, provider) + last_stream_error_payload = _streamed_error_payload(e, classified) if _can_start_stream_provider_cooldown( last_streamed_chunk, emitted_output=stream_visible_output_emitted, @@ -1859,6 +1861,8 @@ async def _execute_streaming( # All credentials exhausted or timeout error_accumulator.timeout_occurred = time.time() >= deadline error_data = error_accumulator.build_client_error_response() + if last_stream_error_payload: + _merge_stream_error_details(error_data, last_stream_error_payload) for line in self._terminal_stream_error_lines(context, error_data): yield line @@ -2887,3 +2891,19 @@ def _streamed_error_payload(error: StreamedAPIError, classified: Any) -> Dict[st } } return {"error": {"message": "Upstream stream failed after output began", "type": classified.error_type, "details": {"status_code": classified.status_code}}} + + +def _merge_stream_error_details(error_data: Dict[str, Any], stream_error: Dict[str, Any]) -> None: + """Preserve structured stream timeout metadata in aggregate stream errors.""" + + target = error_data.get("error") if isinstance(error_data, dict) else None + source = stream_error.get("error") if isinstance(stream_error, dict) else None + if not isinstance(target, dict) or not isinstance(source, dict): + return + details = source.get("details") + if not isinstance(details, dict) or "timeout_type" not in details: + return + merged = dict(target.get("details") or {}) if isinstance(target.get("details"), dict) else {} + merged.update(details) + target["details"] = merged + target["type"] = source.get("type") or target.get("type") diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index dde270fe6..4361e49a0 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -483,6 +483,13 @@ def _process_chunk( ProcessedChunk with SSE string and metadata """ # Convert chunk to dict + if isinstance(chunk, str): + stripped = chunk.strip() + if stripped == "[DONE]" or stripped == "data: [DONE]": + return ProcessedChunk(sse_string="data: [DONE]\n\n") + if stripped.startswith("data:") or stripped.startswith("event:") or stripped.startswith(":"): + usage = _usage_from_sse_string(chunk) + return ProcessedChunk(sse_string=chunk if chunk.endswith("\n\n") else f"{chunk}\n\n", usage=usage) if hasattr(chunk, "model_dump"): chunk_dict = chunk.model_dump() elif hasattr(chunk, "dict"): @@ -703,6 +710,28 @@ def _stream_timeout_error(monitor: StreamMonitor, settings: Any) -> Optional[tup return None +def _usage_from_sse_string(chunk: str) -> Optional[dict[str, Any]]: + """Extract usage from already formatted SSE chunks when native streams pass through.""" + + text = chunk.strip() + data_lines: list[str] = [] + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + if not data_lines: + return None + payload = "\n".join(data_lines).strip() + if not payload or payload == "[DONE]": + return None + try: + data = json.loads(payload) + except json.JSONDecodeError: + return None + usage = data.get("usage") if isinstance(data, dict) else None + return usage if isinstance(usage, dict) else None + + class StreamBuffer: """ Buffer for reassembling fragmented JSON in streams. diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index 92bdb3d49..8e31ace4b 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -125,6 +125,8 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte self._trace(context, "parsed_native_unified_stream_event", event, direction="stream", stage="protocol", snapshot=False) event_payload = stream_event_payload(event) self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") + if event.type == "done": + break await cache_engine.extract("stream_event", event_payload, context.field_cache_context(), transaction_logger=logger) self._trace( context, diff --git a/src/rotator_library/native_provider/http.py b/src/rotator_library/native_provider/http.py index cebee80fd..a5f7118e7 100644 --- a/src/rotator_library/native_provider/http.py +++ b/src/rotator_library/native_provider/http.py @@ -78,6 +78,8 @@ def _parse_stream_line(line: Any) -> Any: text = text.strip() if not text: return None + if text.startswith(":"): + return None if text.startswith("data:"): text = text[len("data:") :].strip() if text == "[DONE]": diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index 119572fcd..75c86d39a 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -207,10 +207,10 @@ def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = N ) def format_stream_event(self, unified_event: UnifiedStreamEvent, context: ProtocolContext | None = None) -> Any: - if unified_event.raw is not None: - return deepcopy(unified_event.raw) if unified_event.type == "done": return "data: [DONE]\n\n" + if unified_event.raw is not None: + return deepcopy(unified_event.raw) return f"data: {json.dumps(unified_event.to_dict())}\n\n" def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: diff --git a/tests/test_native_streaming_transport_seam.py b/tests/test_native_streaming_transport_seam.py index 45480d0b3..cdb5fde1b 100644 --- a/tests/test_native_streaming_transport_seam.py +++ b/tests/test_native_streaming_transport_seam.py @@ -64,7 +64,7 @@ def stream(self, method, endpoint, *, headers, json): @pytest.mark.asyncio async def test_native_http_transport_streams_httpx_lines() -> None: - response = FakeStreamResponse(["", 'data: {"delta":"hi"}', "data: [DONE]"]) + response = FakeStreamResponse(["", ": heartbeat", 'data: {"delta":"hi"}', "data: [DONE]"]) client = FakeHttpxClient(response) chunks = [chunk async for chunk in NativeHTTPTransport(client).stream_json_lines("https://example.test/stream", headers={"h": "v"}, payload={"p": True})] diff --git a/tests/test_request_executor_stream_metrics.py b/tests/test_request_executor_stream_metrics.py index 49f7fb32c..42462b0f3 100644 --- a/tests/test_request_executor_stream_metrics.py +++ b/tests/test_request_executor_stream_metrics.py @@ -189,3 +189,14 @@ async def test_streaming_handler_stall_timeout_after_first_byte(monkeypatch) -> assert chunks and chunks[0].startswith("data: ") assert stream.closed is True assert exc.value.data["error"]["details"]["timeout_type"] == "stall" + + +@pytest.mark.asyncio +async def test_streaming_handler_passes_through_formatted_sse_chunks(monkeypatch) -> None: + async def formatted_stream(): + yield 'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(formatted_stream(), "cred", "openai/gpt-test")] + + assert chunks[0] == 'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' + assert chunks[-1] == "data: [DONE]\n\n" diff --git a/tests/test_streaming_fallback_policy.py b/tests/test_streaming_fallback_policy.py index d2088f674..7f4d656e3 100644 --- a/tests/test_streaming_fallback_policy.py +++ b/tests/test_streaming_fallback_policy.py @@ -5,6 +5,7 @@ import pytest +from rotator_library.client import executor as executor_module from rotator_library.client.executor import RequestExecutor from rotator_library.core.types import RequestContext from rotator_library.routing import parse_route_target @@ -170,3 +171,14 @@ async def fake_stream(self, context): exhausted = [entry for entry in entries if entry["pass_name"] == "routing_fallback_exhausted"][-1] assert exhausted["metadata"]["fallback_targets"][0]["message"] == "" assert exhausted["metadata"]["streaming_policy"] == "never" + + +def test_stream_timeout_details_merge_into_aggregate_error() -> None: + error_data = {"error": {"type": "proxy_error", "details": {"attempts": 1}}} + stream_error = {"error": {"type": "api_connection", "details": {"timeout_type": "ttfb", "timeout_seconds": 0.1}}} + + executor_module._merge_stream_error_details(error_data, stream_error) + + assert error_data["error"]["type"] == "api_connection" + assert error_data["error"]["details"]["attempts"] == 1 + assert error_data["error"]["details"]["timeout_type"] == "ttfb" From 9e082a52a382d05c200c025473fb150f1493d25e Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 14:52:05 +0200 Subject: [PATCH 135/182] fix(streaming): initialize stream timeout state Moves aggregate stream timeout state into the streaming executor, avoids duplicate direct DONE sentinels, and updates native streaming tests for the corrected done-event break behavior. Tests: pytest tests/test_config_stream_settings.py tests/test_stream_transport.py tests/test_request_executor_stream_metrics.py tests/test_native_streaming_transport_seam.py tests/test_native_provider_streaming.py tests/test_stream_metrics.py tests/test_stream_events.py tests/test_stream_policy.py tests/test_streaming_error_handler.py tests/test_streaming_fallback_policy.py tests/test_streaming_usage_accounting.py tests/test_cooldown_activation.py tests/test_retry_policy.py tests/test_protocol_openai_chat.py --- src/rotator_library/client/executor.py | 2 +- src/rotator_library/client/streaming.py | 5 ++++- tests/test_native_provider_streaming.py | 4 ++-- tests/test_request_executor_stream_metrics.py | 10 ++++++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 97ecff96d..054811f60 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1071,7 +1071,6 @@ async def _execute_non_streaming( retry_state = RetryState() last_exception: Optional[Exception] = None - last_stream_error_payload: Optional[Dict[str, Any]] = None while time.time() < deadline: # Check for untried credentials @@ -1335,6 +1334,7 @@ async def _execute_streaming( retry_state = RetryState() last_exception: Optional[Exception] = None + last_stream_error_payload: Optional[Dict[str, Any]] = None try: while time.time() < deadline: diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 4361e49a0..7e4107422 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -217,6 +217,9 @@ async def close_upstream(reason: str, *, force: bool = False) -> None: has_tool_calls, model, ) + if not processed.sse_string: + stream_completed = True + break self._collect_session_response_anchors( processed.sse_string, assistant_parts, @@ -486,7 +489,7 @@ def _process_chunk( if isinstance(chunk, str): stripped = chunk.strip() if stripped == "[DONE]" or stripped == "data: [DONE]": - return ProcessedChunk(sse_string="data: [DONE]\n\n") + return ProcessedChunk(sse_string="", finish_reason="stop") if stripped.startswith("data:") or stripped.startswith("event:") or stripped.startswith(":"): usage = _usage_from_sse_string(chunk) return ProcessedChunk(sse_string=chunk if chunk.endswith("\n\n") else f"{chunk}\n\n", usage=usage) diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index e35fcb144..138ae2972 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -46,7 +46,7 @@ async def test_native_provider_stream_traces_and_yields_formatted_events(tmp_pat events = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(client))] - assert events == chunks + assert events == [chunks[0]] assert client.calls[0]["json"]["stream"] is True pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] assert "native_provider_stream_request" in pass_names @@ -54,7 +54,7 @@ async def test_native_provider_stream_traces_and_yields_formatted_events(tmp_pat assert pass_names.count("parsed_native_stream_event") == 2 assert "after_field_cache_extraction" in pass_names assert "after_field_cache_stream_extraction" in pass_names - assert pass_names.count("formatted_client_stream_event") == 2 + assert pass_names.count("formatted_client_stream_event") == 1 trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") assert "opaque-vendor-state" not in trace_text diff --git a/tests/test_request_executor_stream_metrics.py b/tests/test_request_executor_stream_metrics.py index 42462b0f3..db2301b96 100644 --- a/tests/test_request_executor_stream_metrics.py +++ b/tests/test_request_executor_stream_metrics.py @@ -200,3 +200,13 @@ async def formatted_stream(): assert chunks[0] == 'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' assert chunks[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_streaming_handler_does_not_duplicate_direct_done_sentinel(monkeypatch) -> None: + async def done_stream(): + yield "data: [DONE]\n\n" + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(done_stream(), "cred", "openai/gpt-test")] + + assert chunks == ["data: [DONE]\n\n"] From f13a3b82da9498c29f802a59709a24196e3561aa Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:00:22 +0200 Subject: [PATCH 136/182] docs(experimental): plan usage cost correction Adds the Phase 9b corrective plan for provider-reported cost preservation, SSE cost events, native/Responses cost trace hardening, and honest quota snapshot scope. Tests: not run (planning document only) --- .../phase-9b-usage-cost-corrections.md | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 docs/experimental/phase-9b-usage-cost-corrections.md diff --git a/docs/experimental/phase-9b-usage-cost-corrections.md b/docs/experimental/phase-9b-usage-cost-corrections.md new file mode 100644 index 000000000..8cff0bd49 --- /dev/null +++ b/docs/experimental/phase-9b-usage-cost-corrections.md @@ -0,0 +1,76 @@ +# Phase 9b: Provider-Reported Cost, SSE Cost Events, And Quota Snapshot Hardening + +## Goal + +Correct Phase 9 validation findings. Phase 9 normalized token usage and added advisory cost calculation, but provider-reported costs and streaming cost events are not yet first-class runtime accounting inputs. + +## Non-Goals + +- Do not replace `UsageManager`, JSON persistence, or window/limit engines. +- Do not introduce SQLite or any new database. +- Do not make advisory pricing authoritative when providers report actual cost. +- Do not build an admin quota UI. +- Do not commit user-facing reports. +- Do not alter stream fallback/cooldown safety from Phases 6b-8b. + +## Current State + +- `protocols.types.CostDetails` exists and some protocol adapters populate it. +- `UsageRecord` has token buckets but no first-class provider-reported cost fields. +- `extract_usage_record()` ignores `cost_details`, `cost`, `total_cost`, and protocol `Usage.cost`. +- `CostCalculator.calculate()` prefers advisory pricing and drops actual provider cost values. +- `StreamingHandler` does not parse SSE comment/event cost frames. +- `mark_success(approx_cost=...)` can already store an approximate cost total without changing `UsageManager` shape. + +## Implementation Plan + +1. Extend `UsageRecord` with provider-reported cost fields. + - `provider_reported_cost: Optional[float]` + - `cost_currency: str = "USD"` + - `cost_source: Optional[str]` + - Include them in `to_dict()`. + +2. Extract provider-reported costs from provider shapes. + - OpenAI-like `cost_details.total_cost`, `cost_details.cost`, `cost`, `total_cost`, and `provider_reported_cost`. + - Anthropic/Gemini generic cost fields and `costMetadata`. + - Protocol `Usage.cost` values. + +3. Make provider-reported actual cost win in `CostCalculator`. + - `skip_cost_calculation` remains the highest-priority behavior. + - Add `provider_reported_cost` to `CostBreakdown` and make `total_cost` use it without double counting. + - Preserve advisory pricing paths when no provider-reported cost exists. + +4. Parse SSE cost comment/event frames. + - Support `: cost {json}`, `: cost 0.001`, and `event: cost` frames. + - Keep cost comments non-visible and out of session anchors. + - Final provider usage cost overrides earlier cost comments. + +5. Extend formatted SSE usage extraction. + - Ensure formatted SSE chunks with `usage.cost_details` flow through `UsageRecord` and cost calculation. + +6. Harden native and Responses cost traces. + - Native usage traces should include cost breakdown. + - Responses usage traces should show provider-reported cost when present. + +7. Keep quota snapshots honest. + - Do not invent cost totals if usage state does not store them. + - Document/request-token-window scope explicitly. + +## Tests + +- Provider-reported cost extraction tests in `test_usage_accounting.py`. +- Provider-reported cost precedence tests in `test_usage_costs.py`. +- SSE cost comment/event tests in `test_streaming_usage_accounting.py`. +- Native/Responses trace cost tests. +- Quota snapshot honesty tests. +- Phase 8b stream regression subset. + +## Acceptance Criteria + +- Provider-reported actual cost is preserved in `UsageRecord`. +- `CostCalculator` uses provider-reported cost before advisory pricing, while skip-cost still wins. +- SSE cost comments/events update streaming cost accounting and `mark_success(approx_cost=...)`. +- Cost comments remain non-visible for fallback/retry/session purposes. +- Native and Responses usage traces include provider-reported cost when present. +- Quota snapshots remain read-only request/token window reports and do not invent unsupported cost totals. +- Focused tests and dual-agent review pass. From 5d8da9c594b5783450c8d10f4a3fb8c4875c4b00 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:02:31 +0200 Subject: [PATCH 137/182] feat(usage): preserve provider reported costs Adds provider-reported cost fields to UsageRecord and extracts actual provider cost metadata from OpenAI-like, Gemini, and protocol Usage shapes without changing token accounting behavior. Tests: pytest tests/test_usage_accounting.py --- src/rotator_library/usage/accounting.py | 61 +++++++++++++++++++++++++ tests/test_usage_accounting.py | 44 ++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/src/rotator_library/usage/accounting.py b/src/rotator_library/usage/accounting.py index f045d9c6f..04247e849 100644 --- a/src/rotator_library/usage/accounting.py +++ b/src/rotator_library/usage/accounting.py @@ -36,6 +36,9 @@ class UsageRecord: source: str = "unknown" provider: Optional[str] = None model: Optional[str] = None + provider_reported_cost: Optional[float] = None + cost_currency: str = "USD" + cost_source: Optional[str] = None metadata: dict[str, Any] = field(default_factory=dict) @property @@ -74,6 +77,9 @@ def to_dict(self) -> dict[str, Any]: "source": self.source, "provider": self.provider, "model": self.model, + "provider_reported_cost": self.provider_reported_cost, + "cost_currency": self.cost_currency, + "cost_source": self.cost_source, "metadata": serialize_value(self.metadata), } @@ -115,6 +121,9 @@ def _unwrap_usage(value: Any) -> Any: def _as_dict(value: Any) -> dict[str, Any]: if isinstance(value, dict): return value + if hasattr(value, "to_dict"): + dumped = value.to_dict() + return dumped if isinstance(dumped, dict) else {} if hasattr(value, "model_dump"): dumped = value.model_dump() return dumped if isinstance(dumped, dict) else {} @@ -141,6 +150,12 @@ def _as_dict(value: Any) -> dict[str, Any]: "cache_write_tokens", "reasoning_tokens", "thinking_tokens", + "cost", + "cost_details", + "costMetadata", + "provider_reported_cost", + "total_cost", + "currency", ): if hasattr(value, key): result[key] = getattr(value, key) @@ -175,6 +190,7 @@ def _from_openai_like_usage(data: dict[str, Any], *, provider: Optional[str], mo if reasoning and completion_tokens >= reasoning: completion_tokens -= reasoning input_tokens = max(0, prompt_tokens - cache_read) + cost = _extract_cost(data) return UsageRecord( input_tokens=input_tokens, completion_tokens=completion_tokens, @@ -185,6 +201,9 @@ def _from_openai_like_usage(data: dict[str, Any], *, provider: Optional[str], mo source=source, provider=provider, model=model, + provider_reported_cost=cost[0], + cost_currency=cost[1], + cost_source=cost[2], metadata={"shape": "openai_like"}, ) @@ -197,6 +216,7 @@ def _from_anthropic_usage(data: dict[str, Any], *, provider: Optional[str], mode reasoning = _int(data.get("reasoning_tokens", data.get("thinking_tokens", 0))) if reasoning and output_tokens >= reasoning: output_tokens -= reasoning + cost = _extract_cost(data) return UsageRecord( input_tokens=input_tokens, completion_tokens=output_tokens, @@ -207,6 +227,9 @@ def _from_anthropic_usage(data: dict[str, Any], *, provider: Optional[str], mode source=source, provider=provider, model=model, + provider_reported_cost=cost[0], + cost_currency=cost[1], + cost_source=cost[2], metadata={"shape": "anthropic"}, ) @@ -218,6 +241,7 @@ def _from_gemini_usage(data: dict[str, Any], *, provider: Optional[str], model: completion = _int(data.get("candidatesTokenCount", data.get("completion_tokens", 0))) if reasoning and completion >= reasoning: completion -= reasoning + cost = _extract_cost(data) return UsageRecord( input_tokens=max(0, prompt_tokens - cache_read), completion_tokens=completion, @@ -227,10 +251,38 @@ def _from_gemini_usage(data: dict[str, Any], *, provider: Optional[str], model: source=source, provider=provider, model=model, + provider_reported_cost=cost[0], + cost_currency=cost[1], + cost_source=cost[2], metadata={"shape": "gemini"}, ) +def _extract_cost(data: dict[str, Any]) -> tuple[Optional[float], str, Optional[str]]: + """Extract actual provider-reported cost without guessing prices. + + Advisory pricing belongs in `usage.costs`. This helper only preserves cost + values explicitly reported by a provider or protocol adapter. + """ + + cost_payload = _as_dict(data.get("cost") or data.get("cost_details") or data.get("costMetadata") or {}) + raw_cost = ( + cost_payload.get("provider_reported_cost") + if cost_payload + else None + ) + if raw_cost is None and cost_payload: + raw_cost = cost_payload.get("total_cost", cost_payload.get("total")) + if raw_cost is None and cost_payload: + raw_cost = cost_payload.get("cost") + if raw_cost is None: + raw_cost = data.get("provider_reported_cost", data.get("total_cost", data.get("cost"))) + cost_value = _float_or_none(raw_cost) + currency = str(cost_payload.get("currency") or data.get("currency") or "USD") + source = cost_payload.get("source") or ("provider_reported" if cost_value is not None else None) + return cost_value, currency, str(source) if source else None + + def _looks_like_gemini(data: dict[str, Any]) -> bool: return any(key in data for key in ("promptTokenCount", "candidatesTokenCount", "thoughtsTokenCount", "cachedContentTokenCount")) @@ -244,3 +296,12 @@ def _int(value: Any) -> int: return max(0, int(value or 0)) except (TypeError, ValueError): return 0 + + +def _float_or_none(value: Any) -> Optional[float]: + if value is None: + return None + try: + return max(0.0, float(value)) + except (TypeError, ValueError): + return None diff --git a/tests/test_usage_accounting.py b/tests/test_usage_accounting.py index d74db9d6d..018793045 100644 --- a/tests/test_usage_accounting.py +++ b/tests/test_usage_accounting.py @@ -2,6 +2,7 @@ from types import SimpleNamespace +from rotator_library.protocols.types import CostDetails, Usage from rotator_library.usage.accounting import UsageRecord, extract_usage_record @@ -29,6 +30,22 @@ def test_openai_dict_usage_extracts_cache_and_reasoning_without_double_counting( assert record.raw_total_tokens == 130 +def test_openai_usage_extracts_provider_reported_cost_details() -> None: + record = extract_usage_record( + { + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "cost_details": {"total_cost": 0.0123, "currency": "EUR", "source": "provider_usage"}, + } + } + ) + + assert record.provider_reported_cost == 0.0123 + assert record.cost_currency == "EUR" + assert record.cost_source == "provider_usage" + + def test_openai_object_usage_extracts_attributes() -> None: response = SimpleNamespace( usage=SimpleNamespace( @@ -85,6 +102,33 @@ def test_gemini_usage_metadata_extracts_thought_and_cached_tokens() -> None: assert record.metadata["shape"] == "gemini" +def test_gemini_usage_extracts_cost_metadata() -> None: + record = extract_usage_record( + { + "usageMetadata": { + "promptTokenCount": 1, + "totalTokenCount": 1, + "costMetadata": {"cost": "0.004", "currency": "USD", "source": "gemini_usage"}, + } + } + ) + + assert record.provider_reported_cost == 0.004 + assert record.cost_source == "gemini_usage" + + +def test_protocol_usage_cost_details_are_preserved() -> None: + record = extract_usage_record( + Usage(input_tokens=2, output_tokens=3, cost=CostDetails(provider_reported_cost=0.02, currency="GBP", source="protocol_usage")) + ) + + assert record.input_tokens == 2 + assert record.completion_tokens == 3 + assert record.provider_reported_cost == 0.02 + assert record.cost_currency == "GBP" + assert record.cost_source == "protocol_usage" + + def test_responses_usage_extracts_nested_details() -> None: record = extract_usage_record( { From aa9ba903fa0586aa36720568e223975997f823a5 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:03:20 +0200 Subject: [PATCH 138/182] feat(usage): prefer reported provider costs Updates CostCalculator to preserve actual provider-reported costs ahead of advisory pricing while keeping skip_cost_calculation as the highest-priority override. Tests: pytest tests/test_usage_costs.py tests/test_usage_accounting.py --- src/rotator_library/usage/costs.py | 11 +++++++++++ tests/test_usage_costs.py | 21 +++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/rotator_library/usage/costs.py b/src/rotator_library/usage/costs.py index 12c850ea1..6a37aeb82 100644 --- a/src/rotator_library/usage/costs.py +++ b/src/rotator_library/usage/costs.py @@ -40,12 +40,15 @@ class CostBreakdown: cache_write_cost: float = 0.0 output_cost: float = 0.0 reasoning_cost: float = 0.0 + provider_reported_cost: float = 0.0 currency: str = "USD" pricing_source: str = "unavailable" metadata: dict[str, Any] = field(default_factory=dict) @property def total_cost(self) -> float: + if self.provider_reported_cost: + return self.provider_reported_cost return self.input_cost + self.cache_read_cost + self.cache_write_cost + self.output_cost + self.reasoning_cost def to_dict(self) -> dict[str, Any]: @@ -55,6 +58,7 @@ def to_dict(self) -> dict[str, Any]: "cache_write_cost": self.cache_write_cost, "output_cost": self.output_cost, "reasoning_cost": self.reasoning_cost, + "provider_reported_cost": self.provider_reported_cost, "total_cost": self.total_cost, "currency": self.currency, "pricing_source": self.pricing_source, @@ -76,6 +80,13 @@ def calculate(self, usage: UsageRecord, *, model: str, response: Any = None, pro if self.provider_plugin and getattr(self.provider_plugin, "skip_cost_calculation", False): return CostBreakdown(pricing_source="skipped", metadata={"reason": "provider_skip_cost_calculation"}) + if usage.provider_reported_cost is not None: + return CostBreakdown( + provider_reported_cost=usage.provider_reported_cost, + currency=usage.cost_currency, + pricing_source=usage.cost_source or "provider_reported", + metadata={"actual_provider_reported": True}, + ) pricing = self._provider_pricing(model) if pricing: return _calculate_from_pricing(usage, pricing) diff --git a/tests/test_usage_costs.py b/tests/test_usage_costs.py index 3dce73ed0..0ffe6ec5d 100644 --- a/tests/test_usage_costs.py +++ b/tests/test_usage_costs.py @@ -40,6 +40,27 @@ def test_skip_cost_provider_returns_zero_skipped_breakdown() -> None: assert cost.pricing_source == "skipped" +def test_provider_reported_cost_wins_over_advisory_pricing() -> None: + usage = UsageRecord(input_tokens=10, completion_tokens=10, provider_reported_cost=0.123, cost_currency="EUR", cost_source="provider_actual") + + cost = CostCalculator(provider_plugin=PricingProvider(), use_litellm_fallback=False).calculate(usage, model="test") + + assert cost.total_cost == 0.123 + assert cost.provider_reported_cost == 0.123 + assert cost.currency == "EUR" + assert cost.pricing_source == "provider_actual" + assert cost.input_cost == 0.0 + + +def test_skip_cost_still_wins_over_provider_reported_cost() -> None: + usage = UsageRecord(provider_reported_cost=0.123) + + cost = CostCalculator(provider_plugin=SkipCostProvider()).calculate(usage, model="test") + + assert cost.total_cost == 0.0 + assert cost.pricing_source == "skipped" + + def test_missing_pricing_returns_unavailable_zero() -> None: cost = CostCalculator(use_litellm_fallback=False).calculate(UsageRecord(input_tokens=100), model="unknown-model") From bf13508c708712c921ca7a84033cc458c35c05e5 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:05:05 +0200 Subject: [PATCH 139/182] feat(streaming): account for SSE cost events Parses provider-reported cost from SSE cost comments and event frames, carries it through streaming usage records, and lets final provider usage cost override earlier cost comments. Tests: pytest tests/test_streaming_usage_accounting.py tests/test_request_executor_stream_metrics.py tests/test_stream_policy.py --- src/rotator_library/client/streaming.py | 75 +++++++++++++++++++++++- tests/test_streaming_usage_accounting.py | 48 +++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 7e4107422..7fe437581 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -20,6 +20,7 @@ import logging import re import time +from dataclasses import replace from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, TYPE_CHECKING from ..core.errors import StreamedAPIError, CredentialNeedsReauthError @@ -211,12 +212,15 @@ async def close_upstream(reason: str, *, force: bool = False) -> None: error_buffer.reset() # Process chunk + cost_event_record = _usage_record_from_sse_cost_chunk(chunk, model=model) processed = self._process_chunk( chunk, accumulated_finish_reason, has_tool_calls, model, ) + if cost_event_record.provider_reported_cost is not None: + usage_record = _merge_usage_cost(usage_record, cost_event_record) if not processed.sse_string: stream_completed = True break @@ -247,11 +251,14 @@ async def close_upstream(reason: str, *, force: bool = False) -> None: # Only update if not already tool_calls (highest priority) accumulated_finish_reason = processed.finish_reason if processed.usage and isinstance(processed.usage, dict): - usage_record = extract_usage_record( + next_usage_record = extract_usage_record( processed.usage, model=model, source="stream_final_chunk", ) + if next_usage_record.provider_reported_cost is None and usage_record.provider_reported_cost is not None: + next_usage_record = _merge_usage_cost(next_usage_record, usage_record) + usage_record = next_usage_record prompt_tokens = usage_record.input_tokens + usage_record.cache_read_tokens completion_tokens = usage_record.completion_tokens thinking_tokens = usage_record.reasoning_tokens @@ -735,6 +742,72 @@ def _usage_from_sse_string(chunk: str) -> Optional[dict[str, Any]]: return usage if isinstance(usage, dict) else None +def _usage_record_from_sse_cost_chunk(chunk: Any, *, model: str) -> UsageRecord: + """Extract provider-reported stream cost from SSE comments/events.""" + + if not isinstance(chunk, str): + return UsageRecord(source="stream_cost_event", model=model) + cost_payload = _sse_cost_payload(chunk) + if cost_payload is None: + return UsageRecord(source="stream_cost_event", model=model) + if isinstance(cost_payload, (int, float, str)): + cost_payload = {"provider_reported_cost": cost_payload, "source": "sse_cost"} + if not isinstance(cost_payload, dict): + return UsageRecord(source="stream_cost_event", model=model) + return extract_usage_record( + {"usage": {"provider_reported_cost": cost_payload.get("provider_reported_cost", cost_payload.get("total_cost", cost_payload.get("cost"))), "currency": cost_payload.get("currency", "USD"), "cost_details": cost_payload}}, + model=model, + source="stream_cost_event", + ) + + +def _sse_cost_payload(chunk: str) -> Any: + """Parse `: cost ...` comments and `event: cost` frames.""" + + event_type: Optional[str] = None + data_lines: list[str] = [] + for line in chunk.strip().splitlines(): + stripped = line.strip() + if stripped.startswith(":"): + comment = stripped[1:].strip() + if comment.startswith("cost"): + return _parse_cost_text(comment[4:].strip()) + continue + if stripped.startswith("event:"): + event_type = stripped[6:].strip() + continue + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + if event_type == "cost" and data_lines: + return _parse_cost_text("\n".join(data_lines).strip()) + return None + + +def _parse_cost_text(text: str) -> Any: + if not text: + return None + try: + return json.loads(text) + except json.JSONDecodeError: + try: + return float(text) + except ValueError: + return None + + +def _merge_usage_cost(base: UsageRecord, cost_record: UsageRecord) -> UsageRecord: + """Copy provider-reported cost onto an existing normalized usage record.""" + + if cost_record.provider_reported_cost is None: + return base + return replace( + base, + provider_reported_cost=cost_record.provider_reported_cost, + cost_currency=cost_record.cost_currency, + cost_source=cost_record.cost_source, + ) + + class StreamBuffer: """ Buffer for reassembling fragmented JSON in streams. diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py index 0699b50c8..8db6c492a 100644 --- a/tests/test_streaming_usage_accounting.py +++ b/tests/test_streaming_usage_accounting.py @@ -35,6 +35,26 @@ async def _zero_usage_chunks(): yield {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}]} +async def _cost_comment_chunks(): + yield ': cost {"total_cost":0.042,"currency":"USD","source":"provider_sse"}\n\n' + yield {"id": "chunk_1", "choices": [{"delta": {"content": "hi"}}]} + yield {"id": "chunk_2", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _cost_event_chunks(): + yield 'event: cost\ndata: {"total_cost":0.021,"currency":"EUR","source":"event_cost"}\n\n' + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _cost_comment_overridden_by_final_usage_chunks(): + yield ': cost 0.042\n\n' + yield { + "id": "chunk_2", + "choices": [{"delta": {}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "cost_details": {"total_cost": 0.084, "source": "final_usage"}}, + } + + @pytest.mark.asyncio async def test_streaming_usage_uses_normalized_accounting_and_trace(tmp_path, monkeypatch) -> None: monkeypatch.setattr( @@ -87,3 +107,31 @@ async def test_streaming_usage_uses_configured_env_pricing(monkeypatch) -> None: _ = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "openai/gpt-test", cred_context=cred_context)] assert cred_context.success_kwargs["approx_cost"] == 120.0 + + +@pytest.mark.asyncio +async def test_streaming_cost_comment_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_cost_comment_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert chunks[0].startswith(": cost") + assert cred_context.success_kwargs["approx_cost"] == 0.042 + + +@pytest.mark.asyncio +async def test_streaming_cost_event_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_cost_event_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.021 + + +@pytest.mark.asyncio +async def test_streaming_final_usage_cost_overrides_comment_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_cost_comment_overridden_by_final_usage_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.084 From ac2e920c04ce6218a23060e00720d5bd063e85ab Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:09:48 +0200 Subject: [PATCH 140/182] feat(usage): trace reported costs across surfaces Adds cost breakdowns to native and Responses usage traces, preserves Responses bridge cost fields, and documents quota snapshots as request/token windows rather than unsupported cost ledgers. Tests: pytest tests/test_native_provider_executor.py tests/test_responses_service.py tests/test_usage_quota_snapshots.py --- .../native_provider/executor.py | 4 +++- src/rotator_library/responses/bridge.py | 3 +++ src/rotator_library/responses/service.py | 6 ++++-- src/rotator_library/usage/quota.py | 8 +++++++- tests/test_native_provider_executor.py | 3 +++ tests/test_responses_service.py | 19 +++++++++++++++++++ tests/test_usage_quota_snapshots.py | 2 ++ 7 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index 8e31ace4b..ad94c9ea8 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -14,6 +14,7 @@ from ..protocols import get_protocol, serialize_value from ..transform_trace import REDACTED from ..usage.accounting import extract_usage_record +from ..usage.costs import CostCalculator from .context import NativeProviderContext from .http import NativeHTTPTransport from .streaming import stream_event_payload @@ -70,10 +71,11 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont model=context.model, source="native_provider_response", ) + cost_breakdown = CostCalculator().calculate(usage_record, model=context.model, provider=context.provider) self._trace( context, "usage_accounting_summary", - {"usage": usage_record.to_dict()}, + {"usage": usage_record.to_dict(), "cost": cost_breakdown.to_dict()}, direction="metadata", stage="final", snapshot=False, diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index 8485f29bd..e90893c4b 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -251,6 +251,9 @@ def _usage_to_responses(usage: Any) -> Any: completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") if isinstance(completion_details, dict): result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} + for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "currency"): + if key in usage: + result[key] = deepcopy(usage[key]) return result diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 508c2dd6d..cdd95d8bb 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -12,6 +12,7 @@ from ..protocols import ProtocolContext from ..streaming import StreamEvent, StreamMonitor from ..usage.accounting import extract_usage_record +from ..usage.costs import CostCalculator from ..protocols.responses import ResponsesProtocol from .bridge import ResponsesBridge, responses_session_hints from .store import InMemoryResponsesStore, ResponsesStore @@ -446,13 +447,14 @@ def _trace_responses_usage( if not usage: return record = extract_usage_record(usage, provider="responses", model=model, source=source) + cost_breakdown = CostCalculator().calculate(record, model=model, provider="responses") self._trace( transaction_logger, "usage_accounting_summary", - {"usage": record.to_dict()}, + {"usage": record.to_dict(), "cost": cost_breakdown.to_dict()}, direction="metadata", stage="final", - metadata={"source": source}, + metadata={"source": source, "pricing_source": cost_breakdown.pricing_source}, ) async def _store_stream_response( diff --git a/src/rotator_library/usage/quota.py b/src/rotator_library/usage/quota.py index b42759aa3..7d490c2a0 100644 --- a/src/rotator_library/usage/quota.py +++ b/src/rotator_library/usage/quota.py @@ -58,7 +58,12 @@ def build_quota_snapshots( quota_group: Optional[str] = None, include_credentials: bool = True, ) -> list[QuotaSnapshot]: - """Build read-only quota snapshots from credential states.""" + """Build read-only request/token quota snapshots from credential states. + + The current usage state stores request/token windows, not a reliable + provider-cost ledger. Snapshots therefore avoid inventing cost totals; cost + reporting can be added later only if the underlying state owns that data. + """ snapshots: list[QuotaSnapshot] = [] for stable_id, state in states.items(): @@ -113,6 +118,7 @@ def _snapshots_for_windows( remaining=window.remaining, reset_at=window.reset_at, source=source, + metadata={"scope": "request_token_window"}, ) for window in windows.values() ] diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index ab123145e..e8c6ec54f 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -59,6 +59,7 @@ async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tm "id": "chat_1", "model": "provider/gpt-test", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "hidden"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "cost_details": {"total_cost": 0.01, "source": "native_provider"}}, } client = FakeHTTPClient(response) @@ -89,6 +90,8 @@ async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tm assert "field_cache_extraction_start" in pass_names assert "after_field_cache_extraction" in pass_names assert "field_cache_extraction_complete" in pass_names + usage_entries = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "usage_accounting_summary"] + assert usage_entries[-1]["data"]["cost"]["provider_reported_cost"] == 0.01 assert "final_client_response" in pass_names diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py index 1c71ba2dc..3970a3a90 100644 --- a/tests/test_responses_service.py +++ b/tests/test_responses_service.py @@ -23,6 +23,13 @@ async def acompletion(self, **kwargs): } +class FakeCostClient(FakeClient): + async def acompletion(self, **kwargs): + response = await super().acompletion(**kwargs) + response["usage"]["cost_details"] = {"total_cost": 0.033, "source": "responses_provider"} + return response + + class FakeInternalClient(FakeClient): def __init__(self) -> None: super().__init__() @@ -284,6 +291,18 @@ async def test_service_emits_transform_trace_passes(tmp_path) -> None: ] +@pytest.mark.asyncio +async def test_service_usage_trace_includes_provider_reported_cost(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=InMemoryResponsesStore()) + + await service.create_response({"model": "gpt-test", "input": "Hello"}, FakeCostClient(), transaction_logger=logger) + + usage_entry = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "usage_accounting_summary"][-1] + assert usage_entry["data"]["cost"]["provider_reported_cost"] == 0.033 + assert usage_entry["metadata"]["pricing_source"] == "usage.cost_details" + + def test_trace_responses_usage_returns_before_conversion_without_logger(monkeypatch) -> None: service = ResponsesService(store=InMemoryResponsesStore()) diff --git a/tests/test_usage_quota_snapshots.py b/tests/test_usage_quota_snapshots.py index cc18de095..f110ac4ef 100644 --- a/tests/test_usage_quota_snapshots.py +++ b/tests/test_usage_quota_snapshots.py @@ -24,6 +24,8 @@ def test_build_quota_snapshots_for_model_window() -> None: assert snapshot.used == 3 assert snapshot.remaining == 7 assert snapshot.credential_id != "credential-secret" + assert snapshot.metadata == {"scope": "request_token_window"} + assert "cost_used" not in snapshot.to_dict() def test_build_quota_snapshots_for_group_window_without_credentials() -> None: From d127a08e4ef5b352a0cc934345ccdf0e427d0434 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:16:24 +0200 Subject: [PATCH 141/182] fix(usage): preserve streaming and native costs Preserves provider-reported cost fields in Responses streaming usage and recovers raw provider cost metadata for native execution traces when protocol formatting drops provider-specific cost fields. Tests: pytest tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_streaming_usage_accounting.py tests/test_request_executor_stream_metrics.py tests/test_stream_policy.py tests/test_native_provider_executor.py tests/test_responses_service.py tests/test_responses_streaming.py tests/test_usage_quota_snapshots.py tests/test_protocol_openai_chat.py tests/test_protocol_responses.py --- src/rotator_library/native_provider/executor.py | 14 ++++++++++++++ src/rotator_library/responses/service.py | 3 +++ tests/test_native_provider_executor.py | 2 +- tests/test_responses_streaming.py | 8 ++++++-- 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index ad94c9ea8..61d554d46 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -6,6 +6,7 @@ from __future__ import annotations from copy import deepcopy +from dataclasses import replace from typing import Any, AsyncGenerator from ..adapters import get_adapter, run_adapter_chain @@ -71,6 +72,19 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont model=context.model, source="native_provider_response", ) + raw_usage_record = extract_usage_record( + raw_response, + provider=context.provider, + model=context.model, + source="native_provider_raw_response", + ) + if usage_record.provider_reported_cost is None and raw_usage_record.provider_reported_cost is not None: + usage_record = replace( + usage_record, + provider_reported_cost=raw_usage_record.provider_reported_cost, + cost_currency=raw_usage_record.cost_currency, + cost_source=raw_usage_record.cost_source, + ) cost_breakdown = CostCalculator().calculate(usage_record, model=context.model, provider=context.provider) self._trace( context, diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index cdd95d8bb..fb535bced 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -623,4 +623,7 @@ def _usage_to_responses_stream(usage: Any) -> Any: completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") if isinstance(completion_details, dict): result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} + for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "currency"): + if key in usage: + result[key] = deepcopy(usage[key]) return result diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index e8c6ec54f..a597d7a5f 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -59,7 +59,7 @@ async def test_native_provider_executor_runs_protocol_adapter_cache_and_trace(tm "id": "chat_1", "model": "provider/gpt-test", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "hidden"}}], - "usage": {"prompt_tokens": 1, "completion_tokens": 1, "cost_details": {"total_cost": 0.01, "source": "native_provider"}}, + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_cost": 0.01}, } client = FakeHTTPClient(response) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index a3d5e280d..80b4baa9c 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + import pytest from rotator_library.responses import InMemoryResponsesStore, ResponsesSSEFormatter, ResponsesService, ResponsesStoreSettings, ResponsesStreamEvent, ResponsesWebSocketFormatter @@ -15,7 +17,7 @@ async def acompletion(self, **kwargs): async def chunks(): yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"role":"assistant"}}]}\n\n' yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"content":"Hel"}}]}\n\n' - yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"content":"lo"}}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}\n\n' + yield 'data: {"id":"chat_stream","model":"gpt-test","choices":[{"delta":{"content":"lo"}}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"cost_details":{"total_cost":0.044,"source":"stream_provider"}}}\n\n' yield "data: [DONE]\n\n" return chunks() @@ -75,7 +77,7 @@ async def test_stream_response_emits_responses_sse_events_and_stores_final_respo stored = await store.get(response_id) assert stored is not None assert stored.output_items[0]["content"][0]["text"] == "Hello" - assert stored.usage == {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3} + assert stored.usage == {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3, "cost_details": {"total_cost": 0.044, "source": "stream_provider"}} @pytest.mark.asyncio @@ -210,6 +212,8 @@ async def test_responses_stream_records_common_stream_metrics(tmp_path) -> None: assert "responses_stream_event_created" in trace_text assert "responses_stream_event_output_text_delta" in trace_text assert "responses_stream_event_completed" in trace_text + usage_entry = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines() if '"usage_accounting_summary"' in line][-1] + assert usage_entry["data"]["cost"]["provider_reported_cost"] == 0.044 @pytest.mark.asyncio From 41eebe8dbde4ab57e1b10e38fe6c53a94ab08930 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:22:40 +0200 Subject: [PATCH 142/182] docs(experimental): plan config wiring correction Adds the Phase 10b corrective plan for retry/cooldown, Responses store policy, field-cache config hardening, and env example updates. Tests: not run (planning document only) --- .../phase-10b-config-surface-wiring.md | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 docs/experimental/phase-10b-config-surface-wiring.md diff --git a/docs/experimental/phase-10b-config-surface-wiring.md b/docs/experimental/phase-10b-config-surface-wiring.md new file mode 100644 index 000000000..01aef40b4 --- /dev/null +++ b/docs/experimental/phase-10b-config-surface-wiring.md @@ -0,0 +1,55 @@ +# Phase 10b: Config Surface Wiring And Runtime Settings Completion + +## Goal + +Correct the Phase 10 validation finding that the config layer exists but does not wire enough runtime surfaces introduced in corrective phases. The proxy remains `.env`-first, optional JSON config is structured and secret-free, and environment variables override JSON values. + +## Non-Goals + +- Do not create a full application settings framework. +- Do not move secrets into JSON config. +- Do not replace provider credential discovery. +- Do not add security or multi-user config. +- Do not introduce SQLite or durable DB config. +- Do not commit user-facing reports. + +## Implementation Plan + +1. Add retry/cooldown runtime settings. + - Parse provider cooldown, provider backoff, and failure-history settings from JSON and env. + - Preserve existing env var names and defaults. + - Env overrides JSON. + +2. Wire retry/cooldown settings into retry policy. + - Keep imports lazy to avoid startup cycles. + - Preserve existing monkeypatch/env test behavior. + +3. Add Responses store settings config. + - Parse `RESPONSES_STORE_TTL_SECONDS`, `RESPONSES_STORE_MAX_ITEMS`, `RESPONSES_STORE_FAILED`, and `RESPONSES_STORE_IN_PROGRESS`. + - Support JSON under `responses.store`. + - Preserve default behavior. + +4. Wire Responses store settings into proxy startup. + - Construct `ResponsesService(store_settings=get_responses_store_settings())` in the FastAPI app path. + - Direct tests can still inject explicit settings. + +5. Harden field-cache/provider config parser coverage. + - Cover TTL, metadata mode hints, inject insert behavior, invalid rule errors, and secret rejection in new sections. + +6. Update `.env.example`. + - Document active streaming timeout/heartbeat settings. + - Document provider/model cooldown and backoff knobs. + - Document Responses store policy knobs. + - Document provider-reported cost precedence and SSE cost comments. + - Show safe structured-config section names. + +## Acceptance Criteria + +- Optional JSON config can configure retry/cooldown/backoff and Responses storage policies. +- Env vars override JSON for all new settings. +- Existing env-only behavior remains compatible. +- ResponsesService at proxy startup uses configured store settings. +- Streaming/pricing/routing/field-cache config regressions remain passing. +- Secret-like JSON keys are still rejected in all new sections. +- `.env.example` documents Phase 7b-10b runtime knobs accurately. +- Focused tests and dual-agent review pass. From 6859ec5e5fcd0271c4804c7ffcff284ecd9021f1 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:24:30 +0200 Subject: [PATCH 143/182] feat(config): add retry and responses settings Adds structured retry/cooldown/backoff and Responses store settings helpers with env-over-JSON precedence, and wires retry policy to read the helper lazily while preserving defaults. Tests: pytest tests/test_experimental_config.py tests/test_retry_policy.py --- src/rotator_library/config/experimental.py | 78 +++++++++++++++++++++- src/rotator_library/retry_policy.py | 26 +++++--- tests/test_experimental_config.py | 47 +++++++++++++ 3 files changed, 140 insertions(+), 11 deletions(-) diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py index 19d8b9f64..25ac3aff8 100644 --- a/src/rotator_library/config/experimental.py +++ b/src/rotator_library/config/experimental.py @@ -22,7 +22,7 @@ from ..usage.costs import ModelPricing _CONFIG_ENV_KEYS = ("LLM_PROXY_CONFIG_FILE", "PROXY_CONFIG_FILE") -_KNOWN_SECTIONS = {"routing", "pricing", "streaming", "field_cache", "providers"} +_KNOWN_SECTIONS = {"routing", "pricing", "streaming", "field_cache", "providers", "retry", "responses"} _SECRET_KEY_PARTS = ("api_key", "authorization", "access_token", "refresh_token", "client_secret", "bearer_token", "password") @@ -44,13 +44,15 @@ class ExperimentalConfig: streaming: dict[str, Any] = field(default_factory=dict) field_cache: dict[str, Any] = field(default_factory=dict) providers: dict[str, Any] = field(default_factory=dict) + retry: dict[str, Any] = field(default_factory=dict) + responses: dict[str, Any] = field(default_factory=dict) unknown_sections: dict[str, Any] = field(default_factory=dict) warnings: tuple[str, ...] = () path: Optional[str] = None @property def is_empty(self) -> bool: - return not (self.routing or self.pricing or self.streaming or self.field_cache or self.providers or self.unknown_sections) + return not (self.routing or self.pricing or self.streaming or self.field_cache or self.providers or self.retry or self.responses or self.unknown_sections) @dataclass(frozen=True) @@ -69,6 +71,20 @@ class StreamRuntimeSettings: trace_metrics: bool = True +@dataclass(frozen=True) +class RetryRuntimeSettings: + """Runtime retry/cooldown settings layered from JSON and env.""" + + provider_cooldown_min_seconds: int = 10 + provider_cooldown_default_seconds: int = 30 + provider_cooldown_on_quota: bool = False + provider_backoff_window_seconds: int = 60 + provider_backoff_threshold: int = 3 + provider_backoff_base_seconds: Optional[int] = None + provider_backoff_max_seconds: int = 300 + failure_history_max_entries: int = 200 + + def load_experimental_config(path: str | os.PathLike[str] | None = None, env: Mapping[str, str] | None = None) -> ExperimentalConfig: """Load optional JSON config from an explicit path or config env var.""" @@ -91,6 +107,8 @@ def load_experimental_config(path: str | os.PathLike[str] | None = None, env: Ma streaming=_dict_section(data, "streaming"), field_cache=_dict_section(data, "field_cache"), providers=_dict_section(data, "providers"), + retry=_dict_section(data, "retry"), + responses=_dict_section(data, "responses"), unknown_sections=unknown, warnings=warnings, path=str(resolved), @@ -108,6 +126,8 @@ def load_config_from_mapping(data: Mapping[str, Any]) -> ExperimentalConfig: streaming=_dict_section(data, "streaming"), field_cache=_dict_section(data, "field_cache"), providers=_dict_section(data, "providers"), + retry=_dict_section(data, "retry"), + responses=_dict_section(data, "responses"), unknown_sections={key: value for key, value in data.items() if key not in _KNOWN_SECTIONS}, warnings=warnings, ) @@ -162,6 +182,53 @@ def get_stream_runtime_settings( ) +def get_retry_runtime_settings( + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> RetryRuntimeSettings: + """Return retry/cooldown settings with environment overriding JSON.""" + + source = env if env is not None else os.environ + active = config if config is not None else load_experimental_config(env=source) + retry = active.retry if isinstance(active.retry, dict) else {} + cooldown = retry.get("provider_cooldown", {}) if isinstance(retry.get("provider_cooldown"), dict) else retry + backoff = retry.get("backoff", {}) if isinstance(retry.get("backoff"), dict) else retry + return RetryRuntimeSettings( + provider_cooldown_min_seconds=max(0, as_int(_env_or_json(source, "PROVIDER_COOLDOWN_MIN_SECONDS", cooldown, "provider_cooldown_min_seconds", default=10), name="PROVIDER_COOLDOWN_MIN_SECONDS")), + provider_cooldown_default_seconds=max(0, as_int(_env_or_json(source, "PROVIDER_COOLDOWN_DEFAULT_SECONDS", cooldown, "provider_cooldown_default_seconds", default=30), name="PROVIDER_COOLDOWN_DEFAULT_SECONDS")), + provider_cooldown_on_quota=as_bool(_env_or_json(source, "PROVIDER_COOLDOWN_ON_QUOTA", cooldown, "provider_cooldown_on_quota", default=False), name="PROVIDER_COOLDOWN_ON_QUOTA"), + provider_backoff_window_seconds=max(0, as_int(_env_or_json(source, "PROVIDER_BACKOFF_WINDOW_SECONDS", backoff, "provider_backoff_window_seconds", default=60), name="PROVIDER_BACKOFF_WINDOW_SECONDS")), + provider_backoff_threshold=max(1, as_int(_env_or_json(source, "PROVIDER_BACKOFF_THRESHOLD", backoff, "provider_backoff_threshold", default=3), name="PROVIDER_BACKOFF_THRESHOLD")), + provider_backoff_base_seconds=_optional_positive_int(_env_or_json(source, "PROVIDER_BACKOFF_BASE_SECONDS", backoff, "provider_backoff_base_seconds"), "PROVIDER_BACKOFF_BASE_SECONDS"), + provider_backoff_max_seconds=max(1, as_int(_env_or_json(source, "PROVIDER_BACKOFF_MAX_SECONDS", backoff, "provider_backoff_max_seconds", default=300), name="PROVIDER_BACKOFF_MAX_SECONDS")), + failure_history_max_entries=max(1, as_int(_env_or_json(source, "FAILURE_HISTORY_MAX_ENTRIES", backoff, "failure_history_max_entries", default=200), name="FAILURE_HISTORY_MAX_ENTRIES")), + ) + + +def get_responses_store_settings( + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> Any: + """Return Responses store settings with environment overriding JSON.""" + + from ..responses import ResponsesStoreSettings + + source = env if env is not None else os.environ + active = config if config is not None else load_experimental_config(env=source) + responses = active.responses if isinstance(active.responses, dict) else {} + store = responses.get("store", {}) if isinstance(responses.get("store"), dict) else responses + ttl_seconds = _optional_positive_int(_env_or_json(source, "RESPONSES_STORE_TTL_SECONDS", store, "ttl_seconds"), "RESPONSES_STORE_TTL_SECONDS") + max_items = _optional_positive_int(_env_or_json(source, "RESPONSES_STORE_MAX_ITEMS", store, "max_items"), "RESPONSES_STORE_MAX_ITEMS") + return ResponsesStoreSettings( + ttl_seconds=ttl_seconds, + max_items=max_items, + store_failed=as_bool(_env_or_json(source, "RESPONSES_STORE_FAILED", store, "store_failed", default=True), name="RESPONSES_STORE_FAILED"), + store_in_progress=as_bool(_env_or_json(source, "RESPONSES_STORE_IN_PROGRESS", store, "store_in_progress", default=False), name="RESPONSES_STORE_IN_PROGRESS"), + ) + + def parse_field_cache_rules(config: ExperimentalConfig, provider: str, model: str) -> tuple[FieldCacheRule, ...]: """Parse configured field-cache rules for a provider/model. @@ -296,6 +363,13 @@ def _optional_positive_float(value: Any, name: str) -> Optional[float]: return parsed if parsed > 0 else None +def _optional_positive_int(value: Any, name: str) -> Optional[int]: + if value in (None, ""): + return None + parsed = as_int(value, name=name) + return parsed if parsed > 0 else None + + def _field_cache_rule_from_dict(data: Mapping[str, Any]) -> FieldCacheRule: inject_data = data.get("inject") inject = None diff --git a/src/rotator_library/retry_policy.py b/src/rotator_library/retry_policy.py index ff7f1c6ca..80dce1578 100644 --- a/src/rotator_library/retry_policy.py +++ b/src/rotator_library/retry_policy.py @@ -138,7 +138,8 @@ class FailureHistory: """ def __init__(self, *, max_entries: int | None = None, clock: Any = None) -> None: - self.max_entries = max(1, max_entries if max_entries is not None else _env_int("FAILURE_HISTORY_MAX_ENTRIES", 200)) + settings = _retry_settings() + self.max_entries = max(1, max_entries if max_entries is not None else settings.failure_history_max_entries) self._entries: deque[FailureHistoryEntry] = deque(maxlen=self.max_entries) self._clock = clock or time.time @@ -165,10 +166,11 @@ def snapshot(self) -> tuple[FailureHistoryEntry, ...]: def backoff_for(self, *, provider: Optional[str], error_type: str, scope: str, model: Optional[str], default_duration: int) -> BackoffDecision: """Return bounded backoff for repeated transient failures.""" - window = _env_int("PROVIDER_BACKOFF_WINDOW_SECONDS", 60) - threshold = max(1, _env_int("PROVIDER_BACKOFF_THRESHOLD", 3)) - base = max(1, _env_int("PROVIDER_BACKOFF_BASE_SECONDS", default_duration)) - max_seconds = max(base, _env_int("PROVIDER_BACKOFF_MAX_SECONDS", 300)) + settings = _retry_settings() + window = settings.provider_backoff_window_seconds + threshold = max(1, settings.provider_backoff_threshold) + base = max(1, settings.provider_backoff_base_seconds or default_duration) + max_seconds = max(base, settings.provider_backoff_max_seconds) now = float(self._clock()) recent = [ entry @@ -207,10 +209,8 @@ def is_model_capacity_error(error: Any) -> bool: def provider_cooldown_env() -> tuple[int, int, bool]: """Read provider-cooldown env controls with conservative defaults.""" - min_seconds = _env_int("PROVIDER_COOLDOWN_MIN_SECONDS", 10) - default_seconds = _env_int("PROVIDER_COOLDOWN_DEFAULT_SECONDS", DEFAULT_PROVIDER_COOLDOWN_DEFAULT_SECONDS) - cooldown_on_quota = os.environ.get("PROVIDER_COOLDOWN_ON_QUOTA", "").strip().lower() in {"1", "true", "yes", "on"} - return max(0, min_seconds), max(0, default_seconds), cooldown_on_quota + settings = _retry_settings() + return settings.provider_cooldown_min_seconds, settings.provider_cooldown_default_seconds, settings.provider_cooldown_on_quota def is_target_failover_eligible( @@ -230,3 +230,11 @@ def _env_int(name: str, default: int) -> int: return int(os.environ.get(name, default)) except (TypeError, ValueError): return default + + +def _retry_settings() -> Any: + """Load retry settings lazily to avoid config import cycles at startup.""" + + from .config.experimental import get_retry_runtime_settings + + return get_retry_runtime_settings() diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py index 6d78c7175..72e3f8cf9 100644 --- a/tests/test_experimental_config.py +++ b/tests/test_experimental_config.py @@ -6,6 +6,8 @@ ExperimentalConfigError, as_int, env_price_key, + get_responses_store_settings, + get_retry_runtime_settings, get_stream_runtime_settings, load_config_from_mapping, load_experimental_config, @@ -52,6 +54,51 @@ def test_stream_runtime_settings_env_overrides_json() -> None: assert settings.stall_timeout_seconds == 30 +def test_retry_runtime_settings_env_overrides_json() -> None: + config = load_config_from_mapping( + { + "retry": { + "provider_cooldown": {"provider_cooldown_min_seconds": 20, "provider_cooldown_on_quota": True}, + "backoff": {"provider_backoff_threshold": 4, "failure_history_max_entries": 50}, + } + } + ) + + settings = get_retry_runtime_settings(config=config, env={"PROVIDER_COOLDOWN_MIN_SECONDS": "5"}) + + assert settings.provider_cooldown_min_seconds == 5 + assert settings.provider_cooldown_on_quota is True + assert settings.provider_backoff_threshold == 4 + assert settings.failure_history_max_entries == 50 + + +def test_responses_store_settings_env_overrides_json() -> None: + config = load_config_from_mapping( + { + "responses": { + "store": { + "ttl_seconds": 60, + "max_items": 10, + "store_failed": False, + "store_in_progress": True, + } + } + } + ) + + settings = get_responses_store_settings(config=config, env={"RESPONSES_STORE_MAX_ITEMS": "5", "RESPONSES_STORE_FAILED": "true"}) + + assert settings.ttl_seconds == 60 + assert settings.max_items == 5 + assert settings.store_failed is True + assert settings.store_in_progress is True + + +def test_new_config_sections_still_reject_secret_like_keys() -> None: + with pytest.raises(ExperimentalConfigError): + load_config_from_mapping({"responses": {"store": {"authorization": "hidden"}}}) + + def test_field_cache_rules_parse_wildcard_then_model_specific() -> None: config = load_config_from_mapping( { From bbeba37241f690a6a2caf284a878f53fecb539b4 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:26:21 +0200 Subject: [PATCH 144/182] feat(config): wire runtime settings surfaces Adds JSON/env helpers for retry cooldown/backoff and Responses store policy, wires retry policy and proxy startup to use them, hardens field-cache config coverage, and updates the environment example for Phase 7b-10b runtime knobs. Tests: pytest tests/test_experimental_config.py tests/test_retry_policy.py tests/test_responses_service.py tests/test_routing_config.py tests/test_streaming_usage_accounting.py tests/test_config_stream_settings.py --- .env.example | 38 +++++++++++++++++++++++++------ src/proxy_app/main.py | 8 +++++-- tests/test_experimental_config.py | 29 +++++++++++++++++++++++ 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/.env.example b/.env.example index 9d638a8f8..5638a1f66 100644 --- a/.env.example +++ b/.env.example @@ -338,6 +338,9 @@ # groups, model routes, advisory pricing, streaming observability, and field # cache rules. Environment variables override JSON config. Do not put API keys, # OAuth tokens, bearer tokens, or authorization headers in this file. +# Supported top-level sections include: routing, pricing, streaming, retry, +# responses, field_cache, and providers. The providers section is for non-secret +# provider tuning only; credentials must stay in env or provider credential files. # LLM_PROXY_CONFIG_FILE=./config/llm-proxy.json # --- Fallback Groups and Model Routes --- @@ -353,23 +356,44 @@ # --- Provider Cooldown Activation --- # Provider-level cooldown is conservative and only intended for large/global -# retry-after events, not every per-credential quota error. +# retry-after events, not every per-credential quota error. Model-capacity +# errors can start a model-scoped cooldown without blocking unrelated models. # PROVIDER_COOLDOWN_MIN_SECONDS=10 # PROVIDER_COOLDOWN_DEFAULT_SECONDS=60 # PROVIDER_COOLDOWN_ON_QUOTA=false +# Repeated transient provider/model failures can increase bounded backoff. +# PROVIDER_BACKOFF_WINDOW_SECONDS=60 +# PROVIDER_BACKOFF_THRESHOLD=3 +# PROVIDER_BACKOFF_BASE_SECONDS=60 +# PROVIDER_BACKOFF_MAX_SECONDS=300 +# FAILURE_HISTORY_MAX_ENTRIES=200 + +# --- Responses API Store Policy --- +# Responses are stored in-memory by default. TTL/max limits are disabled when +# unset or <= 0. Failed stream responses are stored by default; in-progress +# streaming state is disabled unless explicitly enabled. +# RESPONSES_STORE_TTL_SECONDS=0 +# RESPONSES_STORE_MAX_ITEMS=0 +# RESPONSES_STORE_FAILED=true +# RESPONSES_STORE_IN_PROGRESS=false # --- Streaming Observability --- -# Stream lifecycle metrics are traced by default. Timeout/heartbeat values are -# parsed for future enforcement but are not used to abort streams by default. +# Stream lifecycle metrics are traced by default. TTFB/stall timeout values are +# active when set to >0; keep at 0 to disable. Heartbeats are SSE comments and +# do not count as visible output. # STREAM_TRACE_METRICS=true # STREAM_TTFB_TIMEOUT_SECONDS=0 # STREAM_STALL_TIMEOUT_SECONDS=0 -# STREAM_HEARTBEAT_SECONDS=0 +# STREAM_HEARTBEAT_INTERVAL_SECONDS=0 +# STREAM_HEARTBEAT_SECONDS=0 # Legacy alias +# STREAM_CANCEL_UPSTREAM_ON_DISCONNECT=true # --- Advisory Model Pricing --- -# Per-token advisory prices used for usage/cost traces and stored approximate -# cost values when provider/LiteLLM pricing is unavailable. Provider code can -# still supply authoritative pricing, and skip-cost providers always return zero. +# Per-token advisory prices used only when providers do not report actual cost. +# Precedence: skip-cost provider setting > provider-reported cost/SSE cost event +# > provider explicit pricing > env pricing > JSON pricing > LiteLLM metadata. +# Streaming providers can report actual cost with `: cost {...}` comments or +# `event: cost` frames. # Env names sanitize provider/model by replacing non-alphanumerics with `_`. # MODEL_PRICE_OPENAI_GPT_5_1_INPUT=0.000001 # MODEL_PRICE_OPENAI_GPT_5_1_OUTPUT=0.00001 diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 48f3e7b75..fe7f91ff3 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -597,7 +597,9 @@ async def process_credential(provider: str, path: str, provider_instance): # Phase 4 Responses API compatibility service. It currently bridges through # the existing chat-completions client path; later native providers can reuse # the same route/storage surface without changing clients. - app.state.responses_service = ResponsesService() + from rotator_library.config.experimental import get_responses_store_settings + + app.state.responses_service = ResponsesService(store_settings=get_responses_store_settings()) # Warn if no provider credentials are configured if not client.all_credentials: @@ -674,7 +676,9 @@ def get_responses_service(request: Request) -> ResponsesService: service = getattr(request.app.state, "responses_service", None) if service is None: - service = ResponsesService() + from rotator_library.config.experimental import get_responses_store_settings + + service = ResponsesService(store_settings=get_responses_store_settings()) request.app.state.responses_service = service return service diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py index 72e3f8cf9..812137dd1 100644 --- a/tests/test_experimental_config.py +++ b/tests/test_experimental_config.py @@ -133,6 +133,35 @@ def test_field_cache_rules_match_unprefixed_model_alias() -> None: assert [rule.name for rule in rules] == ["signature"] +def test_field_cache_rule_parses_ttl_metadata_and_insert_injection() -> None: + config = load_config_from_mapping( + { + "field_cache": { + "provider": { + "*": [ + { + "name": "tool_state", + "source": "stream_event", + "path": "raw.tool.state", + "mode": "per_tool_call", + "ttl_seconds": 120, + "metadata": {"tool_container_path": "tools"}, + "inject": {"target": "request", "path": "metadata.tool_state", "insert": True}, + } + ] + } + } + } + ) + + rule = parse_field_cache_rules(config, "provider", "model")[0] + + assert rule.ttl_seconds == 120 + assert rule.metadata == {"tool_container_path": "tools"} + assert rule.inject is not None + assert rule.inject.insert is True + + def test_field_cache_rule_rejects_invalid_config_values() -> None: config = load_config_from_mapping( { From 2f47569cf5883bf39e15761f2ae6edeb3546c757 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:33:57 +0200 Subject: [PATCH 145/182] fix(config): harden experimental settings parsing Broadens secret-like key rejection, keeps retry env parsing tolerant for malformed legacy values, fails clearly on malformed field-cache rule shapes, and updates env-example defaults and drift coverage. Tests: pytest tests/test_experimental_config.py tests/test_env_example_experimental_config.py tests/test_retry_policy.py tests/test_responses_service.py tests/test_routing_config.py tests/test_config_stream_settings.py tests/test_field_cache_engine.py tests/test_native_provider_executor.py tests/test_usage_costs.py tests/test_usage_accounting.py --- .env.example | 4 +- src/rotator_library/config/experimental.py | 61 +++++++++++++++---- tests/test_env_example_experimental_config.py | 12 ++++ tests/test_experimental_config.py | 37 +++++++++++ 4 files changed, 100 insertions(+), 14 deletions(-) diff --git a/.env.example b/.env.example index 5638a1f66..88ff74cb5 100644 --- a/.env.example +++ b/.env.example @@ -359,12 +359,12 @@ # retry-after events, not every per-credential quota error. Model-capacity # errors can start a model-scoped cooldown without blocking unrelated models. # PROVIDER_COOLDOWN_MIN_SECONDS=10 -# PROVIDER_COOLDOWN_DEFAULT_SECONDS=60 +# PROVIDER_COOLDOWN_DEFAULT_SECONDS=30 # PROVIDER_COOLDOWN_ON_QUOTA=false # Repeated transient provider/model failures can increase bounded backoff. # PROVIDER_BACKOFF_WINDOW_SECONDS=60 # PROVIDER_BACKOFF_THRESHOLD=3 -# PROVIDER_BACKOFF_BASE_SECONDS=60 +# PROVIDER_BACKOFF_BASE_SECONDS=0 # 0/unset means use provider cooldown default # PROVIDER_BACKOFF_MAX_SECONDS=300 # FAILURE_HISTORY_MAX_ENTRIES=200 diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py index 25ac3aff8..d4cf56218 100644 --- a/src/rotator_library/config/experimental.py +++ b/src/rotator_library/config/experimental.py @@ -23,7 +23,7 @@ _CONFIG_ENV_KEYS = ("LLM_PROXY_CONFIG_FILE", "PROXY_CONFIG_FILE") _KNOWN_SECTIONS = {"routing", "pricing", "streaming", "field_cache", "providers", "retry", "responses"} -_SECRET_KEY_PARTS = ("api_key", "authorization", "access_token", "refresh_token", "client_secret", "bearer_token", "password") +_SECRET_KEY_PARTS = ("api_key", "apikey", "authorization", "access_token", "accesstoken", "refresh_token", "refreshtoken", "client_secret", "clientsecret", "secret_key", "secretkey", "bearer_token", "bearertoken", "password") class ExperimentalConfigError(ValueError): @@ -195,14 +195,14 @@ def get_retry_runtime_settings( cooldown = retry.get("provider_cooldown", {}) if isinstance(retry.get("provider_cooldown"), dict) else retry backoff = retry.get("backoff", {}) if isinstance(retry.get("backoff"), dict) else retry return RetryRuntimeSettings( - provider_cooldown_min_seconds=max(0, as_int(_env_or_json(source, "PROVIDER_COOLDOWN_MIN_SECONDS", cooldown, "provider_cooldown_min_seconds", default=10), name="PROVIDER_COOLDOWN_MIN_SECONDS")), - provider_cooldown_default_seconds=max(0, as_int(_env_or_json(source, "PROVIDER_COOLDOWN_DEFAULT_SECONDS", cooldown, "provider_cooldown_default_seconds", default=30), name="PROVIDER_COOLDOWN_DEFAULT_SECONDS")), - provider_cooldown_on_quota=as_bool(_env_or_json(source, "PROVIDER_COOLDOWN_ON_QUOTA", cooldown, "provider_cooldown_on_quota", default=False), name="PROVIDER_COOLDOWN_ON_QUOTA"), - provider_backoff_window_seconds=max(0, as_int(_env_or_json(source, "PROVIDER_BACKOFF_WINDOW_SECONDS", backoff, "provider_backoff_window_seconds", default=60), name="PROVIDER_BACKOFF_WINDOW_SECONDS")), - provider_backoff_threshold=max(1, as_int(_env_or_json(source, "PROVIDER_BACKOFF_THRESHOLD", backoff, "provider_backoff_threshold", default=3), name="PROVIDER_BACKOFF_THRESHOLD")), - provider_backoff_base_seconds=_optional_positive_int(_env_or_json(source, "PROVIDER_BACKOFF_BASE_SECONDS", backoff, "provider_backoff_base_seconds"), "PROVIDER_BACKOFF_BASE_SECONDS"), - provider_backoff_max_seconds=max(1, as_int(_env_or_json(source, "PROVIDER_BACKOFF_MAX_SECONDS", backoff, "provider_backoff_max_seconds", default=300), name="PROVIDER_BACKOFF_MAX_SECONDS")), - failure_history_max_entries=max(1, as_int(_env_or_json(source, "FAILURE_HISTORY_MAX_ENTRIES", backoff, "failure_history_max_entries", default=200), name="FAILURE_HISTORY_MAX_ENTRIES")), + provider_cooldown_min_seconds=max(0, _int_setting(source, "PROVIDER_COOLDOWN_MIN_SECONDS", cooldown, "provider_cooldown_min_seconds", 10)), + provider_cooldown_default_seconds=max(0, _int_setting(source, "PROVIDER_COOLDOWN_DEFAULT_SECONDS", cooldown, "provider_cooldown_default_seconds", 30)), + provider_cooldown_on_quota=_bool_setting(source, "PROVIDER_COOLDOWN_ON_QUOTA", cooldown, "provider_cooldown_on_quota", False), + provider_backoff_window_seconds=max(0, _int_setting(source, "PROVIDER_BACKOFF_WINDOW_SECONDS", backoff, "provider_backoff_window_seconds", 60)), + provider_backoff_threshold=max(1, _int_setting(source, "PROVIDER_BACKOFF_THRESHOLD", backoff, "provider_backoff_threshold", 3)), + provider_backoff_base_seconds=_optional_int_setting(source, "PROVIDER_BACKOFF_BASE_SECONDS", backoff, "provider_backoff_base_seconds"), + provider_backoff_max_seconds=max(1, _int_setting(source, "PROVIDER_BACKOFF_MAX_SECONDS", backoff, "provider_backoff_max_seconds", 300)), + failure_history_max_entries=max(1, _int_setting(source, "FAILURE_HISTORY_MAX_ENTRIES", backoff, "failure_history_max_entries", 200)), ) @@ -240,7 +240,7 @@ def parse_field_cache_rules(config: ExperimentalConfig, provider: str, model: st provider_rules = config.field_cache.get(provider, {}) if isinstance(config.field_cache, dict) else {} if not isinstance(provider_rules, dict): - return () + raise ExperimentalConfigError("field_cache provider section must be an object") raw_rules: list[Any] = [] keys = ["*"] if "/" in model: @@ -250,7 +250,14 @@ def parse_field_cache_rules(config: ExperimentalConfig, provider: str, model: st value = provider_rules.get(key, []) if isinstance(value, list): raw_rules.extend(value) - return tuple(_field_cache_rule_from_dict(rule) for rule in raw_rules if isinstance(rule, dict)) + elif value not in (None, []): + raise ExperimentalConfigError("field_cache model rules must be a list") + parsed_rules = [] + for rule in raw_rules: + if not isinstance(rule, dict): + raise ExperimentalConfigError("field_cache rule entries must be objects") + parsed_rules.append(_field_cache_rule_from_dict(rule)) + return tuple(parsed_rules) def as_bool(value: Any, *, name: str) -> bool: @@ -314,7 +321,9 @@ def _reject_secret_keys(value: Any, path: str = "config") -> None: if isinstance(value, Mapping): for key, nested in value.items(): key_text = str(key).lower() - if any(part in key_text for part in _SECRET_KEY_PARTS): + compact_key = re.sub(r"[^a-z0-9]+", "", key_text) + underscored_key = re.sub(r"[^a-z0-9]+", "_", key_text) + if any(part in key_text or part in compact_key or part in underscored_key for part in _SECRET_KEY_PARTS): raise ExperimentalConfigError(f"Unsafe secret-like key in JSON config at {path}.{key}") _reject_secret_keys(nested, f"{path}.{key}") elif isinstance(value, list): @@ -354,6 +363,34 @@ def _env_or_json(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], return data.get(json_key, default) +def _int_setting(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], json_key: str, default: int) -> int: + if env_key in env: + try: + return int(env.get(env_key) or default) + except (TypeError, ValueError): + return default + return as_int(data.get(json_key, default), name=env_key) + + +def _optional_int_setting(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], json_key: str) -> Optional[int]: + if env_key in env: + try: + parsed = int(env.get(env_key) or 0) + except (TypeError, ValueError): + return None + return parsed if parsed > 0 else None + return _optional_positive_int(data.get(json_key), env_key) + + +def _bool_setting(env: Mapping[str, str], env_key: str, data: Mapping[str, Any], json_key: str, default: bool) -> bool: + if env_key in env: + try: + return as_bool(env.get(env_key), name=env_key) + except ExperimentalConfigError: + return default + return as_bool(data.get(json_key, default), name=env_key) + + def _optional_positive_float(value: Any, name: str) -> Optional[float]: if value in (None, ""): return None diff --git a/tests/test_env_example_experimental_config.py b/tests/test_env_example_experimental_config.py index 5435c17c2..905c961ce 100644 --- a/tests/test_env_example_experimental_config.py +++ b/tests/test_env_example_experimental_config.py @@ -14,13 +14,25 @@ def test_env_example_documents_experimental_config_knobs() -> None: "PROVIDER_COOLDOWN_MIN_SECONDS", "PROVIDER_COOLDOWN_DEFAULT_SECONDS", "PROVIDER_COOLDOWN_ON_QUOTA", + "PROVIDER_BACKOFF_WINDOW_SECONDS", + "PROVIDER_BACKOFF_THRESHOLD", + "PROVIDER_BACKOFF_BASE_SECONDS", + "PROVIDER_BACKOFF_MAX_SECONDS", + "FAILURE_HISTORY_MAX_ENTRIES", + "RESPONSES_STORE_TTL_SECONDS", + "RESPONSES_STORE_MAX_ITEMS", + "RESPONSES_STORE_FAILED", + "RESPONSES_STORE_IN_PROGRESS", "STREAM_TRACE_METRICS", "STREAM_TTFB_TIMEOUT_SECONDS", "STREAM_STALL_TIMEOUT_SECONDS", + "STREAM_HEARTBEAT_INTERVAL_SECONDS", "STREAM_HEARTBEAT_SECONDS", + "STREAM_CANCEL_UPSTREAM_ON_DISCONNECT", "MODEL_PRICE_OPENAI_GPT_5_1_INPUT", "MODEL_PRICE_OPENAI_GPT_5_1_REASONING", ): assert key in text assert "Do not put API keys" in text + assert "PROVIDER_COOLDOWN_DEFAULT_SECONDS=30" in text diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py index 812137dd1..9e355edd9 100644 --- a/tests/test_experimental_config.py +++ b/tests/test_experimental_config.py @@ -99,6 +99,29 @@ def test_new_config_sections_still_reject_secret_like_keys() -> None: load_config_from_mapping({"responses": {"store": {"authorization": "hidden"}}}) +@pytest.mark.parametrize("secret_key", ["secret_key", "secret-key", "apiKey", "client-secret"]) +def test_secret_key_variants_are_rejected(secret_key: str) -> None: + with pytest.raises(ExperimentalConfigError): + load_config_from_mapping({"retry": {secret_key: "hidden"}}) + + +def test_retry_runtime_settings_malformed_env_preserves_defaults() -> None: + config = load_config_from_mapping({"retry": {"provider_cooldown_default_seconds": 45}}) + + settings = get_retry_runtime_settings( + config=config, + env={ + "PROVIDER_COOLDOWN_DEFAULT_SECONDS": "not-an-int", + "PROVIDER_COOLDOWN_ON_QUOTA": "not-a-bool", + "PROVIDER_BACKOFF_BASE_SECONDS": "not-an-int", + }, + ) + + assert settings.provider_cooldown_default_seconds == 30 + assert settings.provider_cooldown_on_quota is False + assert settings.provider_backoff_base_seconds is None + + def test_field_cache_rules_parse_wildcard_then_model_specific() -> None: config = load_config_from_mapping( { @@ -179,6 +202,20 @@ def test_field_cache_rule_rejects_invalid_config_values() -> None: parse_field_cache_rules(config, "provider", "model") +def test_field_cache_rules_reject_malformed_shapes() -> None: + config = load_config_from_mapping({"field_cache": {"provider": {"*": ["not-a-rule"]}}}) + + with pytest.raises(ExperimentalConfigError, match="rule entries"): + parse_field_cache_rules(config, "provider", "model") + + +def test_field_cache_rules_reject_non_list_model_rules() -> None: + config = load_config_from_mapping({"field_cache": {"provider": {"*": {"name": "bad"}}}}) + + with pytest.raises(ExperimentalConfigError, match="model rules"): + parse_field_cache_rules(config, "provider", "model") + + def test_env_price_key_sanitizes_provider_and_model() -> None: assert env_price_key("openai", "gpt-5.1-mini", "cache_read") == "MODEL_PRICE_OPENAI_GPT_5_1_MINI_CACHE_READ" From 36c11ce7491be20e32185b0b264537ded8f12a4a Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 15:41:35 +0200 Subject: [PATCH 146/182] fix(config): reject more secret-like keys Rejects OAuth and id-token key variants in JSON config, fails clearly on malformed nested field-cache shapes, and expands env-example default drift tests. Tests: pytest tests/test_experimental_config.py tests/test_env_example_experimental_config.py tests/test_retry_policy.py tests/test_responses_service.py tests/test_routing_config.py tests/test_config_stream_settings.py tests/test_field_cache_engine.py tests/test_native_provider_executor.py tests/test_usage_costs.py tests/test_usage_accounting.py --- src/rotator_library/config/experimental.py | 12 +++++++++-- tests/test_env_example_experimental_config.py | 21 ++++++++++++++++++- tests/test_experimental_config.py | 14 ++++++++++++- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py index d4cf56218..ca0d00bf3 100644 --- a/src/rotator_library/config/experimental.py +++ b/src/rotator_library/config/experimental.py @@ -23,7 +23,7 @@ _CONFIG_ENV_KEYS = ("LLM_PROXY_CONFIG_FILE", "PROXY_CONFIG_FILE") _KNOWN_SECTIONS = {"routing", "pricing", "streaming", "field_cache", "providers", "retry", "responses"} -_SECRET_KEY_PARTS = ("api_key", "apikey", "authorization", "access_token", "accesstoken", "refresh_token", "refreshtoken", "client_secret", "clientsecret", "secret_key", "secretkey", "bearer_token", "bearertoken", "password") +_SECRET_KEY_PARTS = ("api_key", "apikey", "authorization", "access_token", "accesstoken", "refresh_token", "refreshtoken", "oauth_token", "oauthtoken", "oauth_token_secret", "oauthtokensecret", "id_token", "idtoken", "token_secret", "tokensecret", "client_secret", "clientsecret", "secret_key", "secretkey", "bearer_token", "bearertoken", "password") class ExperimentalConfigError(ValueError): @@ -418,6 +418,8 @@ def _field_cache_rule_from_dict(data: Mapping[str, Any]) -> FieldCacheRule: insert=as_bool(inject_data.get("insert", False), name="field_cache.inject.insert"), as_list=as_bool(inject_data.get("as_list", False), name="field_cache.inject.as_list"), ) + elif inject_data is not None: + raise ExperimentalConfigError("field_cache.inject must be an object") elif data.get("target_path"): inject = FieldCacheInjection(target=str(data.get("target", "request")), path=str(data["target_path"])) scope = data.get("scope", ("provider", "model", "classifier", "session")) @@ -437,7 +439,7 @@ def _field_cache_rule_from_dict(data: Mapping[str, Any]) -> FieldCacheRule: inject=inject, enabled=as_bool(data.get("enabled", True), name="field_cache.enabled"), ttl_seconds=int(data["ttl_seconds"]) if data.get("ttl_seconds") is not None else None, - metadata=dict(data.get("metadata", {})) if isinstance(data.get("metadata", {}), dict) else {}, + metadata=_metadata_dict(data.get("metadata", {})), allow_missing_session=as_bool(data.get("allow_missing_session", False), name="field_cache.allow_missing_session"), ) except KeyError as exc: @@ -446,5 +448,11 @@ def _field_cache_rule_from_dict(data: Mapping[str, Any]) -> FieldCacheRule: raise ExperimentalConfigError(str(exc)) from exc +def _metadata_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return dict(value) + raise ExperimentalConfigError("field_cache.metadata must be an object") + + def _env_part(value: str) -> str: return re.sub(r"[^A-Z0-9]+", "_", value.upper()).strip("_") diff --git a/tests/test_env_example_experimental_config.py b/tests/test_env_example_experimental_config.py index 905c961ce..cac0f0d61 100644 --- a/tests/test_env_example_experimental_config.py +++ b/tests/test_env_example_experimental_config.py @@ -35,4 +35,23 @@ def test_env_example_documents_experimental_config_knobs() -> None: assert key in text assert "Do not put API keys" in text - assert "PROVIDER_COOLDOWN_DEFAULT_SECONDS=30" in text + for default_line in ( + "PROVIDER_COOLDOWN_MIN_SECONDS=10", + "PROVIDER_COOLDOWN_DEFAULT_SECONDS=30", + "PROVIDER_COOLDOWN_ON_QUOTA=false", + "PROVIDER_BACKOFF_WINDOW_SECONDS=60", + "PROVIDER_BACKOFF_THRESHOLD=3", + "PROVIDER_BACKOFF_BASE_SECONDS=0", + "PROVIDER_BACKOFF_MAX_SECONDS=300", + "FAILURE_HISTORY_MAX_ENTRIES=200", + "RESPONSES_STORE_TTL_SECONDS=0", + "RESPONSES_STORE_MAX_ITEMS=0", + "RESPONSES_STORE_FAILED=true", + "RESPONSES_STORE_IN_PROGRESS=false", + "STREAM_TRACE_METRICS=true", + "STREAM_TTFB_TIMEOUT_SECONDS=0", + "STREAM_STALL_TIMEOUT_SECONDS=0", + "STREAM_HEARTBEAT_INTERVAL_SECONDS=0", + "STREAM_CANCEL_UPSTREAM_ON_DISCONNECT=true", + ): + assert default_line in text diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py index 9e355edd9..9506edbe1 100644 --- a/tests/test_experimental_config.py +++ b/tests/test_experimental_config.py @@ -99,7 +99,7 @@ def test_new_config_sections_still_reject_secret_like_keys() -> None: load_config_from_mapping({"responses": {"store": {"authorization": "hidden"}}}) -@pytest.mark.parametrize("secret_key", ["secret_key", "secret-key", "apiKey", "client-secret"]) +@pytest.mark.parametrize("secret_key", ["secret_key", "secret-key", "apiKey", "client-secret", "oauth_token", "oauthToken", "oauth-token", "id_token", "oauth_token_secret"]) def test_secret_key_variants_are_rejected(secret_key: str) -> None: with pytest.raises(ExperimentalConfigError): load_config_from_mapping({"retry": {secret_key: "hidden"}}) @@ -216,6 +216,18 @@ def test_field_cache_rules_reject_non_list_model_rules() -> None: parse_field_cache_rules(config, "provider", "model") +def test_field_cache_rules_reject_malformed_nested_shapes() -> None: + bad_inject = load_config_from_mapping({"field_cache": {"provider": {"*": [{"name": "bad", "source": "response", "path": "x", "inject": "not-object"}]}}}) + + with pytest.raises(ExperimentalConfigError, match="inject"): + parse_field_cache_rules(bad_inject, "provider", "model") + + bad_metadata = load_config_from_mapping({"field_cache": {"provider": {"*": [{"name": "bad", "source": "response", "path": "x", "metadata": "not-object"}]}}}) + + with pytest.raises(ExperimentalConfigError, match="metadata"): + parse_field_cache_rules(bad_metadata, "provider", "model") + + def test_env_price_key_sanitizes_provider_and_model() -> None: assert env_price_key("openai", "gpt-5.1-mini", "cache_read") == "MODEL_PRICE_OPENAI_GPT_5_1_MINI_CACHE_READ" From 8b1f1bcb6e3966e81375ea9e0cefbf56fb6bd8f3 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:02:47 +0200 Subject: [PATCH 147/182] docs(experimental): record third-pass audit findings Preserves the complete third-pass validation findings across phases 1-10, including blockers, highs, mediums, low residuals, and remediation order. Tests: not run (documentation only) --- .../experimental/third-pass-audit-findings.md | 361 ++++++++++++++++++ 1 file changed, 361 insertions(+) create mode 100644 docs/experimental/third-pass-audit-findings.md diff --git a/docs/experimental/third-pass-audit-findings.md b/docs/experimental/third-pass-audit-findings.md new file mode 100644 index 000000000..29658805c --- /dev/null +++ b/docs/experimental/third-pass-audit-findings.md @@ -0,0 +1,361 @@ +# Third-Pass Audit Findings + +This document preserves the complete third-pass validation findings so the remediation pass cannot lose track of them. + +References used by the reviewers: + +- `C:\Projects\test\new 1.txt` +- `C:\Projects\test\new 2.txt` +- General docs: `00-master-plan.md`, `01-protocol-architecture.md`, `02-transform-logging.md`, `03-field-cache-rules.md`, `04-provider-roadmap.md`, `05-routing-retry-usage-roadmap.md`, `06-phase-workflow.md`, `07-detailed-phase-roadmap.md` +- Phase first-pass plans: `phase-1-protocol-core.md` through `phase-10-config-polish.md` +- Phase second-pass plans: `phase-1b-protocol-breadth-operation-model.md` through `phase-10b-config-surface-wiring.md` + +## Overall Verdict + +The third-pass review is not clean. Each phase received one normal `explore` review and one `explore-heavy` review. The reviewers found remaining blockers, highs, mediums, and low-risk residuals across the implemented work. + +## Severity Summary + +| Phase | Highest Severity | Clean? | +|---|---:|---:| +| Phase 1 / 1b Protocols | High | No | +| Phase 2 / 2b Transform Tracing | High | No | +| Phase 3 / 3b Field Cache | Medium | No | +| Phase 4 / 4b Responses | Medium | No | +| Phase 5 / 5b Providers | Blocker | No | +| Phase 6 / 6b Routing/Fallback | Blocker | No | +| Phase 7 / 7b Retry/Cooldown | High | No | +| Phase 8 / 8b Streaming | Medium | No | +| Phase 9 / 9b Usage/Cost | High | No | +| Phase 10 / 10b Config | High | No | + +## Phase 1 / 1b: Protocols + +### Blockers + +- None reported. + +### High + +- Native OpenAI Chat and Responses formatting returns unified usage shape instead of protocol-native usage shape. + - OpenAI Chat `format_response()` can emit `Usage.to_dict()` fields like `input_tokens` / `output_tokens` instead of `prompt_tokens` / `completion_tokens` / `total_tokens`. + - Responses `format_response()` can do the same instead of Responses-compatible usage fields. + - Native executor returns formatted protocol bodies directly to clients, so this can leak wrong usage schemas into live native responses. + +### Medium + +- OpenAI legacy `function_call` is preserved only in `extra`, not modeled as a unified `ToolCall`. +- Ollama lacks a `format_response()` override, so adapter-mutated unified responses can be ignored when `raw` is present. +- Operation compatibility is not enforced in live native provider execution. + - Native execution resolves protocol and operation but does not call `supports_operation()`. + - Misconfigured provider/protocol/operation combinations can silently fall back or drift. + +### Low / Residual + +- No cross-protocol conversion tests such as Anthropic -> unified -> OpenAI or Gemini -> unified -> OpenAI. +- `format_stream_event()` is not overridden by Anthropic, Gemini, Responses, or Ollama; they rely on base raw-pass-through behavior. +- Ollama declares `jsonl` support without JSONL-specific stream formatting. +- Unified types have `to_dict()` but no `from_dict()` deserialization helpers. +- Duplicate helper functions exist across protocol modules. +- No explicit `format_response()` tests for OpenAI Chat, Anthropic Messages, or Gemini. +- `CostDetails` protocol extraction is present in OpenAI Chat and Responses only. +- Registry discovery imports protocol modules without import-error isolation. +- Missing tests for formatted usage schemas, legacy `function_call`, Ollama mutated response formatting, registry failure isolation, and native unsupported operation rejection. + +## Phase 2 / 2b: Transform Tracing + +### Blockers + +- None reported. + +### High + +- Anthropic compatibility path lacks transform-pass tracing for live conversion boundaries. + - Missing explicit traces for Anthropic raw request, Anthropic -> OpenAI conversion, OpenAI -> Anthropic final response, and per-event stream conversion. + - `AnthropicHandler` and `anthropic_streaming_wrapper` mostly rely on legacy transaction logger calls. + +### Medium + +- `raw_provider_stream_response` logs the stream iterator object rather than provider stream data. +- Streaming handler lifecycle traces are gated by `trace_metrics`; chunk traces remain, but lifecycle diagnostics can disappear. +- Non-streaming final response trace can be emitted before usage normalization, making `final_client_response` not truly final. +- Native streaming does not run or trace the stream-event adapter chain. +- Responses streaming does not trace final SSE formatting boundaries. +- Redaction misses common camelCase secret keys such as `apiKey`, `accessToken`, `refreshToken`, and `clientSecret`. +- Non-streaming Responses errors are not traced with standardized transform-failure shape. +- Built-in provider transform exceptions are not logged as transform failures. + +### Low / Residual + +- `ProviderTransforms.apply_sync()` has no trace support; currently not used by live execution. +- `AnthropicHandler.count_tokens()` has no transaction logger or trace entries. +- Provider logger trace entries cannot override per-entry `credential_id`; they rely on context correlation. +- No dedicated test for Anthropic transform tracing. +- Responses service has redundant `if transaction_logger` guards before `_trace()` calls. +- Native provider streaming does not apply adapter chains to stream events. +- Explicit LiteLLM fallback is traceable by metadata but has no pass named `litellm_fallback_request`. + +## Phase 3 / 3b: Adapters And Field Cache + +### Blockers + +- None reported. + +### High + +- None reported. + +### Medium + +- Runtime does not execute all declared field-cache sources and targets. + - `request`, `unified_request`, `unified_response`, and `unified_stream_event` are accepted, but live native execution only injects `request` and extracts `response` / `stream_event`. +- Adapter-chain traces can leak arbitrary configured provider-state fields. + - Native executor applies rule-aware redaction around its own traces, but `run_adapter_chain()` logs payloads directly. +- Credential-scoped field-cache rules do not fail closed when `credential_id` is missing. + - Missing credential can become a shared hashed `_none` bucket in direct/library native executor usage. + +### Low / Residual + +- Copilot provider has no field-cache rules by design. +- `ProviderCacheFieldStore.clear()` depends on injected cache exposing `clear()`. +- Turn-mode inference defaults to common `messages` shape; custom providers need explicit metadata. +- `per_tool_call` injection requires tool-call ID in context metadata or payload path. +- `FieldCacheEngine` defaults to a fresh in-memory store if used directly; `NativeProviderExecutor` correctly shares one per executor instance. +- Native streaming path extracts from stream events but does not inject into stream events, which is expected. +- `ProviderCacheFieldStore.append()` may serialize already wrapped values twice in edge cases. +- Process-local default store is per `NativeProviderExecutor` instance, not a global singleton. +- JSON parsing accepts non-positive `ttl_seconds`, which effectively disables expiry. +- Provider declarations are still conservative/mock-live for several priority providers. + +## Phase 4 / 4b: Responses API + +### Blockers + +- None reported. + +### High + +- None reported. + +### Medium + +- Responses storage is process-local by default; durable JSON/current-state or provider-cache-backed storage is not wired into app startup. +- `previous_response_id` bridge context is lossy. + - The bridge replays parent output only, not parent input or full lineage. +- Responses route errors are wrapped as FastAPI `detail.error` instead of top-level OpenAI-compatible `error` bodies. + +### Low / Residual + +- `ProviderCacheResponsesStore.delete()` returns `False` if the injected cache lacks `delete_async`. +- WebSocket formatter exists but has no FastAPI route or service-neutral integration test. +- `_record_responses_session_anchor()` no-ops if `session_tracker` or `session_id` is unavailable. +- No periodic TTL cleanup task; expired entries prune on access/write. +- Bridge output conversion handles text/tool calls but not all richer Responses item types. +- `store_in_progress` returns opt-in `in_progress` state without a formal retrieval schema. +- `ResponsesStreamState` uses fixed `msg_0` for bridge streams. +- No explicit concurrent access tests for `InMemoryResponsesStore`. +- No cancel endpoint. +- `validate_stream_request()` and `stream_events()` can load the same parent response twice. +- `stream_response(..., transport=...)` always uses SSE formatting. +- Formatted SSE frames themselves are not traced. +- Tests do not include end-to-end FastAPI + real `RotatingClient` sticky continuation or durable app-level storage wiring. + +## Phase 5 / 5b: Provider-Native Integrations + +### Blockers + +- Native provider responses are formatted back to the provider protocol, not the originating client protocol. + - `/v1/chat/completions` routed to Claude Code, Codex, or Antigravity can return Anthropic Messages, Responses API, or Gemini payloads instead of OpenAI Chat payloads. + - Native context does not carry a client target protocol; native executor uses the provider protocol for `format_response()`. + +### High + +- Claude Code native requests can omit required Anthropic `max_tokens`. +- Antigravity model normalization is unsafe for duplicated aliases and can lose low/high thinking-level behavior. + +### Medium + +- `get_native_headers()` and `get_native_endpoint()` are not declared on `ProviderInterface` even though the executor requires them dynamically. +- Native streaming is disabled for all priority providers. This is safe, but it leaves the native streaming path unexercised by real priority providers. +- Claude Code auth header convention may be wrong for API-key credentials because it always uses `Authorization: Bearer` with `anthropic-version`. +- ProviderInterface native contract lacks `supports_native_operation()` / `should_use_native_protocol()` style hooks, making non-stream native auto-selection all-or-nothing by protocol declaration. +- Native streaming fail-closed behavior is duplicated and only partially fail-closed; the executor duplicate catches fewer exceptions than the safe helper. +- Antigravity quota grouping regresses retired parity by grouping models too broadly. +- Field-cache declarations exist, but thinking/signature reinjection is not proven semantically correct for protocol-native continuation requirements. + +### Low / Residual + +- Copilot has no field-cache rules by design. +- Codex message-to-Responses conversion is minimal and does not handle tool calls, multipart content, images, or rich system formatting. +- Antigravity intentionally restores only a conservative safe subset; many retired features remain absent. +- Antigravity has quota groups but no usage reset configs or tier priorities. +- Priority provider model discovery catches broad exceptions and silently falls back to hardcoded lists. +- Execution-mode routing priority is correct and fail-closed. +- Native streaming fail-closed behavior is tested for explicit native streaming. +- Mock-live tests cover direct native request execution for all four priority providers, but not full end-to-end credential rotation/usage manager execution. +- Some provider docstrings still describe skeleton behavior. + +## Phase 6 / 6b: Routing And Fallback + +### Blockers + +- Stale fallback test file imports a missing module, so broad test collection can fail. + - `tests/test_fallback_groups.py` imports `rotator_library.fallback_groups.FallbackGroupManager`, which no longer exists. + - The same file asserts obsolete `RotatingClient(model_fallback_groups=...)` / `fallback_manager` behavior. + +### High + +- Requested-model-in-group promotion is not implemented. + - If the requested model is inside a fallback group, the router should try it first and preserve remaining group order. +- Streaming execution modes do not match non-streaming execution-mode behavior. + - Streaming can ignore explicit `@litellm_fallback` when a plugin has custom logic and can fall through differently for `@custom`. +- Explicit native non-streaming configuration errors can be treated as retryable fallback errors because some `RoutingExecutionError`s default to `unsupported_operation`. +- Structured route-error classification misses common status-less aliases such as `invalid_api_key`, `unauthorized`, `invalid_argument`, `rate_limited`, `too_many_requests`, `resource_exhausted`, and `unavailable`. + +### Medium + +- Fallback target session namespace is cloned from the first target instead of being re-inferred or adjusted per target, risking response anchors under the wrong provider/model namespace. + +### Low / Residual + +- No end-to-end test for env config -> `RequestContextBuilder` -> `RequestExecutor.execute()` full fallback scenario. +- `build_embedding_context()` does not resolve routing groups. +- No explicit test for `_execute_non_streaming_with_fallback()` when all targets return structured error responses. +- `_route_error_type_from_response()` falls back to text-summary scanning when structured fields are ambiguous. +- Fallback exhaustion summaries are sanitized and appear safe. +- Routing route aliases are lowercase keys only; there is no hyphen/space alias normalization. +- Wrapper-level streaming tests are thinner than helper-level coverage for some event frames. + +## Phase 7 / 7b: Retry, Cooldown, Backoff, Failure History + +### Blockers + +- None reported. + +### High + +- Provider/model cooldown can be bypassed when the remaining cooldown exceeds the request deadline budget. + - `_wait_for_cooldown()` logs and returns, and callers continue into credential acquisition/execution. +- Generic transient errors start provider-wide cooldown too aggressively. + - A single `server_error` or `api_connection` without `retry_after` can start provider cooldown immediately. + +### Medium + +- `decide_streaming_error_action()` is implemented and tested but not used by the live executor. +- Failure history lacks success reset; successes do not clear or reduce provider/model failure history. +- Structured retry attempt history is not populated beyond traces/summaries; `routing_attempt_history` exists but is unused in live executor. + +### Low / Residual + +- `FallbackAttemptRunner` is tested but not used by the live executor; inline wrappers remain. +- `decide_streaming_error_action()` does not accept `failure_history`, which would matter if it is adopted live. +- No `test_request_executor_group_policy.py` file despite the plan mentioning it. +- Streaming error branches in the executor are duplicated across several handlers. +- Model-capacity detection is narrow and misses variants such as `CAPACITY_EXHAUSTED` or `capacity-exhausted`. +- Tests miss cooldown-over-budget fail-fast/fallback, single generic transient non-cooldown, success reset, live model cooldown not blocking unrelated models, and live streaming parity with the decision helper. + +## Phase 8 / 8b: Streaming Hardening + +### Blockers + +- None reported. + +### High + +- None reported. + +### Medium + +- Responses streaming bypasses active Phase 8b hardening at the Responses layer. + - No direct TTFB/stall/heartbeat/disconnect/close policy is applied in `ResponsesService.stream_events()`. + - Responses formatter has no heartbeat formatter. +- Anthropic compatibility streaming can stop on disconnect without explicitly closing the upstream OpenAI stream. + +### Low / Residual + +- No formal `StreamTransportFormatter` protocol/ABC. +- Executor streaming error handling remains duplicated across branches. +- Responses streaming does not emit heartbeats or enforce TTFB/stall policy directly. +- Native streaming executor does not integrate with `StreamingHandler.wrap_stream()`. +- Native `_parse_stream_line()` returns raw text on JSON parse failure. +- No specific test for stall timeout after visible output. +- Normalized stream event parser comment says malformed chunks fail closed for visibility, but parser returns `visible_output=False`; routing safety is protected elsewhere. +- `decide_streaming_error_action()` is not used by the live executor. + +## Phase 9 / 9b: Usage, Quota, Cost + +### Blockers + +- None reported. + +### High + +- Provider-reported top-level response costs are ignored when a `usage` object exists. + - `extract_usage_record()` unwraps full responses to `response["usage"]`, discarding sibling `cost`, `total_cost`, `cost_details`, or `provider_reported_cost`. +- OpenAI-like cache-write tokens can be double-counted in normalized totals/costs. + - `input_tokens = prompt_tokens - cache_read`, while `cache_write_tokens` is stored separately and also added to total. +- Responses streaming drops SSE cost comments/events unless final usage carries cost. + +### Medium + +- `event: cost` frames can block streaming retry/fallback despite being metadata. +- Native streaming executor does not emit `usage_accounting_summary` trace. +- Native streaming executor does not preserve provider-reported cost from raw response. +- Structured provider cost breakdowns without `total_cost` are not preserved or summed. + +### Low / Residual + +- Responses bridge passes cost fields through verbatim rather than structuring them, which is acceptable but should be documented. +- No explicit native streaming cost-trace test. +- Streaming cost calculator does not pass `provider_plugin`, so provider explicit pricing can be missed in streaming mode when no provider-reported cost exists. +- Quota snapshot `used` maps to request count, not token count; metadata says request/token window but field naming is ambiguous. +- No integration-level test for Responses usage with provider-reported cost. +- No quota snapshot test for combined model and group filters. +- Planned trace pass names like `usage_cost_calculated` or `quota_snapshot_built` are not implemented if considered distinct from `usage_accounting_summary`. + +## Phase 10 / 10b: Config + +### Blockers + +- None reported. + +### High + +- Full `PROXY_API_KEY` is printed to console at startup. +- Provider/protocol/adapter/model/quota-checker config surfaces are still not live-wired, despite the `providers` JSON section being accepted/documented. + +### Medium + +- Responses streaming ignores Phase 10 streaming runtime settings. +- Pricing env parsing can fail requests instead of warning/ignoring invalid values. +- Routing config validation is not fully startup-safe; direct model-route target specs can parse later at request time. +- Secret rejection is key-name based and misses generic `credential` / `credentials` key names. + +### Low / Residual + +- `STREAM_RETRY_ON_REASONING_ONLY` is env-only and not part of `StreamRuntimeSettings` JSON config. +- `get_responses_store_settings()` returns `Any` rather than a precise type. +- `provider_cooldown_env()` remains a legacy tuple wrapper over richer retry settings. +- `STREAM_HEARTBEAT_SECONDS` legacy alias resolution is subtle. +- No standalone `config-reference.md` was created. +- `providers` section accepts raw dicts but does not validate protocol/adapter names; this is conservative but incomplete for future wiring. +- `.env.example` drift test does not verify the legacy `STREAM_HEARTBEAT_SECONDS=0` default line. + +## Cross-Phase Remediation Order + +1. Phase 5 blocker: native provider output must be converted back to the originating client protocol. +2. Phase 6 blocker: retire/update stale fallback tests and fix routing execution correctness. +3. Phase 10 high: stop printing secrets and make accepted config surfaces honest/live-wired or reject unsupported shapes. +4. Phase 1 high/medium protocol formatting and operation enforcement. +5. Phase 9 high usage/cost correctness. +6. Phase 7 high cooldown budget and transient cooldown behavior. +7. Phase 2 tracing gaps, especially Anthropic compatibility and redaction. +8. Phase 3 field-cache runtime/redaction gaps. +9. Phase 4 Responses storage/continuation/error-shape gaps. +10. Phase 8 streaming hardening gaps for Responses and Anthropic compatibility. + +## Required Completion Bar + +- Every blocker/high/medium listed above must be fixed or explicitly re-reviewed as no longer applicable. +- Low-risk items should be fixed when small/safe; otherwise they must be carried into phase reports as explicit residuals. +- Each remediation phase must follow the established loop: plan in conversation, write plan doc, implement focused commits, run focused tests, run `explore` and `explore-heavy`, fix findings, re-review, write user-facing report without committing the report. From c9472733b646444d4328ba46895c1665494aedb6 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:04:34 +0200 Subject: [PATCH 148/182] docs(experimental): plan protocol guardrail fixes Adds the Phase 1c corrective plan for protocol-native usage formatting, legacy function_call modeling, Ollama response formatting, and native operation enforcement. Tests: not run (planning document only) --- ...1c-protocol-output-operation-guardrails.md | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 docs/experimental/phase-1c-protocol-output-operation-guardrails.md diff --git a/docs/experimental/phase-1c-protocol-output-operation-guardrails.md b/docs/experimental/phase-1c-protocol-output-operation-guardrails.md new file mode 100644 index 000000000..a9327d3fa --- /dev/null +++ b/docs/experimental/phase-1c-protocol-output-operation-guardrails.md @@ -0,0 +1,59 @@ +# Phase 1c: Protocol Output Correctness And Native Operation Guardrails + +## Goal + +Close the Phase 1/1b third-pass findings that still affect protocol correctness and later native provider behavior. + +## Scope + +- Fix OpenAI Chat formatted usage so it emits public OpenAI-compatible usage fields rather than unified internal usage fields. +- Fix Responses formatted usage so it emits Responses-compatible usage fields rather than unified internal usage fields. +- Promote OpenAI legacy `function_call` into unified `ToolCall` while preserving legacy round-trip shape. +- Add Ollama response formatting that respects mutated unified responses instead of stale raw payloads. +- Enforce protocol operation compatibility in native execution before transport calls. + +## Non-Goals + +- Do not solve every Phase 5 provider issue here beyond the protocol guardrail needed by native execution. +- Do not enable native streaming for priority providers. +- Do not rewrite protocol registry discovery unless a focused test exposes a direct failure. +- Do not touch unrelated dirty files or user-facing reports. + +## Implementation Plan + +1. OpenAI Chat usage formatting. + - Add a helper that converts `Usage` to `prompt_tokens`, `completion_tokens`, `total_tokens`, `prompt_tokens_details`, `completion_tokens_details`, and `cost_details`. + - Ensure `input_tokens`, `output_tokens`, `raw`, and `extra` do not leak into formatted OpenAI usage. + +2. Responses usage formatting. + - Add a helper that converts `Usage` to `input_tokens`, `output_tokens`, `total_tokens`, `input_tokens_details`, `output_tokens_details`, and `cost_details`. + - Preserve cache-write information only as a safe extension inside `input_tokens_details` when present. + +3. Legacy OpenAI `function_call`. + - Parse legacy `function_call` into a unified `ToolCall` when modern `tool_calls` are not present. + - Preserve the raw legacy field in `extra` so formatting emits `function_call` instead of incorrectly upgrading to `tool_calls`. + +4. Ollama response formatting. + - Implement `OllamaProtocol.format_response()` for chat, generate, and embeddings shapes. + - Merge non-core `extra` and usage/timing fields while honoring mutated unified messages/output/data. + +5. Native operation enforcement. + - In `NativeProviderExecutor.execute()` and `stream()`, reject unsupported protocol operations before parse/build/transport. + - Keep errors sanitized and explicit so bad provider/protocol pairings fail early. + +## Tests + +- `tests/test_protocol_openai_chat.py` +- `tests/test_protocol_responses.py` +- `tests/test_protocol_ollama_mcp.py` +- `tests/test_native_provider_executor.py` or `tests/test_request_executor_native_routing.py` +- Full protocol suite plus native routing smoke tests before review. + +## Acceptance Criteria + +- OpenAI Chat formatted responses expose OpenAI-compatible usage fields. +- Responses formatted responses expose Responses-compatible usage fields. +- Legacy OpenAI `function_call` is represented in unified tool calls and round-trips as legacy `function_call`. +- Ollama formatted responses reflect mutated unified state. +- Native execution rejects unsupported protocol operations before network transport. +- Focused tests pass and both `explore` and `explore-heavy` reviewers report no blockers/highs/mediums. From 5d2113f39e377bb282444acfa32e3eb30d625d3d Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:09:49 +0200 Subject: [PATCH 149/182] fix(protocols): format public usage shapes Formats OpenAI Chat and Responses usage with protocol-native field names, models legacy OpenAI function_call as unified tool calls while preserving legacy output, adds Ollama response formatting from mutated unified state, and rejects unsupported native operations before transport. Tests: pytest tests/test_protocol_operation_model.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_gemini.py tests/test_protocol_responses.py tests/test_protocol_openai_embeddings.py tests/test_protocol_openai_images_audio.py tests/test_protocol_ollama_mcp.py tests/test_protocol_registry.py tests/test_native_provider_executor.py tests/test_request_executor_native_routing.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_gemini_cli_protocol_declarations.py --- .../native_provider/executor.py | 17 ++++- src/rotator_library/protocols/gemini.py | 6 +- src/rotator_library/protocols/ollama.py | 39 +++++++++++- src/rotator_library/protocols/openai_chat.py | 62 ++++++++++++++++++- src/rotator_library/protocols/responses.py | 40 +++++++++++- tests/test_native_provider_executor.py | 22 +++++++ tests/test_protocol_ollama_mcp.py | 17 +++++ tests/test_protocol_openai_chat.py | 32 ++++++++++ tests/test_protocol_responses.py | 10 +++ 9 files changed, 236 insertions(+), 9 deletions(-) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index 61d554d46..a69c9fa02 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -12,7 +12,7 @@ from ..adapters import get_adapter, run_adapter_chain from ..field_cache import FieldCacheEngine, InMemoryFieldCacheStore from ..field_cache.paths import FieldCachePathError, PathToken, parse_path -from ..protocols import get_protocol, serialize_value +from ..protocols import ProtocolError, get_protocol, serialize_value from ..transform_trace import REDACTED from ..usage.accounting import extract_usage_record from ..usage.costs import CostCalculator @@ -37,6 +37,7 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont logger = context.transaction_logger protocol = get_protocol(context.protocol_name) + self._ensure_supported_operation(protocol, context) self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") try: self._trace(context, "raw_native_client_request", raw_request, direction="request", stage="client") @@ -113,6 +114,7 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte logger = context.transaction_logger protocol = get_protocol(context.protocol_name) + self._ensure_supported_operation(protocol, context) self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") try: self._trace(context, "raw_native_client_request", raw_request, direction="request", stage="client") @@ -200,6 +202,19 @@ def _trace( snapshot=snapshot, ) + @staticmethod + def _ensure_supported_operation(protocol: Any, context: NativeProviderContext) -> None: + """Fail before transport when provider and protocol operations disagree.""" + + if protocol.supports_operation(context.operation): + return + raise ProtocolError( + f"provider {context.provider} requested unsupported operation {context.operation!r}", + protocol=protocol.name, + pass_name="native_operation_check", + payload={"provider": context.provider, "model": context.model, "operation": context.operation}, + ) + def _redact_field_cache_paths(data: Any, context: NativeProviderContext, direction: str) -> Any: """Redact configured cache paths before broad native payload traces. diff --git a/src/rotator_library/protocols/gemini.py b/src/rotator_library/protocols/gemini.py index 203810ad8..de9157d23 100644 --- a/src/rotator_library/protocols/gemini.py +++ b/src/rotator_library/protocols/gemini.py @@ -55,7 +55,7 @@ class GeminiProtocol(ProtocolAdapter): name: ClassVar[str] = "gemini" aliases: ClassVar[tuple[str, ...]] = ("google_gemini", "generate_content") supported_transports: ClassVar[tuple[str, ...]] = ("http", "sse") - supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_CHAT, OPERATION_COUNT_TOKENS) + supported_operations: ClassVar[tuple[str, ...]] = (OPERATION_CHAT, OPERATION_COUNT_TOKENS, "generate", "stream_generate") def parse_request(self, raw_request: dict[str, Any], context: ProtocolContext | None = None) -> UnifiedRequest: request = dict(raw_request or {}) @@ -356,7 +356,7 @@ def _as_dict(value: Any) -> dict[str, Any]: def _operation_from_context(context: ProtocolContext | None, default: str) -> str: - supported = {OPERATION_CHAT, OPERATION_COUNT_TOKENS} + supported = {OPERATION_CHAT, OPERATION_COUNT_TOKENS, "generate", "stream_generate"} if context and isinstance(context.provider_options, dict): operation = normalize_operation(context.provider_options.get("operation")) if operation in supported: @@ -374,7 +374,7 @@ def _response_operation(response: dict[str, Any], context: ProtocolContext | Non return OPERATION_COUNT_TOKENS if "totalTokens" in response and "candidates" not in response: return OPERATION_COUNT_TOKENS - return OPERATION_CHAT + return requested if requested in {"generate", "stream_generate"} else OPERATION_CHAT def _decode_sse_data(raw_event: Any) -> Any: diff --git a/src/rotator_library/protocols/ollama.py b/src/rotator_library/protocols/ollama.py index 45e613820..029a71fac 100644 --- a/src/rotator_library/protocols/ollama.py +++ b/src/rotator_library/protocols/ollama.py @@ -11,7 +11,7 @@ from .base import ProtocolAdapter from .operation import OPERATION_EMBEDDINGS, OPERATION_OLLAMA_CHAT, OPERATION_OLLAMA_GENERATE, normalize_operation -from .types import ContentBlock, ProtocolContext, UnifiedMessage, UnifiedRequest, UnifiedResponse, UnifiedStreamEvent, Usage +from .types import ContentBlock, ProtocolContext, UnifiedMessage, UnifiedRequest, UnifiedResponse, UnifiedStreamEvent, Usage, first_text _OPTION_FIELDS = {"options", "format", "keep_alive", "template", "context", "raw", "suffix"} _CORE_FIELDS = {"operation", "model", "messages", "prompt", "input", "stream", "system", *_OPTION_FIELDS} @@ -84,6 +84,33 @@ def parse_response(self, raw_response: Any, context: ProtocolContext | None = No extra={k: deepcopy(v) for k, v in response.items() if k not in {"model", "message", "response", "embeddings", "embedding", "prompt_eval_count", "eval_count", "total_duration", "load_duration", "prompt_eval_duration", "eval_duration"}}, ) + def format_response(self, unified_response: UnifiedResponse, context: ProtocolContext | None = None) -> dict[str, Any]: + """Format a unified response back to an Ollama response shape. + + Ollama responses are often mutated by adapters after parsing. Do not + return `raw` wholesale here; rebuild the public fields from unified state + and then merge preserved extras/timing values. + """ + + payload = deepcopy(unified_response.extra) + if unified_response.model: + payload["model"] = unified_response.model + operation = unified_response.operation or _ollama_operation(payload, context) + if operation == OPERATION_OLLAMA_CHAT: + if unified_response.messages: + payload["message"] = _message_to_ollama(unified_response.messages[0]) + elif operation == OPERATION_EMBEDDINGS: + raw = unified_response.raw if isinstance(unified_response.raw, dict) else {} + key = "embedding" if "embedding" in raw and "embeddings" not in raw else "embeddings" + payload[key] = deepcopy(unified_response.data) + else: + payload["response"] = _ollama_response_text(unified_response) + if unified_response.usage and isinstance(unified_response.usage.raw, dict): + for key, value in unified_response.usage.raw.items(): + if key.endswith("count") or key.endswith("duration"): + payload[key] = deepcopy(value) + return {k: v for k, v in payload.items() if v is not None} + def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: data = _json_event(raw_event) if not isinstance(data, dict): @@ -128,6 +155,16 @@ def _message_to_ollama(message: UnifiedMessage) -> dict[str, Any]: return payload +def _ollama_response_text(response: UnifiedResponse) -> str: + if response.output: + return "".join(str(item) for item in response.output if item is not None) + if response.messages: + text = first_text(response.messages[0].content) + if text is not None: + return text + return "" + + def _json_event(raw_event: Any) -> Any: if isinstance(raw_event, dict): return raw_event diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index 75c86d39a..20bc63695 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -167,7 +167,7 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo "created": unified_response.metadata.get("created"), "model": unified_response.model, "choices": choices, - "usage": unified_response.usage.to_dict() if unified_response.usage else None, + "usage": _format_openai_usage(unified_response.usage), } payload.update(deepcopy(unified_response.extra)) return {k: v for k, v in payload.items() if v is not None} @@ -255,7 +255,7 @@ def _parse_message(self, message: dict[str, Any]) -> UnifiedMessage: content=self._parse_content(payload.get("content")), name=payload.get("name"), tool_call_id=payload.get("tool_call_id"), - tool_calls=[self._parse_tool_call(call) for call in payload.get("tool_calls") or []], + tool_calls=self._parse_message_tool_calls(payload), reasoning=reasoning, raw=deepcopy(message), extra={k: deepcopy(v) for k, v in payload.items() if k not in {"role", "content", "name", "tool_call_id", "tool_calls", "reasoning", "reasoning_content"}}, @@ -270,7 +270,8 @@ def _format_message(self, message: UnifiedMessage) -> dict[str, Any]: content = self._format_content(message.content) if content is not None: payload["content"] = content - if message.tool_calls: + legacy_function_call = message.extra.get("function_call") + if message.tool_calls and not legacy_function_call: payload["tool_calls"] = [self._format_tool_call(call) for call in message.tool_calls] if message.reasoning: # OpenAI-compatible providers use multiple names for reasoning text. @@ -281,6 +282,26 @@ def _format_message(self, message: UnifiedMessage) -> dict[str, Any]: payload.update(deepcopy(message.extra)) return payload + def _parse_message_tool_calls(self, payload: dict[str, Any]) -> list[ToolCall]: + """Return modern and legacy OpenAI function calls as unified tools.""" + + modern_calls = payload.get("tool_calls") or [] + if modern_calls: + return [self._parse_tool_call(call) for call in modern_calls] + legacy_call = payload.get("function_call") + if isinstance(legacy_call, dict): + return [ + ToolCall( + id=None, + name=legacy_call.get("name"), + arguments=legacy_call.get("arguments"), + type="function", + raw=deepcopy(legacy_call), + extra={"legacy_function_call": True}, + ) + ] + return [] + def _parse_content(self, content: Any) -> list[ContentBlock]: if content is None: return [] @@ -426,5 +447,40 @@ def _format_arguments(arguments: Any) -> str: return json.dumps(serialize_value(arguments), separators=(",", ":")) +def _format_openai_usage(usage: Usage | None) -> dict[str, Any] | None: + """Format normalized usage using OpenAI Chat's public field names.""" + + if usage is None: + return None + payload: dict[str, Any] = { + "prompt_tokens": usage.input_tokens, + "completion_tokens": usage.output_tokens, + "total_tokens": usage.total_tokens or (usage.input_tokens + usage.output_tokens), + } + prompt_details: dict[str, Any] = {} + if usage.cache_read_tokens: + prompt_details["cached_tokens"] = usage.cache_read_tokens + if usage.cache_write_tokens: + prompt_details["cache_creation_tokens"] = usage.cache_write_tokens + if prompt_details: + payload["prompt_tokens_details"] = prompt_details + completion_details: dict[str, Any] = {} + if usage.reasoning_tokens: + completion_details["reasoning_tokens"] = usage.reasoning_tokens + if completion_details: + payload["completion_tokens_details"] = completion_details + if usage.cost: + cost_details: dict[str, Any] = dict(usage.cost.metadata) + if usage.cost.provider_reported_cost is not None: + cost_details["total_cost"] = usage.cost.provider_reported_cost + elif usage.cost.estimated_cost is not None: + cost_details["estimated_cost"] = usage.cost.estimated_cost + cost_details["currency"] = usage.cost.currency + if usage.cost.source: + cost_details["source"] = usage.cost.source + payload["cost_details"] = cost_details + return payload + + def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py index ffeb89e25..a612e3061 100644 --- a/src/rotator_library/protocols/responses.py +++ b/src/rotator_library/protocols/responses.py @@ -150,7 +150,7 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo "model": unified_response.model, "status": unified_response.stop_reason, "output": output, - "usage": unified_response.usage.to_dict() if unified_response.usage else None, + "usage": _format_responses_usage(unified_response.usage), } payload.update(deepcopy(unified_response.extra)) return {k: v for k, v in payload.items() if v is not None} @@ -390,3 +390,41 @@ def _reasoning_text(item: dict[str, Any]) -> str | None: def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: return {k: deepcopy(v) for k, v in payload.items() if k not in keys} + + +def _format_responses_usage(usage: Usage | None) -> dict[str, Any] | None: + """Format normalized usage using OpenAI Responses public field names.""" + + if usage is None: + return None + payload: dict[str, Any] = { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "total_tokens": usage.total_tokens or (usage.input_tokens + usage.output_tokens), + } + input_details: dict[str, Any] = {} + if usage.cache_read_tokens: + input_details["cached_tokens"] = usage.cache_read_tokens + if usage.cache_write_tokens: + # OpenAI Responses does not have a universal cache-write field, but this + # extension keeps provider-reported cache creation visible without + # leaking the unified internal `cache_write_tokens` key. + input_details["cache_creation_tokens"] = usage.cache_write_tokens + if input_details: + payload["input_tokens_details"] = input_details + output_details: dict[str, Any] = {} + if usage.reasoning_tokens: + output_details["reasoning_tokens"] = usage.reasoning_tokens + if output_details: + payload["output_tokens_details"] = output_details + if usage.cost: + cost_details: dict[str, Any] = dict(usage.cost.metadata) + if usage.cost.provider_reported_cost is not None: + cost_details["total_cost"] = usage.cost.provider_reported_cost + elif usage.cost.estimated_cost is not None: + cost_details["estimated_cost"] = usage.cost.estimated_cost + cost_details["currency"] = usage.cost.currency + if usage.cost.source: + cost_details["source"] = usage.cost.source + payload["cost_details"] = cost_details + return payload diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index a597d7a5f..e7c9a45be 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -6,6 +6,7 @@ from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from rotator_library.protocols import ProtocolError from rotator_library.transaction_logger import TransactionLogger @@ -186,3 +187,24 @@ async def test_native_provider_executor_logs_transform_errors(tmp_path) -> None: pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] assert "transform_log_error" in pass_names + + +@pytest.mark.asyncio +async def test_native_provider_executor_rejects_unsupported_operation_before_transport() -> None: + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + operation="embeddings", + endpoint="https://example.test/chat", + ) + client = FakeHTTPClient({"id": "should_not_call"}) + + with pytest.raises(ProtocolError, match="unsupported operation"): + await NativeProviderExecutor().execute( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(client), + ) + + assert client.calls == [] diff --git a/tests/test_protocol_ollama_mcp.py b/tests/test_protocol_ollama_mcp.py index fbd494565..6cbee9ab9 100644 --- a/tests/test_protocol_ollama_mcp.py +++ b/tests/test_protocol_ollama_mcp.py @@ -40,6 +40,23 @@ def test_ollama_chat_generate_and_stream_shapes() -> None: assert chunk.usage.output_tokens == 2 +def test_ollama_format_response_uses_mutated_unified_fields() -> None: + adapter = get_protocol("ollama") + chat = adapter.parse_response({"model": "llama3", "message": {"role": "assistant", "content": "before"}, "done": True}) + chat.messages[0].content[0].text = "after" + + generate = adapter.parse_response({"model": "llama3", "response": "before", "eval_count": 2}) + generate.output[0] = "after" + + embeddings = adapter.parse_response({"model": "llama3", "embedding": [0.1, 0.2]}) + embeddings.data = [0.3, 0.4] + + assert adapter.format_response(chat)["message"]["content"] == "after" + assert adapter.format_response(generate)["response"] == "after" + assert adapter.format_response(generate)["eval_count"] == 2 + assert adapter.format_response(embeddings)["embedding"] == [0.3, 0.4] + + def test_mcp_jsonrpc_round_trip_and_error_preservation() -> None: adapter = get_protocol("mcp") request = {"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "lookup"}} diff --git a/tests/test_protocol_openai_chat.py b/tests/test_protocol_openai_chat.py index fce7f39db..e2c8af7a3 100644 --- a/tests/test_protocol_openai_chat.py +++ b/tests/test_protocol_openai_chat.py @@ -135,6 +135,38 @@ def test_openai_chat_response_extracts_usage_cost_and_reasoning() -> None: assert unified.usage.cost is not None assert unified.usage.cost.provider_reported_cost == 0.01 + formatted = adapter.format_response(unified) + assert formatted["usage"]["prompt_tokens"] == 10 + assert formatted["usage"]["completion_tokens"] == 5 + assert formatted["usage"]["total_tokens"] == 18 + assert formatted["usage"]["prompt_tokens_details"] == {"cached_tokens": 3, "cache_creation_tokens": 2} + assert formatted["usage"]["completion_tokens_details"] == {"reasoning_tokens": 3} + assert formatted["usage"]["cost_details"]["total_cost"] == 0.01 + assert "input_tokens" not in formatted["usage"] + assert "output_tokens" not in formatted["usage"] + + +def test_openai_legacy_function_call_is_unified_and_round_trips() -> None: + adapter = get_protocol("openai_chat") + raw = { + "model": "openai/gpt-test", + "messages": [ + { + "role": "assistant", + "content": None, + "function_call": {"name": "lookup", "arguments": "{\"q\":\"x\"}"}, + } + ], + } + + unified = adapter.parse_request(raw) + rebuilt = adapter.build_request(unified) + + assert unified.messages[0].tool_calls[0].name == "lookup" + assert unified.messages[0].tool_calls[0].extra["legacy_function_call"] is True + assert rebuilt["messages"][0]["function_call"] == {"name": "lookup", "arguments": "{\"q\":\"x\"}"} + assert "tool_calls" not in rebuilt["messages"][0] + def test_openai_chat_stream_event_parses_sse_delta_and_done() -> None: adapter = get_protocol("openai_chat") diff --git a/tests/test_protocol_responses.py b/tests/test_protocol_responses.py index 0ae2d752c..7b578a923 100644 --- a/tests/test_protocol_responses.py +++ b/tests/test_protocol_responses.py @@ -80,6 +80,16 @@ def test_responses_response_extracts_output_items_reasoning_calls_and_usage() -> assert unified.usage.cost is not None assert unified.usage.cost.provider_reported_cost == 0.02 + formatted = adapter.format_response(unified) + assert formatted["usage"]["input_tokens"] == 10 + assert formatted["usage"]["output_tokens"] == 5 + assert formatted["usage"]["total_tokens"] == 18 + assert formatted["usage"]["input_tokens_details"] == {"cached_tokens": 3} + assert formatted["usage"]["output_tokens_details"] == {"reasoning_tokens": 3} + assert formatted["usage"]["cost_details"]["total_cost"] == 0.02 + assert "raw" not in formatted["usage"] + assert "extra" not in formatted["usage"] + def test_responses_format_preserves_unknown_output_items() -> None: adapter = get_protocol("responses") From bfcfa33dfb094e43a754c8cc01f31e313b9f1ec4 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:15:45 +0200 Subject: [PATCH 150/182] fix(protocols): close Phase 1c review gaps Preserves Responses cache-write usage through parse and format, reflects mutated legacy function_call tool calls, formats mutated Ollama usage counts, and covers streaming unsupported-operation rejection. Tests: pytest tests/test_protocol_openai_chat.py tests/test_protocol_responses.py tests/test_protocol_ollama_mcp.py tests/test_native_provider_executor.py tests/test_protocol_operation_model.py tests/test_request_executor_native_routing.py --- src/rotator_library/protocols/ollama.py | 6 +++++- src/rotator_library/protocols/openai_chat.py | 8 +++++-- src/rotator_library/protocols/responses.py | 1 + tests/test_native_provider_executor.py | 22 ++++++++++++++++++++ tests/test_protocol_ollama_mcp.py | 3 ++- tests/test_protocol_openai_chat.py | 4 ++++ tests/test_protocol_responses.py | 5 +++-- 7 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/rotator_library/protocols/ollama.py b/src/rotator_library/protocols/ollama.py index 029a71fac..748cca1c7 100644 --- a/src/rotator_library/protocols/ollama.py +++ b/src/rotator_library/protocols/ollama.py @@ -107,8 +107,12 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo payload["response"] = _ollama_response_text(unified_response) if unified_response.usage and isinstance(unified_response.usage.raw, dict): for key, value in unified_response.usage.raw.items(): - if key.endswith("count") or key.endswith("duration"): + if key.endswith("duration"): payload[key] = deepcopy(value) + if unified_response.usage.input_tokens: + payload["prompt_eval_count"] = unified_response.usage.input_tokens + if unified_response.usage.output_tokens: + payload["eval_count"] = unified_response.usage.output_tokens return {k: v for k, v in payload.items() if v is not None} def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = None) -> UnifiedStreamEvent: diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index 20bc63695..1a426449d 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -270,16 +270,20 @@ def _format_message(self, message: UnifiedMessage) -> dict[str, Any]: content = self._format_content(message.content) if content is not None: payload["content"] = content - legacy_function_call = message.extra.get("function_call") + extra = deepcopy(message.extra) + legacy_function_call = extra.get("function_call") if message.tool_calls and not legacy_function_call: payload["tool_calls"] = [self._format_tool_call(call) for call in message.tool_calls] + elif message.tool_calls and legacy_function_call: + call = message.tool_calls[0] + extra["function_call"] = {"name": call.name or "", "arguments": _format_arguments(call.arguments)} if message.reasoning: # OpenAI-compatible providers use multiple names for reasoning text. # Prefer the common extension field while keeping all blocks in extra. text = "".join(block.text or "" for block in message.reasoning if block.text) if text: payload["reasoning_content"] = text - payload.update(deepcopy(message.extra)) + payload.update(extra) return payload def _parse_message_tool_calls(self, payload: dict[str, Any]) -> list[ToolCall]: diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py index a612e3061..20afbaba7 100644 --- a/src/rotator_library/protocols/responses.py +++ b/src/rotator_library/protocols/responses.py @@ -198,6 +198,7 @@ def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = N output_tokens=int(usage.get("output_tokens") or 0), total_tokens=int(usage.get("total_tokens") or 0), cache_read_tokens=int(input_details.get("cached_tokens") or 0), + cache_write_tokens=int(input_details.get("cache_creation_tokens") or usage.get("cache_creation_tokens") or 0), reasoning_tokens=int(output_details.get("reasoning_tokens") or 0), cost=cost, raw=deepcopy(usage), diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index e7c9a45be..4d50c8d19 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -208,3 +208,25 @@ async def test_native_provider_executor_rejects_unsupported_operation_before_tra ) assert client.calls == [] + + +@pytest.mark.asyncio +async def test_native_provider_stream_rejects_unsupported_operation_before_transport() -> None: + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + operation="embeddings", + endpoint="https://example.test/chat", + ) + client = FakeHTTPClient({"id": "should_not_call"}) + + with pytest.raises(ProtocolError, match="unsupported operation"): + async for _ in NativeProviderExecutor().stream( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(client), + ): + pass + + assert client.calls == [] diff --git a/tests/test_protocol_ollama_mcp.py b/tests/test_protocol_ollama_mcp.py index 6cbee9ab9..7c859203b 100644 --- a/tests/test_protocol_ollama_mcp.py +++ b/tests/test_protocol_ollama_mcp.py @@ -47,13 +47,14 @@ def test_ollama_format_response_uses_mutated_unified_fields() -> None: generate = adapter.parse_response({"model": "llama3", "response": "before", "eval_count": 2}) generate.output[0] = "after" + generate.usage.output_tokens = 5 embeddings = adapter.parse_response({"model": "llama3", "embedding": [0.1, 0.2]}) embeddings.data = [0.3, 0.4] assert adapter.format_response(chat)["message"]["content"] == "after" assert adapter.format_response(generate)["response"] == "after" - assert adapter.format_response(generate)["eval_count"] == 2 + assert adapter.format_response(generate)["eval_count"] == 5 assert adapter.format_response(embeddings)["embedding"] == [0.3, 0.4] diff --git a/tests/test_protocol_openai_chat.py b/tests/test_protocol_openai_chat.py index e2c8af7a3..7bbb5dfab 100644 --- a/tests/test_protocol_openai_chat.py +++ b/tests/test_protocol_openai_chat.py @@ -164,9 +164,13 @@ def test_openai_legacy_function_call_is_unified_and_round_trips() -> None: assert unified.messages[0].tool_calls[0].name == "lookup" assert unified.messages[0].tool_calls[0].extra["legacy_function_call"] is True + unified.messages[0].tool_calls[0].arguments = {"q": "changed"} assert rebuilt["messages"][0]["function_call"] == {"name": "lookup", "arguments": "{\"q\":\"x\"}"} assert "tool_calls" not in rebuilt["messages"][0] + rebuilt = adapter.build_request(unified) + assert rebuilt["messages"][0]["function_call"] == {"name": "lookup", "arguments": "{\"q\":\"changed\"}"} + def test_openai_chat_stream_event_parses_sse_delta_and_done() -> None: adapter = get_protocol("openai_chat") diff --git a/tests/test_protocol_responses.py b/tests/test_protocol_responses.py index 7b578a923..77476adf3 100644 --- a/tests/test_protocol_responses.py +++ b/tests/test_protocol_responses.py @@ -60,7 +60,7 @@ def test_responses_response_extracts_output_items_reasoning_calls_and_usage() -> "input_tokens": 10, "output_tokens": 5, "total_tokens": 18, - "input_tokens_details": {"cached_tokens": 3}, + "input_tokens_details": {"cached_tokens": 3, "cache_creation_tokens": 2}, "output_tokens_details": {"reasoning_tokens": 3}, "cost_details": {"total_cost": 0.02, "currency": "USD"}, }, @@ -77,6 +77,7 @@ def test_responses_response_extracts_output_items_reasoning_calls_and_usage() -> assert unified.usage.input_tokens == 10 assert unified.usage.reasoning_tokens == 3 assert unified.usage.cache_read_tokens == 3 + assert unified.usage.cache_write_tokens == 2 assert unified.usage.cost is not None assert unified.usage.cost.provider_reported_cost == 0.02 @@ -84,7 +85,7 @@ def test_responses_response_extracts_output_items_reasoning_calls_and_usage() -> assert formatted["usage"]["input_tokens"] == 10 assert formatted["usage"]["output_tokens"] == 5 assert formatted["usage"]["total_tokens"] == 18 - assert formatted["usage"]["input_tokens_details"] == {"cached_tokens": 3} + assert formatted["usage"]["input_tokens_details"] == {"cached_tokens": 3, "cache_creation_tokens": 2} assert formatted["usage"]["output_tokens_details"] == {"reasoning_tokens": 3} assert formatted["usage"]["cost_details"]["total_cost"] == 0.02 assert "raw" not in formatted["usage"] From 9d69ce164a9f7405c7ac6a9002a68a5043a6218a Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:20:33 +0200 Subject: [PATCH 151/182] docs(experimental): plan transform trace completion Adds the Phase 2c corrective plan for Anthropic compatibility traces, native stream-event adapter tracing, Responses SSE formatting traces, transform-failure coverage, final-response ordering, and camelCase redaction. Tests: not run (planning document only) --- .../phase-2c-transform-trace-completion.md | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 docs/experimental/phase-2c-transform-trace-completion.md diff --git a/docs/experimental/phase-2c-transform-trace-completion.md b/docs/experimental/phase-2c-transform-trace-completion.md new file mode 100644 index 000000000..0de390162 --- /dev/null +++ b/docs/experimental/phase-2c-transform-trace-completion.md @@ -0,0 +1,73 @@ +# Phase 2c: Transform Trace Completion And Redaction Hardening + +## Goal + +Close the Phase 2/2b third-pass tracing findings: Anthropic compatibility coverage, native stream-event adapter tracing, Responses SSE formatting traces, transform-failure traces, final-response ordering, and camelCase secret redaction. + +## Scope + +- Add transform-pass traces around live Anthropic Messages compatibility request/response conversion. +- Add Anthropic streaming conversion traces and explicit upstream close on disconnect or abnormal exit. +- Ensure final client response traces represent the final post-normalization returned payload. +- Run and trace native stream-event adapter chains. +- Trace Responses final SSE formatting boundaries. +- Harden redaction for camelCase secret-bearing keys. +- Emit transform-failure traces for built-in provider transforms and Responses conversion errors. + +## Non-Goals + +- Do not redesign transaction log storage. +- Do not remove legacy request/response JSON logs. +- Do not enable native streaming for priority providers. +- Do not change public API response shapes. + +## Implementation Plan + +1. Anthropic non-stream traces. + - Trace `anthropic_raw_request`, `anthropic_to_openai_request`, `anthropic_openai_response`, `openai_to_anthropic_response`, and `anthropic_final_response`. + +2. Anthropic streaming traces and close safety. + - Trace source OpenAI chunks and emitted Anthropic SSE frames. + - Close upstream stream via `aclose()` / `close()` on disconnect or abnormal exit. + - Trace stream transform errors before emitting the Anthropic error frames. + +3. Provider transform failure traces. + - Wrap built-in provider transforms and emit `transform_log_error` before re-raising. + - Preserve provider-hook behavior but keep sanitized metadata. + +4. Native stream-event adapter chain. + - Run adapter chain for parsed native stream events. + - Trace `after_stream_event_adapter_chain` before formatting. + +5. Responses SSE formatting trace. + - Trace each formatted `ResponsesStreamEvent` frame and terminal frame emitted by `stream_response()`. + +6. Final response trace ordering. + - Verify or adjust `final_client_response` ordering so it occurs after usage normalization/cost accounting. + +7. Redaction hardening. + - Normalize camelCase keys such as `apiKey`, `accessToken`, `refreshToken`, `clientSecret`, and `idToken`. + +8. Responses transform-failure traces. + - Emit standardized transform error traces for Responses parse/bridge/storage conversion errors without changing raised errors. + +## Tests + +- Anthropic non-stream and stream trace tests. +- Anthropic disconnect upstream close test. +- Built-in provider transform failure trace test. +- Native stream-event adapter mutation/trace test. +- Responses SSE formatting trace test. +- CamelCase redaction tests. +- Final response trace-ordering tests. + +## Acceptance Criteria + +- Anthropic compatibility has transform-pass coverage for request conversion, response conversion, and stream conversion. +- Anthropic stream disconnect closes upstream when possible. +- Built-in provider transform exceptions emit transform-failure traces. +- Native stream-event adapter chains are executed and traced. +- Responses stream SSE formatting boundaries are traced. +- Final client response traces reflect post-normalization final payloads. +- CamelCase secret-bearing keys are redacted. +- Focused tests pass and both `explore` and `explore-heavy` reviewers report no blockers/highs/mediums. From 8d2c2fd007f6e991eba98c3959425bd84fb91ec8 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:30:47 +0200 Subject: [PATCH 152/182] fix(trace): complete transform trace coverage Adds Anthropic compatibility transform traces, Anthropic stream conversion traces with upstream close safety, built-in transform failure tracing, native stream-event adapter execution, Responses SSE formatting traces, and camelCase secret redaction. Tests: pytest tests/test_anthropic_transform_tracing.py tests/test_transaction_logger_transform_trace.py tests/test_transform_trace.py tests/test_native_provider_streaming.py tests/test_native_provider_executor.py tests/test_responses_streaming.py tests/test_responses_service.py tests/test_adapter_registry.py tests/test_field_cache_trace.py tests/test_request_executor_stream_metrics.py tests/test_protocol_openai_chat.py --- .../anthropic_compat/streaming.py | 120 ++++++++++++++---- src/rotator_library/client/anthropic.py | 58 +++++++++ src/rotator_library/client/transforms.py | 21 ++- .../native_provider/executor.py | 13 +- src/rotator_library/protocols/openai_chat.py | 11 +- src/rotator_library/responses/service.py | 36 +++++- src/rotator_library/transform_trace.py | 6 +- tests/test_anthropic_transform_tracing.py | 105 +++++++++++++++ tests/test_native_provider_streaming.py | 36 ++++++ tests/test_responses_streaming.py | 1 + ...test_transaction_logger_transform_trace.py | 39 ++++++ 11 files changed, 413 insertions(+), 33 deletions(-) create mode 100644 tests/test_anthropic_transform_tracing.py diff --git a/src/rotator_library/anthropic_compat/streaming.py b/src/rotator_library/anthropic_compat/streaming.py index ecb074baa..4ad588d87 100644 --- a/src/rotator_library/anthropic_compat/streaming.py +++ b/src/rotator_library/anthropic_compat/streaming.py @@ -66,11 +66,73 @@ async def anthropic_streaming_wrapper( accumulated_text = "" # Track accumulated text for logging accumulated_thinking = "" # Track accumulated thinking for logging stop_reason_final = "end_turn" # Track final stop reason for logging + upstream_closed = False + + def trace_frame(pass_name: str, frame: str, payload: Any | None = None) -> str: + """Trace one Anthropic stream conversion frame and return it for yield.""" + + if transaction_logger: + transaction_logger.log_transform_pass( + pass_name, + payload if payload is not None else frame, + direction="stream", + stage="adapter", + protocol="anthropic_messages", + transport="sse", + snapshot=False, + ) + return frame + + async def close_upstream(reason: str) -> None: + """Close the wrapped OpenAI stream when Anthropic streaming stops early.""" + + nonlocal upstream_closed + if upstream_closed: + return + upstream_closed = True + close = getattr(openai_stream, "aclose", None) + if callable(close): + await close() + else: + sync_close = getattr(openai_stream, "close", None) + if callable(sync_close): + sync_close() + if transaction_logger: + transaction_logger.log_transform_pass( + "anthropic_stream_upstream_closed", + {"reason": reason}, + direction="stream", + stage="adapter", + protocol="anthropic_messages", + transport="sse", + snapshot=False, + ) try: async for chunk_str in openai_stream: + if transaction_logger: + transaction_logger.log_transform_pass( + "anthropic_stream_source_chunk", + chunk_str, + direction="stream", + stage="provider", + protocol="openai_chat", + transport="sse", + snapshot=False, + ) # Check for client disconnection if callback provided if is_disconnected is not None and await is_disconnected(): + if transaction_logger: + transaction_logger.log_transform_pass( + "anthropic_stream_disconnected", + {"reason": "client_disconnected"}, + direction="stream", + stage="adapter", + protocol="anthropic_messages", + transport="sse", + snapshot=False, + ) + await close_upstream("client_disconnected") break if not chunk_str.strip() or not chunk_str.startswith("data:"): @@ -103,25 +165,25 @@ async def anthropic_streaming_wrapper( "usage": usage_dict, }, } - yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" + yield trace_frame("anthropic_stream_message_start", f"event: message_start\ndata: {json.dumps(message_start)}\n\n", message_start) message_started = True # Close any open thinking block if thinking_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 thinking_block_started = False # Close any open text block if content_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 content_block_started = False # Close all open tool_use blocks for tc_index in sorted(tool_block_indices.keys()): block_idx = tool_block_indices[tc_index] - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {block_idx}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {block_idx}}}\n\n') # Determine stop_reason based on whether we had tool calls stop_reason = "tool_use" if tool_calls_by_index else "end_turn" @@ -134,10 +196,10 @@ async def anthropic_streaming_wrapper( final_usage["cache_creation_input_tokens"] = 0 # Send message_delta with final info - yield f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "{stop_reason}", "stop_sequence": null}}, "usage": {json.dumps(final_usage)}}}\n\n' + yield trace_frame("anthropic_stream_message_delta", f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "{stop_reason}", "stop_sequence": null}}, "usage": {json.dumps(final_usage)}}}\n\n') # Send message_stop - yield 'event: message_stop\ndata: {"type": "message_stop"}\n\n' + yield trace_frame("anthropic_stream_message_stop", 'event: message_stop\ndata: {"type": "message_stop"}\n\n') # Log final Anthropic response if logger provided if transaction_logger: @@ -242,7 +304,7 @@ async def anthropic_streaming_wrapper( "usage": usage_dict, }, } - yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" + yield trace_frame("anthropic_stream_message_start", f"event: message_start\ndata: {json.dumps(message_start)}\n\n", message_start) message_started = True choices = chunk.get("choices") or [] @@ -261,7 +323,7 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "content_block": {"type": "thinking", "thinking": ""}, } - yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + yield trace_frame("anthropic_stream_content_block_start", f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n", block_start) thinking_block_started = True # Send thinking delta @@ -270,7 +332,7 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "delta": {"type": "thinking_delta", "thinking": reasoning_content}, } - yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + yield trace_frame("anthropic_stream_content_delta", f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n", block_delta) # Accumulate thinking for logging accumulated_thinking += reasoning_content @@ -279,7 +341,7 @@ async def anthropic_streaming_wrapper( if content: # If we were in a thinking block, close it first if thinking_block_started and not content_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 thinking_block_started = False @@ -290,7 +352,7 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "content_block": {"type": "text", "text": ""}, } - yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + yield trace_frame("anthropic_stream_content_block_start", f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n", block_start) content_block_started = True # Send content delta @@ -299,7 +361,7 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "delta": {"type": "text_delta", "text": content}, } - yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + yield trace_frame("anthropic_stream_content_delta", f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n", block_delta) # Accumulate text for logging accumulated_text += content @@ -312,13 +374,13 @@ async def anthropic_streaming_wrapper( if tc_index not in tool_calls_by_index: # Close previous thinking block if open if thinking_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 thinking_block_started = False # Close previous text block if open if content_block_started: - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') current_block_index += 1 content_block_started = False @@ -341,7 +403,7 @@ async def anthropic_streaming_wrapper( "input": {}, }, } - yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + yield trace_frame("anthropic_stream_content_block_start", f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n", block_start) # Increment for the next block current_block_index += 1 @@ -361,7 +423,7 @@ async def anthropic_streaming_wrapper( "partial_json": func["arguments"], }, } - yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + yield trace_frame("anthropic_stream_content_delta", f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n", block_delta) # Note: We intentionally ignore finish_reason here. # Block closing is handled when we receive [DONE] to avoid @@ -369,6 +431,15 @@ async def anthropic_streaming_wrapper( except Exception as e: logger.error(f"Error in Anthropic streaming wrapper: {e}") + if transaction_logger: + transaction_logger.log_transform_error( + "anthropic_stream_transform", + e, + payload={"request_id": request_id, "model": original_model}, + stage="adapter", + protocol="anthropic_messages", + transport="sse", + ) # If we haven't sent message_start yet, send it now so the client can display the error # Claude Code and other clients may ignore events that come before message_start @@ -395,7 +466,7 @@ async def anthropic_streaming_wrapper( "usage": usage_dict, }, } - yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" + yield trace_frame("anthropic_stream_message_start", f"event: message_start\ndata: {json.dumps(message_start)}\n\n", message_start) # Send the error as a text content block so it's visible to the user error_message = f"Error: {str(e)}" @@ -404,16 +475,16 @@ async def anthropic_streaming_wrapper( "index": current_block_index, "content_block": {"type": "text", "text": ""}, } - yield f"event: content_block_start\ndata: {json.dumps(error_block_start)}\n\n" + yield trace_frame("anthropic_stream_content_block_start", f"event: content_block_start\ndata: {json.dumps(error_block_start)}\n\n", error_block_start) error_block_delta = { "type": "content_block_delta", "index": current_block_index, "delta": {"type": "text_delta", "text": error_message}, } - yield f"event: content_block_delta\ndata: {json.dumps(error_block_delta)}\n\n" + yield trace_frame("anthropic_stream_content_delta", f"event: content_block_delta\ndata: {json.dumps(error_block_delta)}\n\n", error_block_delta) - yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + yield trace_frame("anthropic_stream_content_block_stop", f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n') # Build final usage with cached tokens final_usage = {"output_tokens": 0} @@ -422,12 +493,15 @@ async def anthropic_streaming_wrapper( final_usage["cache_creation_input_tokens"] = 0 # Send message_delta and message_stop to properly close the stream - yield f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "end_turn", "stop_sequence": null}}, "usage": {json.dumps(final_usage)}}}\n\n' - yield 'event: message_stop\ndata: {"type": "message_stop"}\n\n' + yield trace_frame("anthropic_stream_message_delta", f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "end_turn", "stop_sequence": null}}, "usage": {json.dumps(final_usage)}}}\n\n') + yield trace_frame("anthropic_stream_message_stop", 'event: message_stop\ndata: {"type": "message_stop"}\n\n') # Also send the formal error event for clients that handle it error_event = { "type": "error", "error": {"type": "api_error", "message": str(e)}, } - yield f"event: error\ndata: {json.dumps(error_event)}\n\n" + yield trace_frame("anthropic_stream_error", f"event: error\ndata: {json.dumps(error_event)}\n\n", error_event) + finally: + if not upstream_closed: + await close_upstream("wrapper_exit") diff --git a/src/rotator_library/client/anthropic.py b/src/rotator_library/client/anthropic.py index 507e82fb9..47f1d1e31 100644 --- a/src/rotator_library/client/anthropic.py +++ b/src/rotator_library/client/anthropic.py @@ -94,9 +94,23 @@ async def messages( request.model_dump(exclude_none=True), filename="anthropic_request.json", ) + _trace_anthropic( + anthropic_logger, + "anthropic_raw_request", + request.model_dump(exclude_none=True), + direction="request", + stage="client", + ) # Translate Anthropic request to OpenAI format openai_request = translate_anthropic_request(request) + _trace_anthropic( + anthropic_logger, + "anthropic_to_openai_request", + openai_request, + direction="request", + stage="adapter", + ) # Pass parent log directory to acompletion for nested logging if anthropic_logger and anthropic_logger.log_dir: @@ -138,9 +152,23 @@ async def messages( if hasattr(response, "model_dump") else dict(response) ) + _trace_anthropic( + anthropic_logger, + "anthropic_openai_response", + openai_response, + direction="response", + stage="provider", + ) anthropic_response = openai_to_anthropic_response( openai_response, original_model ) + _trace_anthropic( + anthropic_logger, + "openai_to_anthropic_response", + anthropic_response, + direction="response", + stage="adapter", + ) # Override the ID with our request ID anthropic_response["id"] = request_id @@ -151,6 +179,13 @@ async def messages( anthropic_response, filename="anthropic_response.json", ) + _trace_anthropic( + anthropic_logger, + "anthropic_final_response", + anthropic_response, + direction="response", + stage="final", + ) return anthropic_response @@ -201,3 +236,26 @@ async def count_tokens( total_tokens = message_tokens + tool_tokens return {"input_tokens": total_tokens} + + +def _trace_anthropic( + logger: Optional[TransactionLogger], + pass_name: str, + payload: Any, + *, + direction: str, + stage: str, + transport: Optional[str] = None, +) -> None: + """Emit an Anthropic compatibility transform trace when logging is enabled.""" + + if not logger: + return + logger.log_transform_pass( + pass_name, + payload, + direction=direction, + stage=stage, + protocol="anthropic_messages", + transport=transport, + ) diff --git a/src/rotator_library/client/transforms.py b/src/rotator_library/client/transforms.py index 3b9558f93..78f052ba4 100644 --- a/src/rotator_library/client/transforms.py +++ b/src/rotator_library/client/transforms.py @@ -119,7 +119,26 @@ async def apply( if transform_provider == provider or transform_provider in model.lower(): for transform in transforms: before = deepcopy(kwargs) if transaction_logger else None - result = transform(kwargs, model, provider) + try: + result = transform(kwargs, model, provider) + except Exception as exc: + if transaction_logger: + transaction_logger.log_transform_error( + "builtin_provider_transform", + exc, + payload=before if before is not None else kwargs, + stage="client", + transport=transport, + metadata={ + "provider": provider, + "model": model, + "credential_id": credential_id, + "transform_provider": transform_provider, + "transform_name": getattr(transform, "__name__", repr(transform)), + **trace_metadata, + }, + ) + raise if result: modifications.append(result) _trace_transform_pass( diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index a69c9fa02..f626e7880 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -141,10 +141,19 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte self._trace(context, "raw_native_provider_stream_chunk", raw_chunk, direction="stream", stage="provider") event = protocol.parse_stream_event(raw_chunk, protocol_context) self._trace(context, "parsed_native_unified_stream_event", event, direction="stream", stage="protocol", snapshot=False) - event_payload = stream_event_payload(event) - self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") if event.type == "done": + event_payload = stream_event_payload(event) + self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") break + adapter_context = context.adapter_context() + # Native stream traces apply field-cache path redaction below. + # Suppress generic adapter-chain snapshots here so provider state + # cannot leak before rule-aware redaction runs. + adapter_context.transaction_logger = None + event = await run_adapter_chain(adapters, event, adapter_context, stage="stream_event") + self._trace(context, "after_stream_event_adapter_chain", event, direction="stream", stage="adapter", snapshot=False) + event_payload = stream_event_payload(event) + self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") await cache_engine.extract("stream_event", event_payload, context.field_cache_context(), transaction_logger=logger) self._trace( context, diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index 1a426449d..27c4659df 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -210,7 +210,16 @@ def format_stream_event(self, unified_event: UnifiedStreamEvent, context: Protoc if unified_event.type == "done": return "data: [DONE]\n\n" if unified_event.raw is not None: - return deepcopy(unified_event.raw) + payload = deepcopy(unified_event.raw) + if unified_event.delta is not None and isinstance(payload, dict) and isinstance(payload.get("choices"), list) and payload["choices"]: + choice = payload["choices"][0] + if isinstance(choice, dict): + formatted_delta = self._format_message(unified_event.delta) + original_delta = choice.get("delta") if isinstance(choice.get("delta"), dict) else {} + if "role" not in original_delta: + formatted_delta.pop("role", None) + choice["delta"] = formatted_delta + return payload return f"data: {json.dumps(unified_event.to_dict())}\n\n" def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index fb535bced..63f9a42e4 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -78,12 +78,20 @@ async def create_response( raise ResponsesServiceError("Use stream_response for streaming requests", status_code=400) self._trace(transaction_logger, "responses_raw_request", raw_request, direction="request", stage="client") - unified = self.protocol.parse_request(raw_request, ProtocolContext(source_protocol="responses")) + try: + unified = self.protocol.parse_request(raw_request, ProtocolContext(source_protocol="responses")) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_parse_request", exc, raw_request) + raise if transaction_logger: self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) - chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) + try: + chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_bridge_chat_request", exc, unified.to_dict()) + raise bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) session_hints = chat_kwargs.pop("_session_tracking_hints", None) session_hints = _responses_session_hints(unified.previous_response_id, parent, session_hints) @@ -129,7 +137,17 @@ async def stream_response( formatter = ResponsesSSEFormatter() async for event in self.stream_events(raw_request, client, request=request, transaction_logger=transaction_logger, transport=transport): - yield formatter.format_stream_event(event) + formatted = formatter.format_stream_event(event) + self._trace( + transaction_logger, + "responses_sse_formatted_event", + formatted, + direction="stream", + stage="final", + metadata={"event_name": event.event_name, "terminal": event.terminal, "transport": transport}, + scrub_strings=True, + ) + yield formatted async def validate_stream_request(self, raw_request: dict[str, Any]) -> None: """Validate stream-only preconditions before an HTTP response starts.""" @@ -156,11 +174,19 @@ async def stream_events( stream_request = dict(raw_request) stream_request["stream"] = True self._trace(transaction_logger, "responses_raw_request", stream_request, direction="request", stage="client") - unified = self.protocol.parse_request(stream_request, ProtocolContext(source_protocol="responses", transport=transport)) + try: + unified = self.protocol.parse_request(stream_request, ProtocolContext(source_protocol="responses", transport=transport)) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_parse_request", exc, stream_request) + raise if transaction_logger: self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) - chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) + try: + chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_bridge_chat_request", exc, unified.to_dict()) + raise bridge_metadata = chat_kwargs.pop("_responses_bridge", {}) session_hints = chat_kwargs.pop("_session_tracking_hints", None) session_hints = _responses_session_hints(unified.previous_response_id, parent, session_hints) diff --git a/src/rotator_library/transform_trace.py b/src/rotator_library/transform_trace.py index c54039876..513e37a56 100644 --- a/src/rotator_library/transform_trace.py +++ b/src/rotator_library/transform_trace.py @@ -37,8 +37,11 @@ "x-api-key", "x-goog-api-key", "openai-api-key", + "api-key", + "api-secret", "access-token", "refresh-token", + "id-token", "client-secret", "password", "secret", @@ -71,7 +74,8 @@ def _normalise_key(key: Any) -> str: - return str(key).strip().lower().replace("_", "-") + camel_split = re.sub(r"(?<=[a-z0-9])(?=[A-Z])", "-", str(key).strip()) + return camel_split.lower().replace("_", "-") def scrub_sensitive_text(value: str) -> str: diff --git a/tests/test_anthropic_transform_tracing.py b/tests/test_anthropic_transform_tracing.py new file mode 100644 index 000000000..98bc9faa7 --- /dev/null +++ b/tests/test_anthropic_transform_tracing.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import json + +import pytest + +import rotator_library.client.anthropic as anthropic_client_module +from rotator_library.anthropic_compat import AnthropicMessagesRequest +from rotator_library.anthropic_compat.streaming import anthropic_streaming_wrapper +from rotator_library.client.anthropic import AnthropicHandler +from rotator_library.transaction_logger import TransactionLogger + + +def _trace_entries(log_dir): + return [json.loads(line) for line in (log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + + +class FakeRotatingClient: + enable_request_logging = True + + async def acompletion(self, **kwargs): + return { + "id": "chat_1", + "model": kwargs["model"], + "choices": [{"message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + + +@pytest.mark.asyncio +async def test_anthropic_handler_traces_conversion_boundaries(tmp_path, monkeypatch) -> None: + created: list[TransactionLogger] = [] + + def logger_factory(provider, model, enabled=True, api_format="ant", parent_dir=None): + logger = TransactionLogger(provider, model, enabled=enabled, api_format=api_format, parent_dir=tmp_path) + created.append(logger) + return logger + + monkeypatch.setattr(anthropic_client_module, "TransactionLogger", logger_factory) + request = AnthropicMessagesRequest(model="openai/gpt-test", max_tokens=16, messages=[{"role": "user", "content": "hi"}]) + + response = await AnthropicHandler(FakeRotatingClient()).messages(request) + + assert response["content"][0]["text"] == "ok" + pass_names = [entry["pass_name"] for entry in _trace_entries(created[0].log_dir)] + assert "anthropic_raw_request" in pass_names + assert "anthropic_to_openai_request" in pass_names + assert "anthropic_openai_response" in pass_names + assert "openai_to_anthropic_response" in pass_names + assert "anthropic_final_response" in pass_names + + +class ClosingOpenAIStream: + def __init__(self) -> None: + self.closed = False + self._chunks = iter(['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n']) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._chunks) + except StopIteration: + raise StopAsyncIteration + + async def aclose(self) -> None: + self.closed = True + + +@pytest.mark.asyncio +async def test_anthropic_stream_traces_and_closes_on_disconnect(tmp_path) -> None: + logger = TransactionLogger("anthropic", "claude-test", parent_dir=tmp_path) + stream = ClosingOpenAIStream() + + async def disconnected() -> bool: + return True + + chunks = [chunk async for chunk in anthropic_streaming_wrapper(stream, "claude-test", is_disconnected=disconnected, transaction_logger=logger)] + + assert chunks == [] + assert stream.closed is True + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "anthropic_stream_source_chunk" in pass_names + assert "anthropic_stream_disconnected" in pass_names + assert "anthropic_stream_upstream_closed" in pass_names + + +@pytest.mark.asyncio +async def test_anthropic_stream_traces_emitted_frames(tmp_path) -> None: + logger = TransactionLogger("anthropic", "claude-test", parent_dir=tmp_path) + + async def stream(): + yield 'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' + yield "data: [DONE]\n\n" + + chunks = [chunk async for chunk in anthropic_streaming_wrapper(stream(), "claude-test", transaction_logger=logger)] + + assert any("event: message_start" in chunk for chunk in chunks) + assert any("event: message_stop" in chunk for chunk in chunks) + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "anthropic_stream_message_start" in pass_names + assert "anthropic_stream_content_delta" in pass_names + assert "anthropic_stream_message_delta" in pass_names + assert "anthropic_stream_message_stop" in pass_names diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index 138ae2972..43fa6c75e 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -4,6 +4,7 @@ import pytest +from rotator_library.adapters import PayloadAdapter, register_adapter from rotator_library.field_cache import FieldCacheRule from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor from rotator_library.transaction_logger import TransactionLogger @@ -74,3 +75,38 @@ async def stream_json_lines(self, endpoint, *, headers, json): pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] assert "transform_log_error" in pass_names + + +@pytest.mark.asyncio +async def test_native_provider_stream_runs_stream_event_adapter_chain(tmp_path) -> None: + class StreamTextAdapter(PayloadAdapter): + name = "test_stream_text_adapter" + supported_stages = ("stream_event",) + + async def transform_stream_event(self, payload, context): + payload.delta.content[0].text = "adapted" + return payload + + register_adapter(StreamTextAdapter, replace=True) + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + adapter_names=("test_stream_text_adapter",), + transaction_logger=logger, + ) + + events = [ + event + async for event in NativeProviderExecutor().stream( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeStreamingClient([{"choices": [{"delta": {"content": "before"}}]}, "[DONE]"])), + ) + ] + + assert events[0]["choices"][0]["delta"]["content"] == "adapted" + pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] + assert "after_stream_event_adapter_chain" in pass_names diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 80b4baa9c..f7d0e0fbc 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -212,6 +212,7 @@ async def test_responses_stream_records_common_stream_metrics(tmp_path) -> None: assert "responses_stream_event_created" in trace_text assert "responses_stream_event_output_text_delta" in trace_text assert "responses_stream_event_completed" in trace_text + assert "responses_sse_formatted_event" in trace_text usage_entry = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines() if '"usage_accounting_summary"' in line][-1] assert usage_entry["data"]["cost"]["provider_reported_cost"] == 0.044 diff --git a/tests/test_transaction_logger_transform_trace.py b/tests/test_transaction_logger_transform_trace.py index 60be0b93d..250a0df60 100644 --- a/tests/test_transaction_logger_transform_trace.py +++ b/tests/test_transaction_logger_transform_trace.py @@ -183,6 +183,20 @@ def test_log_transform_error_uses_standard_shape_and_scrubs_text(tmp_path) -> No assert "secret" not in json.dumps(entries[0]["data"]) +def test_trace_redacts_camel_case_secret_keys(tmp_path) -> None: + logger = TransactionLogger("openai", "openai/gpt-test", parent_dir=tmp_path) + + logger.log_transform_pass( + "camel_secret_payload", + {"apiKey": "a", "accessToken": "b", "refreshToken": "c", "clientSecret": "d", "idToken": "e"}, + direction="request", + stage="client", + ) + + data = _trace_entries(logger.log_dir)[0]["data"] + assert set(data.values()) == {REDACTED} + + @pytest.mark.asyncio async def test_provider_transforms_trace_each_live_boundary(tmp_path) -> None: class HookPlugin: @@ -235,6 +249,31 @@ def convert_for_litellm(self, provider_override=None, **kwargs): assert entries[-1]["changed_from_previous"] is True +@pytest.mark.asyncio +async def test_provider_builtin_transform_errors_are_traced(tmp_path) -> None: + def broken_transform(kwargs, model, provider): + raise RuntimeError("bad apiKey: secret") + + logger = TransactionLogger("broken", "broken/test", parent_dir=tmp_path) + transforms = ProviderTransforms({}) + transforms._transforms["broken"] = [broken_transform] + + with pytest.raises(RuntimeError): + await transforms.apply( + "broken", + "broken/test", + "cred", + {"model": "broken/test", "apiKey": "secret"}, + transaction_logger=logger, + credential_id="cred_1", + ) + + entries = _trace_entries(logger.log_dir) + error_entry = [entry for entry in entries if entry["pass_name"] == "transform_log_error"][-1] + assert error_entry["data"]["failed_pass_name"] == "builtin_provider_transform" + assert "secret" not in json.dumps(error_entry["data"]) + + @pytest.mark.asyncio async def test_provider_transforms_do_not_deepcopy_for_trace_when_disabled() -> None: class NoCopyValue: From 3ea61130f1e09f97fa737a3608cd527a0adbe74d Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:38:16 +0200 Subject: [PATCH 153/182] fix(trace): address Phase 2c review gaps Moves final client response logging after usage normalization, replaces stream iterator snapshots with a sanitized stream-opened marker, keeps core lifecycle traces visible, and traces Responses conversion/store failures. Tests: pytest tests/test_anthropic_transform_tracing.py tests/test_transaction_logger_transform_trace.py tests/test_transform_trace.py tests/test_native_provider_streaming.py tests/test_native_provider_executor.py tests/test_responses_streaming.py tests/test_responses_service.py tests/test_adapter_registry.py tests/test_field_cache_trace.py tests/test_request_executor_stream_metrics.py tests/test_executor_usage_accounting.py tests/test_protocol_openai_chat.py --- src/rotator_library/client/executor.py | 31 +++++++------------ src/rotator_library/client/streaming.py | 2 +- src/rotator_library/responses/service.py | 12 +++++-- tests/test_request_executor_stream_metrics.py | 4 +-- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 054811f60..f2627ae2f 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -1198,23 +1198,6 @@ async def _execute_non_streaming( f"Recorded usage from response object for key {mask_credential(cred)}" ) - # Log response if transaction logging enabled - if context.transaction_logger: - try: - response_data = ( - response.model_dump() - if hasattr(response, "model_dump") - else response - ) - response_data = _redact_context_field_cache_paths(response_data, context, "response", plugin) - context.transaction_logger.log_response( - response_data - ) - except Exception as log_err: - lib_logger.debug( - f"Failed to log response: {log_err}" - ) - normalized_response = self._normalize_response_usage(response, model) trace_normalized_response = _redact_context_field_cache_paths(normalized_response, context, "response", plugin) self._log_executor_trace( @@ -1226,6 +1209,16 @@ async def _execute_non_streaming( credential_id=cred_context.stable_id, metadata={"provider": provider, "model": model}, ) + # Legacy response.json and final_client_response must + # reflect the same post-normalization payload returned + # to the caller, not the pre-accounting SDK object. + if context.transaction_logger: + try: + context.transaction_logger.log_response(trace_normalized_response) + except Exception as log_err: + lib_logger.debug( + f"Failed to log response: {log_err}" + ) return normalized_response except RoutingExecutionError as e: @@ -1491,8 +1484,8 @@ async def _execute_streaming( self._log_executor_trace( context, - "raw_provider_stream_response", - stream, + "provider_stream_opened", + {"stream_type": f"{type(stream).__module__}.{type(stream).__name__}"}, direction="response", stage="provider", credential_id=cred_context.stable_id, diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 7fe437581..ee66b5720 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -97,7 +97,7 @@ async def wrap_stream( upstream_closed = False stream_cancelled = False last_heartbeat_at = monitor.metrics.started_at - lifecycle_logger = transaction_logger if stream_settings.trace_metrics else None + lifecycle_logger = transaction_logger self._log_stream_lifecycle( lifecycle_logger, "stream_started", diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 63f9a42e4..36c77cb9a 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -111,14 +111,22 @@ async def create_response( if transaction_logger: self._trace(transaction_logger, "responses_bridge_chat_response", self._response_to_dict(chat_response), direction="response", stage="provider") - response_payload = self.bridge.from_chat_response(chat_response, unified) + try: + response_payload = self.bridge.from_chat_response(chat_response, unified) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_bridge_chat_response", exc, self._response_to_dict(chat_response)) + raise _record_responses_session_anchor(session_info, response_payload) self._trace(transaction_logger, "responses_parsed_response", response_payload, direction="response", stage="protocol") self._trace_responses_usage(transaction_logger, response_payload, unified.model, source="responses_response") if raw_request.get("store", True): stored = self._stored_response(raw_request, response_payload, parent, session_info=session_info) - await self.store.save(stored) + try: + await self.store.save(stored) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_store_response", exc, stored.to_dict()) + raise self._trace(transaction_logger, "responses_stored_response", stored.to_dict(), direction="metadata", stage="final") self._trace(transaction_logger, "responses_final_response", response_payload, direction="response", stage="final") diff --git a/tests/test_request_executor_stream_metrics.py b/tests/test_request_executor_stream_metrics.py index db2301b96..45d94acdc 100644 --- a/tests/test_request_executor_stream_metrics.py +++ b/tests/test_request_executor_stream_metrics.py @@ -111,8 +111,8 @@ async def test_stream_trace_metrics_can_be_disabled_without_changing_output(tmp_ assert chunks[-1] == "data: [DONE]\n\n" if (logger.log_dir / "transform_trace.jsonl").exists(): pass_names = _trace_passes(logger.log_dir) - assert "stream_started" not in pass_names - assert "stream_metrics_final" not in pass_names + assert "stream_started" in pass_names + assert "stream_metrics_final" in pass_names @pytest.mark.asyncio From 005685e0d82fb8ba903a20c8e47a56c491491733 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:44:00 +0200 Subject: [PATCH 154/182] fix(trace): trace responses stream store failures Adds store-specific transform-error traces for completed, failed, and in-progress Responses streaming saves, with secret-scrubbed payloads. Tests: pytest tests/test_responses_streaming.py tests/test_responses_service.py tests/test_anthropic_transform_tracing.py tests/test_transaction_logger_transform_trace.py tests/test_native_provider_streaming.py tests/test_request_executor_stream_metrics.py --- src/rotator_library/responses/service.py | 19 ++++++-- tests/test_responses_streaming.py | 56 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 36c77cb9a..8ee063ed2 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -300,7 +300,7 @@ async def stream_events( completed = response_completed_payload(state, _usage_to_responses_stream(usage)) _record_responses_session_anchor(session_info, completed) self._trace_responses_usage(transaction_logger, completed, unified.model, source="responses_stream") - stored = await self._store_stream_response(stream_request, completed, parent, session_info=session_info) + stored = await self._store_stream_response(stream_request, completed, parent, transaction_logger=transaction_logger, session_info=session_info) if stored: self._trace(transaction_logger, "responses_stored_stream_response", completed, direction="metadata", stage="final") else: @@ -333,7 +333,7 @@ async def stream_events( if state.output_text: failed["output"] = [output_item_done_payload(state)["item"]] self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) - stored = await self._store_stream_response(stream_request, failed, parent, failed=True, session_info=session_info) + stored = await self._store_stream_response(stream_request, failed, parent, failed=True, transaction_logger=transaction_logger, session_info=session_info) if stored: self._trace(transaction_logger, "responses_stored_failed_stream_response", {"response_id": failed.get("id"), "status": "failed"}, direction="metadata", stage="final") self._trace(transaction_logger, "responses_stream_event_failed", failed, direction="stream", stage="final", metadata={"transport": transport}, scrub_strings=True) @@ -498,13 +498,19 @@ async def _store_stream_response( parent: Optional[StoredResponse], *, failed: bool = False, + transaction_logger: Optional[Any] = None, session_info: Optional[dict[str, Any]] = None, ) -> bool: if not raw_request.get("store", True): return False if failed and not self.store_settings.store_failed: return False - await self.store.save(self._stored_response(raw_request, response_payload, parent, session_info=session_info)) + stored = self._stored_response(raw_request, response_payload, parent, session_info=session_info) + try: + await self.store.save(stored) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_store_stream_response", exc, stored.to_dict()) + raise return True async def _store_stream_current_state( @@ -520,7 +526,12 @@ async def _store_stream_current_state( if not self.store_settings.store_in_progress or not raw_request.get("store", True): return False - await self.store.save(self._stored_response(raw_request, response_payload, parent, session_info=session_info)) + stored = self._stored_response(raw_request, response_payload, parent, session_info=session_info) + try: + await self.store.save(stored) + except Exception as exc: + self._log_transform_error(transaction_logger, "responses_store_stream_current_state", exc, stored.to_dict()) + raise self._trace( transaction_logger, "responses_stored_stream_current_state", diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index f7d0e0fbc..7fea52597 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -55,6 +55,20 @@ async def chunks(): return chunks() +class FailingStore: + async def save(self, response): + raise RuntimeError("store failed Authorization: Bearer secret-token") + + async def get(self, response_id): + return None + + async def delete(self, response_id): + return False + + async def list_input_items(self, response_id): + return None + + @pytest.mark.asyncio async def test_stream_response_emits_responses_sse_events_and_stores_final_response() -> None: store = InMemoryResponsesStore() @@ -168,6 +182,48 @@ async def test_stream_events_can_store_in_progress_state() -> None: assert stored.status == "in_progress" +@pytest.mark.asyncio +async def test_stream_response_store_failures_emit_store_specific_trace(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=FailingStore()) + + with pytest.raises(RuntimeError): + _ = [ + chunk + async for chunk in service.stream_response( + {"model": "gpt-test", "input": "Hello", "stream": True}, + FakeStreamingClient(), + transaction_logger=logger, + ) + ] + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + errors = [entry for entry in entries if entry["pass_name"] == "transform_log_error"] + assert any(entry["data"]["failed_pass_name"] == "responses_store_stream_response" for entry in errors) + assert "secret-token" not in json.dumps(errors) + + +@pytest.mark.asyncio +async def test_stream_current_state_store_failures_emit_store_specific_trace(tmp_path) -> None: + logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) + service = ResponsesService(store=FailingStore(), store_settings=ResponsesStoreSettings(store_in_progress=True)) + + with pytest.raises(RuntimeError): + _ = [ + event + async for event in service.stream_events( + {"model": "gpt-test", "input": "Hello", "stream": True}, + FakeStreamingClient(), + transaction_logger=logger, + ) + ] + + entries = [json.loads(line) for line in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8").splitlines()] + errors = [entry for entry in entries if entry["pass_name"] == "transform_log_error"] + assert any(entry["data"]["failed_pass_name"] == "responses_store_stream_current_state" for entry in errors) + assert "secret-token" not in json.dumps(errors) + + @pytest.mark.asyncio async def test_stream_response_failure_trace_scrubs_header_like_secret_text(tmp_path) -> None: logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) From d3e2bfdecca5afa442a8f70d028ca8e929dcb18a Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:50:35 +0200 Subject: [PATCH 155/182] fix(trace): normalize dict responses before final trace Extends non-streaming response usage normalization to dict responses so final_client_response always reflects post-normalization usage, matching the Phase 2c trace contract. Tests: pytest tests/test_executor_usage_accounting.py tests/test_anthropic_transform_tracing.py tests/test_transaction_logger_transform_trace.py tests/test_responses_streaming.py tests/test_request_executor_stream_metrics.py --- src/rotator_library/client/executor.py | 4 +++- tests/test_executor_usage_accounting.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index f2627ae2f..72790d6c4 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -2270,7 +2270,9 @@ def _normalize_response_usage(response: Any, model: str) -> Any: dicts (streaming) and pydantic objects (non-streaming). Internal tracking values from UsageRecord accounting are unaffected. """ - if hasattr(response, "usage") and response.usage: + if isinstance(response, dict): + normalize_usage_for_response(response.get("usage"), model) + elif hasattr(response, "usage") and response.usage: normalize_usage_for_response(response.usage, model) return response diff --git a/tests/test_executor_usage_accounting.py b/tests/test_executor_usage_accounting.py index c1786c082..1f28cc537 100644 --- a/tests/test_executor_usage_accounting.py +++ b/tests/test_executor_usage_accounting.py @@ -67,3 +67,19 @@ def test_executor_accounting_uses_configured_env_pricing(tmp_path, monkeypatch) assert cost.pricing_source == "env" assert cost.input_cost == 6.0 + + +def test_normalize_response_usage_handles_dict_responses() -> None: + response = { + "usage": { + "prompt_tokens": 4, + "completion_tokens": 1, + "completion_tokens_details": {"reasoning_tokens": 3}, + } + } + + result = RequestExecutor._normalize_response_usage(response, "gpt-test") + + assert result is response + assert response["usage"]["completion_tokens"] == 4 + assert response["usage"]["total_tokens"] == 8 From 5a4d42b90f4cdf21c2fa111797ac6ba793015cdc Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 19:59:35 +0200 Subject: [PATCH 156/182] docs(experimental): plan field-cache runtime completion Adds the Phase 3c corrective plan for unified field-cache runtime sources and targets, native adapter trace safety, and credential-scope fail-closed behavior. Tests: not run (planning document only) --- ...phase-3c-field-cache-runtime-completion.md | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 docs/experimental/phase-3c-field-cache-runtime-completion.md diff --git a/docs/experimental/phase-3c-field-cache-runtime-completion.md b/docs/experimental/phase-3c-field-cache-runtime-completion.md new file mode 100644 index 000000000..4c6bb01db --- /dev/null +++ b/docs/experimental/phase-3c-field-cache-runtime-completion.md @@ -0,0 +1,66 @@ +# Phase 3c: Field-Cache Runtime Completion And Trace Safety + +## Goal + +Close the Phase 3/3b third-pass findings around field-cache runtime coverage, adapter-chain trace safety, and credential-scoped isolation. + +## Scope + +- Execute all declared field-cache sources used by native runtime: `request`, `response`, `stream_event`, `unified_request`, `unified_response`, and `unified_stream_event`. +- Execute declared injection targets used by native runtime: `request`, `unified_request`, and `metadata`. +- Suppress generic adapter-chain payload traces in native execution where field-cache-aware redaction has not yet run, and rely on native executor traces that apply rule-aware redaction. +- Fail closed when a rule includes `credential` scope but the runtime lacks a credential identifier. +- Add focused tests for each corrected behavior. + +## Non-Goals + +- Do not replace the field-cache store or add a database. +- Do not make native streaming inject into stream events; streaming extraction remains expected behavior. +- Do not make Copilot declare field-cache rules when none are required. +- Do not implement custom provider-specific turn inference beyond existing metadata hints. + +## Implementation Plan + +1. Credential scope isolation. + - Update cache-key construction so missing `credential` scope returns `None`. + - Keep existing `allow_missing_session` behavior only for `session` scope. + - Add tests proving credential-scoped extraction/injection skips instead of using a shared missing bucket. + +2. Unified request injection and extraction. + - In native non-streaming and streaming paths, create the field-cache engine before protocol build. + - Extract `unified_request` after parsing client payload. + - Inject `metadata` before building protocol context when metadata rules exist. + - Inject `unified_request` before `protocol.build_request()`. + - Hydrate injected serialized unified-request dictionaries back into the current `UnifiedRequest` for supported common fields. + - Trace `after_unified_request_field_cache_injection` and `after_metadata_field_cache_injection`. + +3. Unified response extraction. + - Extract `unified_response` after protocol response parsing and before formatting. + - Keep existing provider-response extraction after response adapter chain. + - Trace `after_unified_response_field_cache_extraction`. + +4. Unified stream-event extraction. + - Extract `unified_stream_event` after native stream-event adapter chain and before formatting. + - Keep existing provider stream-event extraction from formatted event payload. + - Trace `after_unified_stream_event_field_cache_extraction`. + +5. Native adapter trace safety. + - Disable generic adapter-chain traces inside native request/response paths, matching the stream-event path. + - Keep native executor traces after adapter chains, where configured field-cache paths are redacted. + - Add tests with arbitrary configured provider-state paths proving generic `before_adapter_chain`/`after_adapter` traces are not emitted in native execution and native traces redact configured paths. + +6. Tests. + - Field-cache engine credential-scope fail-closed tests. + - Native executor unified request source/target tests. + - Native executor metadata injection tests. + - Native executor unified response extraction tests. + - Native stream unified event extraction tests. + - Native adapter trace redaction/suppression tests. + +## Acceptance Criteria + +- Native runtime executes all declared field-cache sources/targets that it accepts, except stream-event injection which remains intentionally unsupported. +- Missing credential scope never shares a `_none` cache key. +- Native adapter chain traces cannot leak arbitrary configured provider-state fields before rule-aware redaction. +- Existing field-cache and native provider tests continue to pass. +- Both `explore` and `explore-heavy` reviewers report no blockers/highs/mediums. From e29cd9e262812c6d5e767a713c47dba70205c5cf Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 20:03:28 +0200 Subject: [PATCH 157/182] fix(field-cache): complete native runtime coverage Executes unified request/response/stream-event field-cache sources, supports metadata and unified-request injection targets in native execution, suppresses unsafe generic native adapter traces, and fails closed on missing credential scope. Tests: pytest tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_adapter_registry.py tests/test_request_executor_native_routing.py tests/test_protocol_openai_chat.py tests/test_responses_streaming.py tests/test_transform_trace.py --- src/rotator_library/field_cache/engine.py | 2 + .../native_provider/executor.py | 97 ++++++++++++- tests/test_field_cache_engine.py | 20 +++ tests/test_native_provider_executor.py | 129 ++++++++++++++++++ tests/test_native_provider_streaming.py | 41 +++++- 5 files changed, 283 insertions(+), 6 deletions(-) diff --git a/src/rotator_library/field_cache/engine.py b/src/rotator_library/field_cache/engine.py index b32afcef3..c9a2134da 100644 --- a/src/rotator_library/field_cache/engine.py +++ b/src/rotator_library/field_cache/engine.py @@ -41,6 +41,8 @@ def build_cache_key(rule: FieldCacheRule, context: FieldCacheContext) -> Optiona for scope in rule.scope: value = context.value_for_scope(scope) if value is None or value == "": + if scope == "credential": + return None if scope == "session" and not rule.allow_missing_session: return None value = "_none" diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index f626e7880..f7dc9d0f3 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -13,6 +13,7 @@ from ..field_cache import FieldCacheEngine, InMemoryFieldCacheStore from ..field_cache.paths import FieldCachePathError, PathToken, parse_path from ..protocols import ProtocolError, get_protocol, serialize_value +from ..protocols.types import UnifiedRequest from ..transform_trace import REDACTED from ..usage.accounting import extract_usage_record from ..usage.costs import CostCalculator @@ -41,15 +42,21 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") try: self._trace(context, "raw_native_client_request", raw_request, direction="request", stage="client") + cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) + context = await self._inject_metadata(context, cache_engine) protocol_context = context.protocol_context() unified_request = protocol.parse_request(raw_request, protocol_context) self._trace(context, "parsed_native_unified_request", unified_request, direction="request", stage="protocol") + await cache_engine.extract("unified_request", serialize_value(unified_request), context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_unified_request_field_cache_extraction", {"source": "unified_request"}, direction="request", stage="adapter", snapshot=False) + unified_request = await self._inject_unified_request(unified_request, context, cache_engine) provider_request = protocol.build_request(unified_request, protocol_context) self._trace(context, "built_native_provider_request", provider_request, direction="request", stage="protocol") adapters = [get_adapter(name) for name in context.adapter_names] - provider_request = await run_adapter_chain(adapters, provider_request, context.adapter_context(), stage="request") + adapter_context = context.adapter_context() + adapter_context.transaction_logger = None + provider_request = await run_adapter_chain(adapters, provider_request, adapter_context, stage="request") self._trace(context, "after_request_adapter_chain", provider_request, direction="request", stage="adapter") - cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) provider_request, _ = await cache_engine.inject( "request", provider_request, @@ -62,9 +69,13 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont self._trace(context, "raw_native_provider_response", raw_response, direction="response", stage="provider") unified_response = protocol.parse_response(raw_response, protocol_context) self._trace(context, "parsed_native_unified_response", unified_response, direction="response", stage="protocol") + await cache_engine.extract("unified_response", serialize_value(unified_response), context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_unified_response_field_cache_extraction", {"source": "unified_response"}, direction="response", stage="adapter", snapshot=False) provider_response = protocol.format_response(unified_response, protocol_context) self._trace(context, "formatted_native_response", provider_response, direction="response", stage="protocol") - provider_response = await run_adapter_chain(adapters, provider_response, context.adapter_context(), stage="response") + adapter_context = context.adapter_context() + adapter_context.transaction_logger = None + provider_response = await run_adapter_chain(adapters, provider_response, adapter_context, stage="response") self._trace(context, "after_response_adapter_chain", provider_response, direction="response", stage="adapter") await cache_engine.extract("response", provider_response, context.field_cache_context(), transaction_logger=logger) usage_record = extract_usage_record( @@ -118,17 +129,23 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte self._trace(context, "native_protocol_selected", {"protocol": protocol.name}, direction="metadata", stage="protocol") try: self._trace(context, "raw_native_client_request", raw_request, direction="request", stage="client") + cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) + context = await self._inject_metadata(context, cache_engine) protocol_context = context.protocol_context() request_payload = dict(raw_request) request_payload["stream"] = True unified_request = protocol.parse_request(request_payload, protocol_context) self._trace(context, "parsed_native_unified_request", unified_request, direction="request", stage="protocol") + await cache_engine.extract("unified_request", serialize_value(unified_request), context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_unified_request_field_cache_extraction", {"source": "unified_request"}, direction="request", stage="adapter", snapshot=False) + unified_request = await self._inject_unified_request(unified_request, context, cache_engine) provider_request = protocol.build_request(unified_request, protocol_context) self._trace(context, "built_native_provider_request", provider_request, direction="request", stage="protocol") adapters = [get_adapter(name) for name in context.adapter_names] - provider_request = await run_adapter_chain(adapters, provider_request, context.adapter_context(), stage="request") + adapter_context = context.adapter_context() + adapter_context.transaction_logger = None + provider_request = await run_adapter_chain(adapters, provider_request, adapter_context, stage="request") self._trace(context, "after_request_adapter_chain", provider_request, direction="request", stage="adapter") - cache_engine = FieldCacheEngine(context.field_cache_rules, store=self.field_cache_store) provider_request, _ = await cache_engine.inject( "request", provider_request, @@ -152,6 +169,8 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte adapter_context.transaction_logger = None event = await run_adapter_chain(adapters, event, adapter_context, stage="stream_event") self._trace(context, "after_stream_event_adapter_chain", event, direction="stream", stage="adapter", snapshot=False) + await cache_engine.extract("unified_stream_event", serialize_value(event), context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_unified_stream_event_field_cache_extraction", {"source": "unified_stream_event"}, direction="stream", stage="adapter", snapshot=False) event_payload = stream_event_payload(event) self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") await cache_engine.extract("stream_event", event_payload, context.field_cache_context(), transaction_logger=logger) @@ -224,6 +243,42 @@ def _ensure_supported_operation(protocol: Any, context: NativeProviderContext) - payload={"provider": context.provider, "model": context.model, "operation": context.operation}, ) + async def _inject_metadata(self, context: NativeProviderContext, cache_engine: FieldCacheEngine) -> NativeProviderContext: + """Inject cached metadata before protocol/adapter contexts are built.""" + + metadata, operations = await cache_engine.inject( + "metadata", + dict(context.metadata), + context.field_cache_context(), + transaction_logger=context.transaction_logger, + ) + if operations: + self._trace(context, "after_metadata_field_cache_injection", metadata, direction="metadata", stage="adapter", snapshot=False) + if metadata == context.metadata: + return context + return replace(context, metadata=metadata) + + async def _inject_unified_request( + self, + unified_request: UnifiedRequest, + context: NativeProviderContext, + cache_engine: FieldCacheEngine, + ) -> UnifiedRequest: + """Inject cached values into a serialized unified request and hydrate it.""" + + serialized = serialize_value(unified_request) + injected, operations = await cache_engine.inject( + "unified_request", + serialized, + context.field_cache_context(), + transaction_logger=context.transaction_logger, + ) + if operations: + self._trace(context, "after_unified_request_field_cache_injection", injected, direction="request", stage="adapter") + if injected == serialized: + return unified_request + return _hydrate_unified_request(unified_request, injected) + def _redact_field_cache_paths(data: Any, context: NativeProviderContext, direction: str) -> Any: """Redact configured cache paths before broad native payload traces. @@ -253,6 +308,38 @@ def _redact_field_cache_paths(data: Any, context: NativeProviderContext, directi return redacted +def _hydrate_unified_request(original: UnifiedRequest, injected: Any) -> UnifiedRequest: + """Hydrate common unified-request fields after serialized cache injection. + + Field-cache path injection operates on JSON-like dictionaries. Protocol + builders still expect `UnifiedRequest`, so this helper copies supported + top-level fields back onto the dataclass while leaving complex message/tool + objects untouched unless providers handle them through provider-payload rules. + """ + + if not isinstance(injected, dict): + return original + safe_fields = { + "operation", + "model", + "stream", + "input", + "modalities", + "files", + "generation_params", + "response_format", + "previous_response_id", + "metadata", + "raw", + "extra", + } + values = {field_name: getattr(original, field_name) for field_name in UnifiedRequest._fields} + for field_name in safe_fields: + if field_name in injected: + values[field_name] = injected[field_name] + return UnifiedRequest(**values) + + def _trace_redaction_paths(paths: list[str], *, direction: str) -> list[str]: """Return configured paths plus raw-stream envelope fallbacks for traces.""" diff --git a/tests/test_field_cache_engine.py b/tests/test_field_cache_engine.py index de6862f03..842a34801 100644 --- a/tests/test_field_cache_engine.py +++ b/tests/test_field_cache_engine.py @@ -122,6 +122,26 @@ async def test_missing_session_scope_skips_by_default() -> None: assert operations[0].reason == "missing_required_scope" +@pytest.mark.asyncio +async def test_missing_credential_scope_skips_instead_of_sharing_none_bucket() -> None: + rule = _reasoning_rule(scope=("provider", "model", "credential")) + engine = FieldCacheEngine([rule]) + + operations = await engine.extract( + "response", + {"choices": [{"message": {"reasoning_content": "x"}}]}, + _context(credential_id=None), + ) + updated, injection_operations = await engine.inject("request", {"messages": [{}]}, _context(credential_id=None)) + + assert operations[0].skipped is True + assert operations[0].reason == "missing_required_scope" + assert injection_operations[0].skipped is True + assert injection_operations[0].reason == "missing_required_scope" + assert updated == {"messages": [{}]} + assert build_cache_key(rule, _context(credential_id=None)) is None + + @pytest.mark.asyncio async def test_missing_path_is_noop() -> None: engine = FieldCacheEngine([_reasoning_rule()]) diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index 4d50c8d19..650a2383a 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -4,6 +4,7 @@ import pytest +from rotator_library.adapters import PayloadAdapter, register_adapter from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor from rotator_library.protocols import ProtocolError @@ -210,6 +211,134 @@ async def test_native_provider_executor_rejects_unsupported_operation_before_tra assert client.calls == [] +@pytest.mark.asyncio +async def test_native_runtime_executes_unified_request_source_and_target() -> None: + rule = FieldCacheRule( + name="unified_request_state", + source="unified_request", + path="metadata.client_state", + inject=FieldCacheInjection(target="unified_request", path="metadata.cached_client_state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "gpt-test", "messages": [], "metadata": {"client_state": "state-a"}}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": []})), + ) + second_client = FakeHTTPClient({"id": "chat_2", "choices": []}) + await executor.execute( + {"model": "gpt-test", "messages": [], "metadata": {}}, + context, + NativeHTTPTransport(second_client), + ) + + assert second_client.calls[0]["json"]["metadata"]["cached_client_state"] == "state-a" + + +@pytest.mark.asyncio +async def test_native_runtime_executes_unified_response_source() -> None: + rule = FieldCacheRule( + name="response_object", + source="unified_response", + path="metadata.object", + inject=FieldCacheInjection(target="request", path="metadata.cached_object"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "object": "chat.completion", "choices": []}))) + second_client = FakeHTTPClient({"id": "chat_2", "choices": []}) + await executor.execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["cached_object"] == "chat.completion" + + +@pytest.mark.asyncio +async def test_native_runtime_executes_metadata_target_before_adapters() -> None: + class MetadataEchoAdapter(PayloadAdapter): + name = "test_metadata_echo" + supported_stages = ("request",) + + async def transform_request(self, payload, context): + payload.setdefault("metadata", {})["adapter_seen_state"] = context.metadata.get("cached_state") + return payload + + register_adapter(MetadataEchoAdapter, replace=True) + rule = FieldCacheRule( + name="metadata_state", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="metadata", path="cached_state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + adapter_names=("test_metadata_echo",), + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "meta-state"}}]}))) + second_client = FakeHTTPClient({"id": "chat_2", "choices": []}) + await executor.execute({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["adapter_seen_state"] == "meta-state" + + +@pytest.mark.asyncio +async def test_native_adapter_generic_traces_are_suppressed_for_field_cache_safety(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="hidden_state", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="request", path="metadata.hidden_state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + adapter_names=("model_override",), + adapter_config={"model_override": {"model": "provider/gpt-test"}}, + field_cache_rules=(rule,), + transaction_logger=logger, + ) + + await NativeProviderExecutor().execute( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "opaque-state"}}]})), + ) + + entries = _trace_entries(logger.log_dir) + pass_names = [entry["pass_name"] for entry in entries] + assert "before_adapter_chain" not in pass_names + assert "after_adapter" not in pass_names + assert "after_request_adapter_chain" in pass_names + assert "opaque-state" not in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + + @pytest.mark.asyncio async def test_native_provider_stream_rejects_unsupported_operation_before_transport() -> None: context = NativeProviderContext( diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index 43fa6c75e..bd84d4a5a 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -5,7 +5,7 @@ import pytest from rotator_library.adapters import PayloadAdapter, register_adapter -from rotator_library.field_cache import FieldCacheRule +from rotator_library.field_cache import FieldCacheInjection, FieldCacheRule from rotator_library.native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor from rotator_library.transaction_logger import TransactionLogger @@ -110,3 +110,42 @@ async def transform_stream_event(self, payload, context): assert events[0]["choices"][0]["delta"]["content"] == "adapted" pass_names = [entry["pass_name"] for entry in _trace_entries(logger.log_dir)] assert "after_stream_event_adapter_chain" in pass_names + + +@pytest.mark.asyncio +async def test_native_provider_stream_extracts_unified_stream_events_for_later_requests() -> None: + rule = FieldCacheRule( + name="unified_stream_text", + source="unified_stream_event", + path="delta.content.0.text", + inject=FieldCacheInjection(target="request", path="metadata.cached_stream_text"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + _ = [ + event + async for event in executor.stream( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeStreamingClient([{"choices": [{"delta": {"content": "stream-state"}}]}, "[DONE]"])), + ) + ] + second_client = FakeStreamingClient(["[DONE]"]) + _ = [ + event + async for event in executor.stream( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(second_client), + ) + ] + + assert second_client.calls[0]["json"]["metadata"]["cached_stream_text"] == "stream-state" From 58ea43affa06403d0d2652645274086c225c4a57 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 20:11:20 +0200 Subject: [PATCH 158/182] fix(field-cache): close Phase 3c review gaps Extracts request-source field-cache rules in native execution and redacts configured metadata-injection paths in native traces. Tests: pytest tests/test_field_cache_engine.py tests/test_field_cache_trace.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_adapter_registry.py tests/test_request_executor_native_routing.py tests/test_protocol_openai_chat.py tests/test_responses_streaming.py tests/test_transform_trace.py --- .../native_provider/executor.py | 6 ++ tests/test_native_provider_executor.py | 66 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index f7dc9d0f3..1e7f673eb 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -64,6 +64,8 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont transaction_logger=logger, ) self._trace(context, "after_field_cache_injection", provider_request, direction="request", stage="adapter") + await cache_engine.extract("request", provider_request, context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_request_field_cache_extraction", {"source": "request"}, direction="request", stage="adapter", snapshot=False) self._trace(context, "native_provider_request", provider_request, direction="request", stage="provider") raw_response = await transport.post_json(context.endpoint, headers=context.headers, payload=provider_request) self._trace(context, "raw_native_provider_response", raw_response, direction="response", stage="provider") @@ -153,6 +155,8 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte transaction_logger=logger, ) self._trace(context, "after_field_cache_injection", provider_request, direction="request", stage="adapter") + await cache_engine.extract("request", provider_request, context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_request_field_cache_extraction", {"source": "request"}, direction="request", stage="adapter", snapshot=False) self._trace(context, "native_provider_stream_request", provider_request, direction="request", stage="provider") async for raw_chunk in transport.stream_json_lines(context.endpoint, headers=context.headers, payload=provider_request): self._trace(context, "raw_native_provider_stream_chunk", raw_chunk, direction="stream", stage="provider") @@ -296,6 +300,8 @@ def _redact_field_cache_paths(data: Any, context: NativeProviderContext, directi paths: list[str] = [] if direction == "request" and rule.inject: paths.append(rule.inject.path) + if direction == "metadata" and rule.inject and rule.inject.target == "metadata": + paths.append(rule.inject.path) if direction in {"response", "stream"}: paths.append(rule.path) for path in _trace_redaction_paths(paths, direction=direction): diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index 650a2383a..a1a4c8f02 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -339,6 +339,72 @@ async def test_native_adapter_generic_traces_are_suppressed_for_field_cache_safe assert "opaque-state" not in (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") +@pytest.mark.asyncio +async def test_native_runtime_executes_request_source() -> None: + rule = FieldCacheRule( + name="request_state", + source="request", + path="metadata.outgoing_state", + inject=FieldCacheInjection(target="unified_request", path="metadata.reused_request_state"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "gpt-test", "messages": [], "metadata": {"outgoing_state": "request-state"}}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": []})), + ) + second_client = FakeHTTPClient({"id": "chat_2", "choices": []}) + await executor.execute({"model": "gpt-test", "messages": [], "metadata": {}}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["reused_request_state"] == "request-state" + + +@pytest.mark.asyncio +async def test_native_metadata_injection_trace_redacts_configured_paths(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + rule = FieldCacheRule( + name="metadata_secret", + source="response", + path="choices.0.message.reasoning_content", + inject=FieldCacheInjection(target="metadata", path="cached_blob"), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + field_cache_rules=(rule,), + transaction_logger=logger, + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_1", "choices": [{"message": {"role": "assistant", "content": "ok", "reasoning_content": "metadata-secret"}}]})), + ) + await executor.execute( + {"model": "gpt-test", "messages": []}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "chat_2", "choices": []})), + ) + + trace_text = (logger.log_dir / "transform_trace.jsonl").read_text(encoding="utf-8") + assert "metadata-secret" not in trace_text + metadata_entries = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "after_metadata_field_cache_injection"] + assert metadata_entries[-1]["data"]["cached_blob"] == "[REDACTED]" + + @pytest.mark.asyncio async def test_native_provider_stream_rejects_unsupported_operation_before_transport() -> None: context = NativeProviderContext( From 855486cbcee5b1b3d79c694ab401fe79c5ea04a2 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 20:19:40 +0200 Subject: [PATCH 159/182] docs(experimental): plan responses correction Adds the Phase 4c corrective plan for configurable Responses storage, continuation lineage replay, and top-level route error bodies. Tests: not run (planning document only) --- ...ase-4c-responses-storage-lineage-errors.md | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 docs/experimental/phase-4c-responses-storage-lineage-errors.md diff --git a/docs/experimental/phase-4c-responses-storage-lineage-errors.md b/docs/experimental/phase-4c-responses-storage-lineage-errors.md new file mode 100644 index 000000000..b35b1a329 --- /dev/null +++ b/docs/experimental/phase-4c-responses-storage-lineage-errors.md @@ -0,0 +1,60 @@ +# Phase 4c: Responses Storage, Continuation Lineage, And Error Shape + +## Goal + +Close the Phase 4/4b third-pass Responses findings while keeping the existing bridge architecture and no-SQLite constraint. + +## Scope + +- Wire configurable Responses storage at app startup. +- Keep process-local memory storage as the default. +- Add provider-cache-backed JSON storage as an opt-in durable backend. +- Preserve existing `ResponsesStoreSettings` policy for TTL, maximum items, failed storage, and in-progress storage. +- Extend `previous_response_id` continuation context to replay parent input and output lineage, oldest-to-newest. +- Return top-level OpenAI-compatible `error` bodies from Responses routes. + +## Non-Goals + +- Do not introduce SQLite or a new database. +- Do not replace the chat-completions bridge with native Responses execution. +- Do not add a WebSocket route. +- Do not implement a cancel endpoint. +- Do not perfect every rich Responses item type in the bridge; cover known text/message/tool-call cases and preserve unsupported fields. + +## Implementation Plan + +1. Configurable Responses store backend. + - Add runtime config helpers for `memory` and `provider_cache` backends. + - Use existing provider-cache JSON storage for durable mode. + - Keep memory backend as the default. + +2. Startup wiring. + - Construct `ResponsesService(store=..., store_settings=...)` at proxy startup and fallback service creation. + - Keep direct `ResponsesService()` defaults unchanged for tests and embedding users. + +3. Continuation lineage. + - Load parent chains through stored `request.previous_response_id` with depth and cycle guards. + - Pass oldest-to-newest lineage into the bridge. + +4. Bridge replay. + - Replay parent input items as user messages when convertible. + - Replay parent output message items as assistant messages. + - Keep existing parent-output behavior for external callers using the old argument. + +5. Top-level route errors. + - Return `JSONResponse(status_code=..., content={"error": ...})` for Responses service errors. + - Apply to create validation, retrieve, delete, and input-items routes. + +6. Tests. + - Config/env tests for storage backend selection. + - Store tests for provider-cache-backed persistence. + - Bridge/service tests for parent input + output lineage replay. + - Route tests for top-level error bodies. + +## Acceptance Criteria + +- App startup can use durable provider-cache-backed Responses storage via config/env. +- Existing memory storage remains default. +- Continuations replay parent input and output lineage, not just parent output. +- Responses route errors use top-level OpenAI-compatible `error` bodies. +- Focused tests pass and both reviewers report no blockers/highs/mediums. From fdabe101609fdd56f8bcfd97bcce9e73befe9a5f Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 20:27:39 +0200 Subject: [PATCH 160/182] fix(responses): wire storage and lineage Adds configurable provider-cache-backed Responses storage, wires app startup to the configured store, replays parent input/output lineage for continuations, and returns top-level Responses error bodies. Tests: pytest tests/test_experimental_config.py tests/test_responses_store.py tests/test_responses_bridge.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_streaming.py tests/test_protocol_responses.py tests/test_responses_usage_accounting.py tests/test_stream_transport.py --- src/proxy_app/main.py | 18 +++++----- src/rotator_library/config/experimental.py | 42 ++++++++++++++++++++++ src/rotator_library/responses/__init__.py | 3 +- src/rotator_library/responses/bridge.py | 36 ++++++++++++++++++- src/rotator_library/responses/service.py | 23 ++++++++++-- src/rotator_library/responses/store.py | 29 +++++++++++++++ tests/test_experimental_config.py | 35 ++++++++++++++++++ tests/test_responses_bridge.py | 26 ++++++++++++++ tests/test_responses_routes.py | 7 ++-- tests/test_responses_service.py | 35 ++++++++++++++++++ tests/test_responses_store.py | 23 +++++++++++- 11 files changed, 261 insertions(+), 16 deletions(-) diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index fe7f91ff3..5cf4a4dc1 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -598,8 +598,9 @@ async def process_credential(provider: str, path: str, provider_instance): # the existing chat-completions client path; later native providers can reuse # the same route/storage surface without changing clients. from rotator_library.config.experimental import get_responses_store_settings + from rotator_library.responses import create_configured_responses_store - app.state.responses_service = ResponsesService(store_settings=get_responses_store_settings()) + app.state.responses_service = ResponsesService(store=create_configured_responses_store(), store_settings=get_responses_store_settings()) # Warn if no provider credentials are configured if not client.all_credentials: @@ -677,8 +678,9 @@ def get_responses_service(request: Request) -> ResponsesService: service = getattr(request.app.state, "responses_service", None) if service is None: from rotator_library.config.experimental import get_responses_store_settings + from rotator_library.responses import create_configured_responses_store - service = ResponsesService(store_settings=get_responses_store_settings()) + service = ResponsesService(store=create_configured_responses_store(), store_settings=get_responses_store_settings()) request.app.state.responses_service = service return service @@ -1054,7 +1056,7 @@ async def responses_create( try: request_data = await request.json() except json.JSONDecodeError: - raise HTTPException(status_code=400, detail="Invalid JSON in request body.") + return JSONResponse(status_code=400, content={"error": {"message": "Invalid JSON in request body.", "type": "invalid_request_error", "code": 400}}) if logger: logger.log_request(headers=dict(request.headers), body=request_data) transaction_logger = TransactionLogger("responses", request_data.get("model", "unknown")) if ENABLE_REQUEST_LOGGING else None @@ -1074,12 +1076,12 @@ async def responses_create( payload = _responses_error_response(e) if logger: logger.log_final_response(status_code=e.status_code, headers=None, body=payload) - raise HTTPException(status_code=e.status_code, detail=payload) + return JSONResponse(status_code=e.status_code, content=payload) except Exception as e: logging.error(f"Responses endpoint error: {e}") if logger: logger.log_final_response(status_code=500, headers=None, body={"error": str(e)}) - raise HTTPException(status_code=500, detail=str(e)) + return JSONResponse(status_code=500, content={"error": {"message": str(e), "type": "internal_error", "code": 500}}) @app.get("/v1/responses/{response_id}") @@ -1093,7 +1095,7 @@ async def responses_get( try: return JSONResponse(content=await service.get_response(response_id)) except ResponsesServiceError as e: - raise HTTPException(status_code=e.status_code, detail=_responses_error_response(e)) + return JSONResponse(status_code=e.status_code, content=_responses_error_response(e)) @app.delete("/v1/responses/{response_id}") @@ -1107,7 +1109,7 @@ async def responses_delete( try: return JSONResponse(content=await service.delete_response(response_id)) except ResponsesServiceError as e: - raise HTTPException(status_code=e.status_code, detail=_responses_error_response(e)) + return JSONResponse(status_code=e.status_code, content=_responses_error_response(e)) @app.get("/v1/responses/{response_id}/input_items") @@ -1121,7 +1123,7 @@ async def responses_input_items( try: return JSONResponse(content=await service.list_input_items(response_id)) except ResponsesServiceError as e: - raise HTTPException(status_code=e.status_code, detail=_responses_error_response(e)) + return JSONResponse(status_code=e.status_code, content=_responses_error_response(e)) # --- Anthropic Messages API Endpoint --- diff --git a/src/rotator_library/config/experimental.py b/src/rotator_library/config/experimental.py index ca0d00bf3..be1e76346 100644 --- a/src/rotator_library/config/experimental.py +++ b/src/rotator_library/config/experimental.py @@ -85,6 +85,18 @@ class RetryRuntimeSettings: failure_history_max_entries: int = 200 +@dataclass(frozen=True) +class ResponsesStoreRuntimeSettings: + """Runtime backend selection for Responses storage.""" + + backend: str = "memory" + cache_name: str = "responses" + cache_prefix: str = "responses" + cache_dir: Optional[str] = None + cache_memory_ttl_seconds: int = 3600 + cache_disk_ttl_seconds: int = 172800 + + def load_experimental_config(path: str | os.PathLike[str] | None = None, env: Mapping[str, str] | None = None) -> ExperimentalConfig: """Load optional JSON config from an explicit path or config env var.""" @@ -229,6 +241,30 @@ def get_responses_store_settings( ) +def get_responses_store_runtime_settings( + *, + config: ExperimentalConfig | None = None, + env: Mapping[str, str] | None = None, +) -> ResponsesStoreRuntimeSettings: + """Return Responses store backend settings with env overriding JSON.""" + + source = env if env is not None else os.environ + active = config if config is not None else load_experimental_config(env=source) + responses = active.responses if isinstance(active.responses, dict) else {} + store = responses.get("store", {}) if isinstance(responses.get("store"), dict) else responses + backend = str(_env_or_json(source, "RESPONSES_STORE_BACKEND", store, "backend", default="memory")).strip().lower() + if backend not in {"memory", "provider_cache"}: + raise ExperimentalConfigError("RESPONSES_STORE_BACKEND must be 'memory' or 'provider_cache'") + return ResponsesStoreRuntimeSettings( + backend=backend, + cache_name=str(_env_or_json(source, "RESPONSES_STORE_CACHE_NAME", store, "cache_name", default="responses")), + cache_prefix=str(_env_or_json(source, "RESPONSES_STORE_CACHE_PREFIX", store, "cache_prefix", default="responses")), + cache_dir=_optional_string(_env_or_json(source, "RESPONSES_STORE_CACHE_DIR", store, "cache_dir")), + cache_memory_ttl_seconds=max(1, _int_setting(source, "RESPONSES_STORE_CACHE_MEMORY_TTL_SECONDS", store, "cache_memory_ttl_seconds", 3600)), + cache_disk_ttl_seconds=max(1, _int_setting(source, "RESPONSES_STORE_CACHE_DISK_TTL_SECONDS", store, "cache_disk_ttl_seconds", 172800)), + ) + + def parse_field_cache_rules(config: ExperimentalConfig, provider: str, model: str) -> tuple[FieldCacheRule, ...]: """Parse configured field-cache rules for a provider/model. @@ -407,6 +443,12 @@ def _optional_positive_int(value: Any, name: str) -> Optional[int]: return parsed if parsed > 0 else None +def _optional_string(value: Any) -> Optional[str]: + if value in (None, ""): + return None + return str(value) + + def _field_cache_rule_from_dict(data: Mapping[str, Any]) -> FieldCacheRule: inject_data = data.get("inject") inject = None diff --git a/src/rotator_library/responses/__init__.py b/src/rotator_library/responses/__init__.py index a16e76bde..0ba921406 100644 --- a/src/rotator_library/responses/__init__.py +++ b/src/rotator_library/responses/__init__.py @@ -5,7 +5,7 @@ from .bridge import ResponsesBridge from .service import ResponsesService, ResponsesServiceError -from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore +from .store import InMemoryResponsesStore, ProviderCacheResponsesStore, ResponsesStore, create_configured_responses_store from .streaming import ResponsesSSEFormatter, ResponsesStreamEvent, ResponsesWebSocketFormatter from .types import ResponsesStoreSettings, StoredResponse, generate_response_id @@ -21,5 +21,6 @@ "ResponsesStore", "ResponsesWebSocketFormatter", "StoredResponse", + "create_configured_responses_store", "generate_response_id", ] diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index e90893c4b..abd578994 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -53,6 +53,7 @@ def to_chat_kwargs( unified: UnifiedRequest, *, parent_response: Optional[dict[str, Any]] = None, + parent_responses: Optional[list[dict[str, Any]]] = None, ) -> dict[str, Any]: """Convert a parsed Responses request to chat-completions kwargs.""" @@ -60,7 +61,11 @@ def to_chat_kwargs( system_text = _blocks_to_text(unified.system) if system_text: messages.append({"role": "system", "content": system_text}) - if parent_response: + if parent_responses is not None: + for parent in parent_responses: + messages.extend(_parent_request_to_messages(parent.get("request") or {})) + messages.extend(_parent_output_to_messages(parent.get("output") or parent.get("response", {}).get("output") or [])) + elif parent_response: messages.extend(_parent_output_to_messages(parent_response.get("output") or [])) messages.extend(_message_to_chat(message) for message in unified.messages) kwargs: dict[str, Any] = { @@ -190,6 +195,35 @@ def _parent_output_to_messages(output: list[Any]) -> list[dict[str, Any]]: return messages +def _parent_request_to_messages(request: dict[str, Any]) -> list[dict[str, Any]]: + """Replay convertible parent Responses input items as chat messages.""" + + if not isinstance(request, dict): + return [] + protocol = ResponsesProtocol() + try: + unified = protocol.parse_request(request) + except Exception: + return _raw_input_to_messages(request.get("input")) + return [_message_to_chat(message) for message in unified.messages] + + +def _raw_input_to_messages(value: Any) -> list[dict[str, Any]]: + if value in (None, ""): + return [] + if isinstance(value, str): + return [{"role": "user", "content": value}] + if isinstance(value, list): + messages: list[dict[str, Any]] = [] + for item in value: + if isinstance(item, str): + messages.append({"role": "user", "content": item}) + elif isinstance(item, dict) and item.get("role") and item.get("content") is not None: + messages.append({"role": item.get("role"), "content": item.get("content")}) + return messages + return [] + + def _chat_message_to_output_items(message: dict[str, Any], index: int) -> list[dict[str, Any]]: items = [] content = message.get("content") diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 8ee063ed2..701b00c7c 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -88,7 +88,8 @@ async def create_response( parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) try: - chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) + parent_lineage = await self._load_response_lineage(parent) + chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_responses=[stored.to_dict() for stored in parent_lineage] if parent_lineage else None) except Exception as exc: self._log_transform_error(transaction_logger, "responses_bridge_chat_request", exc, unified.to_dict()) raise @@ -191,7 +192,8 @@ async def stream_events( self._trace(transaction_logger, "responses_parsed_request", unified.to_dict(), direction="request", stage="protocol") parent = await self._load_previous_response(unified.previous_response_id, transaction_logger) try: - chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_response=parent.response if parent else None) + parent_lineage = await self._load_response_lineage(parent) + chat_kwargs = self.bridge.to_chat_kwargs(unified, parent_responses=[stored.to_dict() for stored in parent_lineage] if parent_lineage else None) except Exception as exc: self._log_transform_error(transaction_logger, "responses_bridge_chat_request", exc, unified.to_dict()) raise @@ -374,6 +376,23 @@ async def list_input_items(self, response_id: str) -> dict[str, Any]: raise ResponsesServiceError(f"Response not found: {response_id}", status_code=404, error_type="not_found_error") return {"object": "list", "data": items} + async def _load_response_lineage(self, parent: Optional[StoredResponse], *, max_depth: int = 20) -> list[StoredResponse]: + """Return parent continuation lineage from oldest to newest.""" + + if parent is None: + return [] + lineage: list[StoredResponse] = [] + seen: set[str] = set() + current: Optional[StoredResponse] = parent + while current is not None and current.id not in seen and len(lineage) < max_depth: + seen.add(current.id) + lineage.append(current) + previous_id = current.request.get("previous_response_id") if isinstance(current.request, dict) else None + if not previous_id: + break + current = await self.store.get(str(previous_id)) + return list(reversed(lineage)) + async def _load_previous_response(self, response_id: Optional[str], transaction_logger: Optional[Any]) -> Optional[StoredResponse]: if not response_id: return None diff --git a/src/rotator_library/responses/store.py b/src/rotator_library/responses/store.py index b5c507ee9..e8e3f387d 100644 --- a/src/rotator_library/responses/store.py +++ b/src/rotator_library/responses/store.py @@ -7,6 +7,7 @@ import json from copy import deepcopy +from pathlib import Path from typing import Any, Optional, Protocol from .types import StoredResponse @@ -85,6 +86,9 @@ def __init__(self, provider_cache: Any, *, prefix: str = "responses") -> None: async def save(self, response: StoredResponse) -> None: await self._cache.store_async(self._key(response.id), json.dumps(response.to_dict(), ensure_ascii=False)) + flush = getattr(self._cache, "_save_to_disk", None) + if callable(flush): + await flush() async def get(self, response_id: str) -> Optional[StoredResponse]: raw = await self._cache.retrieve_async(self._key(response_id)) @@ -113,3 +117,28 @@ async def list_input_items(self, response_id: str) -> Optional[list[Any]]: def _key(self, response_id: str) -> str: safe_id = response_id.replace("/", "_").replace("\\", "_").replace(":", "_") return f"{self._prefix}:{safe_id}" + + +def create_configured_responses_store(*, config: Any = None, env: Any = None) -> ResponsesStore: + """Create the configured Responses store backend. + + Memory remains the default. The provider-cache backend uses the existing JSON + cache implementation so durable storage does not require a new database. + """ + + from ..config.experimental import get_responses_store_runtime_settings, get_responses_store_settings + from ..providers.provider_cache import create_provider_cache + + runtime = get_responses_store_runtime_settings(config=config, env=env) + settings = get_responses_store_settings(config=config, env=env) + if runtime.backend == "memory": + return InMemoryResponsesStore(max_items=settings.max_items) + cache_dir = Path(runtime.cache_dir) if runtime.cache_dir else None + provider_cache = create_provider_cache( + runtime.cache_name, + cache_dir=cache_dir, + memory_ttl_seconds=runtime.cache_memory_ttl_seconds, + disk_ttl_seconds=runtime.cache_disk_ttl_seconds, + env_prefix=f"{runtime.cache_name.upper().replace('-', '_')}_CACHE", + ) + return ProviderCacheResponsesStore(provider_cache, prefix=runtime.cache_prefix) diff --git a/tests/test_experimental_config.py b/tests/test_experimental_config.py index 9506edbe1..083b4faf1 100644 --- a/tests/test_experimental_config.py +++ b/tests/test_experimental_config.py @@ -7,6 +7,7 @@ as_int, env_price_key, get_responses_store_settings, + get_responses_store_runtime_settings, get_retry_runtime_settings, get_stream_runtime_settings, load_config_from_mapping, @@ -94,6 +95,40 @@ def test_responses_store_settings_env_overrides_json() -> None: assert settings.store_in_progress is True +def test_responses_store_runtime_settings_env_overrides_json(tmp_path) -> None: + config = load_config_from_mapping( + { + "responses": { + "store": { + "backend": "memory", + "cache_name": "json_responses", + "cache_prefix": "json_prefix", + "cache_dir": str(tmp_path / "json-cache"), + "cache_memory_ttl_seconds": 10, + "cache_disk_ttl_seconds": 20, + } + } + } + ) + + settings = get_responses_store_runtime_settings( + config=config, + env={"RESPONSES_STORE_BACKEND": "provider_cache", "RESPONSES_STORE_CACHE_MEMORY_TTL_SECONDS": "30"}, + ) + + assert settings.backend == "provider_cache" + assert settings.cache_name == "json_responses" + assert settings.cache_prefix == "json_prefix" + assert settings.cache_dir == str(tmp_path / "json-cache") + assert settings.cache_memory_ttl_seconds == 30 + assert settings.cache_disk_ttl_seconds == 20 + + +def test_responses_store_runtime_rejects_unknown_backend() -> None: + with pytest.raises(ExperimentalConfigError): + get_responses_store_runtime_settings(config=load_config_from_mapping({"responses": {"store": {"backend": "sqlite"}}}), env={}) + + def test_new_config_sections_still_reject_secret_like_keys() -> None: with pytest.raises(ExperimentalConfigError): load_config_from_mapping({"responses": {"store": {"authorization": "hidden"}}}) diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py index 29d822b48..22691aef7 100644 --- a/tests/test_responses_bridge.py +++ b/tests/test_responses_bridge.py @@ -46,6 +46,32 @@ def test_bridge_adds_parent_response_messages_for_previous_response_id() -> None assert "session_scope" not in kwargs["_session_tracking_hints"] +def test_bridge_replays_parent_input_and_output_lineage() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Now"}) + lineage = [ + { + "request": {"model": "gpt-test", "input": "First user"}, + "response": {"output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "First assistant"}]}]}, + }, + { + "request": {"model": "gpt-test", "input": "Second user", "previous_response_id": "resp_1"}, + "response": {"output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Second assistant"}]}]}, + }, + ] + + kwargs = bridge.to_chat_kwargs(unified, parent_responses=lineage) + + assert kwargs["messages"] == [ + {"role": "user", "content": "First user"}, + {"role": "assistant", "content": "First assistant"}, + {"role": "user", "content": "Second user"}, + {"role": "assistant", "content": "Second assistant"}, + {"role": "user", "content": "Now"}, + ] + + def test_bridge_preserves_tool_definitions() -> None: protocol = ResponsesProtocol() bridge = ResponsesBridge(protocol) diff --git a/tests/test_responses_routes.py b/tests/test_responses_routes.py index 1cdb18fd6..bf1e47188 100644 --- a/tests/test_responses_routes.py +++ b/tests/test_responses_routes.py @@ -49,7 +49,7 @@ def test_post_responses_missing_model_returns_400() -> None: response = client.post("/v1/responses", json={"input": "hello"}) assert response.status_code == 400 - assert response.json()["detail"]["error"]["type"] == "invalid_request_error" + assert response.json()["error"]["type"] == "invalid_request_error" def test_post_responses_stream_missing_model_returns_400_before_sse() -> None: @@ -58,7 +58,7 @@ def test_post_responses_stream_missing_model_returns_400_before_sse() -> None: response = client.post("/v1/responses", json={"input": "hello", "stream": True}) assert response.status_code == 400 - assert response.json()["detail"]["error"]["type"] == "invalid_request_error" + assert response.json()["error"]["type"] == "invalid_request_error" def test_post_responses_stream_missing_previous_response_returns_404_before_sse() -> None: @@ -67,7 +67,7 @@ def test_post_responses_stream_missing_previous_response_returns_404_before_sse( response = client.post("/v1/responses", json={"model": "gpt-test", "input": "hello", "stream": True, "previous_response_id": "missing"}) assert response.status_code == 404 - assert response.json()["detail"]["error"]["type"] == "not_found_error" + assert response.json()["error"]["type"] == "not_found_error" def test_get_delete_and_input_items_routes() -> None: @@ -86,6 +86,7 @@ def test_get_delete_and_input_items_routes() -> None: assert deleted.status_code == 200 assert deleted.json() == {"id": created["id"], "object": "response.deleted", "deleted": True} assert missing.status_code == 404 + assert missing.json()["error"]["type"] == "not_found_error" def test_post_responses_stream_returns_sse_events() -> None: diff --git a/tests/test_responses_service.py b/tests/test_responses_service.py index 3970a3a90..2c0bcf1c0 100644 --- a/tests/test_responses_service.py +++ b/tests/test_responses_service.py @@ -151,11 +151,46 @@ async def test_previous_response_id_loads_parent_context() -> None: await service.create_response({"model": "gpt-test", "input": "Continue", "previous_response_id": "resp_parent"}, client) assert client.calls[0]["messages"] == [ + {"role": "user", "content": "Earlier"}, {"role": "assistant", "content": "Earlier"}, {"role": "user", "content": "Continue"}, ] +@pytest.mark.asyncio +async def test_previous_response_id_loads_full_lineage_oldest_first() -> None: + store = InMemoryResponsesStore() + await store.save( + StoredResponse( + id="resp_grandparent", + model="gpt-test", + status="completed", + request={"model": "gpt-test", "input": "First"}, + response={"id": "resp_grandparent", "object": "response", "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "First answer"}]}]}, + ) + ) + await store.save( + StoredResponse( + id="resp_parent", + model="gpt-test", + status="completed", + request={"model": "gpt-test", "input": "Second", "previous_response_id": "resp_grandparent"}, + response={"id": "resp_parent", "object": "response", "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Second answer"}]}]}, + ) + ) + client = FakeClient() + + await ResponsesService(store=store).create_response({"model": "gpt-test", "input": "Third", "previous_response_id": "resp_parent"}, client) + + assert client.calls[0]["messages"] == [ + {"role": "user", "content": "First"}, + {"role": "assistant", "content": "First answer"}, + {"role": "user", "content": "Second"}, + {"role": "assistant", "content": "Second answer"}, + {"role": "user", "content": "Third"}, + ] + + @pytest.mark.asyncio async def test_internal_session_hints_do_not_leak_to_direct_clients_or_traces(tmp_path) -> None: store = InMemoryResponsesStore() diff --git a/tests/test_responses_store.py b/tests/test_responses_store.py index b7025f74a..15316bb55 100644 --- a/tests/test_responses_store.py +++ b/tests/test_responses_store.py @@ -5,7 +5,7 @@ import pytest -from rotator_library.responses import InMemoryResponsesStore, ProviderCacheResponsesStore, StoredResponse, generate_response_id +from rotator_library.responses import InMemoryResponsesStore, ProviderCacheResponsesStore, StoredResponse, create_configured_responses_store, generate_response_id def _stored(response_id: str = "resp_test") -> StoredResponse: @@ -128,3 +128,24 @@ async def delete_async(self, key: str) -> bool: assert await store.delete(stored.id) is True assert await store.get(stored.id) is None + + +@pytest.mark.asyncio +async def test_configured_provider_cache_store_persists_between_instances(tmp_path) -> None: + env = { + "RESPONSES_STORE_BACKEND": "provider_cache", + "RESPONSES_STORE_CACHE_NAME": "responses_test", + "RESPONSES_STORE_CACHE_PREFIX": "responses", + "RESPONSES_STORE_CACHE_DIR": str(tmp_path), + "RESPONSES_STORE_CACHE_MEMORY_TTL_SECONDS": "60", + "RESPONSES_STORE_CACHE_DISK_TTL_SECONDS": "60", + } + first_store = create_configured_responses_store(env=env) + stored = _stored("resp_durable") + + await first_store.save(stored) + second_store = create_configured_responses_store(env=env) + loaded = await second_store.get(stored.id) + + assert loaded is not None + assert loaded.to_dict() == stored.to_dict() From 8d0c88fd6e9b30b431fd7e0c13aab3f4b6b69b67 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 20:33:14 +0200 Subject: [PATCH 161/182] fix(responses): preserve tool-call lineage Replays parent Responses function-call output items as chat tool calls during previous_response_id continuations and documents durable Responses store env vars. Tests: pytest tests/test_experimental_config.py tests/test_responses_store.py tests/test_responses_bridge.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_streaming.py tests/test_protocol_responses.py tests/test_responses_usage_accounting.py tests/test_stream_transport.py --- .env.example | 11 ++++++++-- src/rotator_library/responses/bridge.py | 28 +++++++++++++++++++++---- tests/test_responses_bridge.py | 20 ++++++++++++++++++ 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/.env.example b/.env.example index 88ff74cb5..2f6fc45d6 100644 --- a/.env.example +++ b/.env.example @@ -369,9 +369,16 @@ # FAILURE_HISTORY_MAX_ENTRIES=200 # --- Responses API Store Policy --- -# Responses are stored in-memory by default. TTL/max limits are disabled when -# unset or <= 0. Failed stream responses are stored by default; in-progress +# Responses are stored in-memory by default. Use provider_cache for durable JSON +# storage via the existing provider-cache layer. TTL/max limits are disabled +# when unset or <= 0. Failed stream responses are stored by default; in-progress # streaming state is disabled unless explicitly enabled. +# RESPONSES_STORE_BACKEND=memory +# RESPONSES_STORE_CACHE_NAME=responses +# RESPONSES_STORE_CACHE_PREFIX=responses +# RESPONSES_STORE_CACHE_DIR= +# RESPONSES_STORE_CACHE_MEMORY_TTL_SECONDS=3600 +# RESPONSES_STORE_CACHE_DISK_TTL_SECONDS=172800 # RESPONSES_STORE_TTL_SECONDS=0 # RESPONSES_STORE_MAX_ITEMS=0 # RESPONSES_STORE_FAILED=true diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index abd578994..aef9e8470 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -187,14 +187,34 @@ def _tool_to_chat(tool: ToolDefinition) -> dict[str, Any]: def _parent_output_to_messages(output: list[Any]) -> list[dict[str, Any]]: messages = [] for item in output: - if not isinstance(item, dict) or item.get("type") != "message": + if not isinstance(item, dict): continue - text = _responses_content_to_text(item.get("content") or []) - if text: - messages.append({"role": item.get("role") or "assistant", "content": text}) + if item.get("type") == "message": + text = _responses_content_to_text(item.get("content") or []) + if text: + messages.append({"role": item.get("role") or "assistant", "content": text}) + elif item.get("type") in {"function_call", "custom_tool_call"}: + tool_call = _responses_tool_call_to_chat(item) + if tool_call: + messages.append({"role": "assistant", "content": None, "tool_calls": [tool_call]}) return messages +def _responses_tool_call_to_chat(item: dict[str, Any]) -> dict[str, Any] | None: + """Convert a stored Responses tool-call item to OpenAI chat shape.""" + + call_id = item.get("call_id") or item.get("id") + name = item.get("name") or item.get("tool_name") or item.get("type") + arguments = item.get("arguments", item.get("input", "")) + if call_id is None or name is None: + return None + return { + "id": call_id, + "type": "function", + "function": {"name": str(name), "arguments": arguments if isinstance(arguments, str) else serialize_value(arguments)}, + } + + def _parent_request_to_messages(request: dict[str, Any]) -> list[dict[str, Any]]: """Replay convertible parent Responses input items as chat messages.""" diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py index 22691aef7..db0b72f37 100644 --- a/tests/test_responses_bridge.py +++ b/tests/test_responses_bridge.py @@ -72,6 +72,26 @@ def test_bridge_replays_parent_input_and_output_lineage() -> None: ] +def test_bridge_replays_parent_tool_call_outputs() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Continue"}) + lineage = [ + { + "request": {"model": "gpt-test", "input": "Use tool"}, + "response": {"output": [{"id": "call_1", "type": "function_call", "call_id": "call_1", "name": "lookup", "arguments": "{}"}]}, + } + ] + + kwargs = bridge.to_chat_kwargs(unified, parent_responses=lineage) + + assert kwargs["messages"] == [ + {"role": "user", "content": "Use tool"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}]}, + {"role": "user", "content": "Continue"}, + ] + + def test_bridge_preserves_tool_definitions() -> None: protocol = ResponsesProtocol() bridge = ResponsesBridge(protocol) From 4d727a3eee5d6cd261ce9bd0c4b90fca438398fd Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 20:38:35 +0200 Subject: [PATCH 162/182] fix(responses): replay tool-result lineage Converts stored Responses function_call_output inputs into Chat tool messages during continuation lineage replay so tool-use continuations keep both calls and results. Tests: pytest tests/test_experimental_config.py tests/test_responses_store.py tests/test_responses_bridge.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_streaming.py tests/test_protocol_responses.py tests/test_responses_usage_accounting.py tests/test_stream_transport.py --- src/rotator_library/responses/bridge.py | 17 ++++++++++++++++ tests/test_responses_bridge.py | 26 +++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index aef9e8470..3c7a530aa 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -139,6 +139,12 @@ def responses_session_hints(previous_response_id: Optional[str], *, affinity_key def _message_to_chat(message: UnifiedMessage) -> dict[str, Any]: + if message.role == "tool": + result = _tool_result_from_message(message) + payload = {"role": "tool", "content": result["content"]} + if result.get("tool_call_id"): + payload["tool_call_id"] = result["tool_call_id"] + return payload payload = {"role": message.role, "content": _blocks_to_chat_content(message.content)} if message.name: payload["name"] = message.name @@ -149,6 +155,17 @@ def _message_to_chat(message: UnifiedMessage) -> dict[str, Any]: return payload +def _tool_result_from_message(message: UnifiedMessage) -> dict[str, Any]: + for block in message.content: + if block.tool_result is not None: + content = block.tool_result.content + return { + "tool_call_id": message.tool_call_id or block.tool_result.tool_call_id, + "content": content if isinstance(content, str) else serialize_value(content), + } + return {"tool_call_id": message.tool_call_id, "content": _blocks_to_chat_content(message.content)} + + def _blocks_to_chat_content(blocks: list[ContentBlock]) -> Any: if not blocks: return "" diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py index db0b72f37..dc84e645a 100644 --- a/tests/test_responses_bridge.py +++ b/tests/test_responses_bridge.py @@ -92,6 +92,32 @@ def test_bridge_replays_parent_tool_call_outputs() -> None: ] +def test_bridge_replays_parent_tool_result_inputs_as_tool_messages() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Continue"}) + lineage = [ + { + "request": { + "model": "gpt-test", + "input": [ + {"type": "message", "role": "user", "content": "Use tool"}, + {"type": "function_call_output", "call_id": "call_1", "output": "tool result"}, + ], + }, + "response": {"output": []}, + } + ] + + kwargs = bridge.to_chat_kwargs(unified, parent_responses=lineage) + + assert kwargs["messages"] == [ + {"role": "user", "content": "Use tool"}, + {"role": "tool", "content": "tool result", "tool_call_id": "call_1"}, + {"role": "user", "content": "Continue"}, + ] + + def test_bridge_preserves_tool_definitions() -> None: protocol = ResponsesProtocol() bridge = ResponsesBridge(protocol) From 6ee91b96121e1a126b5cbde6dd26f25d9c856397 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 20:46:33 +0200 Subject: [PATCH 163/182] docs(experimental): plan provider-native correction Adds the Phase 5c corrective plan for client-protocol native responses, provider native contract hooks, Claude Code hardening, Antigravity alias preservation, and centralized native streaming fail-closed behavior. Tests: not run (planning document only) --- .../phase-5c-provider-native-correction.md | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 docs/experimental/phase-5c-provider-native-correction.md diff --git a/docs/experimental/phase-5c-provider-native-correction.md b/docs/experimental/phase-5c-provider-native-correction.md new file mode 100644 index 000000000..f07939758 --- /dev/null +++ b/docs/experimental/phase-5c-provider-native-correction.md @@ -0,0 +1,72 @@ +# Phase 5c: Provider-Native Output Protocol And Contract Correction + +## Goal + +Close the Phase 5/5b third-pass provider-native findings while preserving the current safety stance: native streaming remains disabled for priority providers until each live stream path is proven, and LiteLLM remains fallback only for unsupported cases. + +## Scope + +- Return native provider responses in the originating client protocol instead of the provider protocol. +- Add explicit native endpoint/header/operation-selection hooks to `ProviderInterface`. +- Ensure Claude Code native requests always include required Anthropic `max_tokens`. +- Harden Claude Code API-key header behavior without breaking bearer/OAuth-style credentials. +- Fix Antigravity alias normalization so duplicated aliases do not collapse incorrectly and low/high thinking aliases can still affect request metadata. +- Centralize explicit-native streaming fail-closed behavior. +- Keep priority-provider native streaming disabled by default. +- Add focused tests for all blocker/high/medium findings. + +## Non-Goals + +- Do not enable live native streaming for Claude Code, Codex, Copilot, or Antigravity. +- Do not restore retired Antigravity device-profile/fingerprint behavior. +- Do not replace credential rotation, usage tracking, or routing. +- Do not implement all rich Codex Responses item conversions in this phase. +- Do not commit user-facing reports. + +## Implementation Plan + +1. Client target protocol in native context. + - Add `client_protocol_name` to `NativeProviderContext`. + - Default direct/library native use to provider protocol for backwards compatibility. + - Set `client_protocol_name="openai_chat"` from the chat-completions executor path. + - Format non-streaming native responses using the client protocol after provider protocol parsing. + - Format native stream events using the client protocol when a client protocol is set. + +2. Provider interface native contract. + - Add default `get_native_endpoint()` and `get_native_headers()` methods that raise clear `NotImplementedError`. + - Add `supports_native_operation()` defaulting to operation support through provider declarations. + - Add `should_use_native_protocol()` defaulting to true only when a provider declares a native protocol and operation is supported. + - Update executor checks to call hooks rather than relying on `hasattr()`. + +3. Claude Code hardening. + - `prepare_native_request()` ensures Anthropic `max_tokens` is present, using existing request value or `CLAUDE_CODE_MAX_TOKENS` default. + - Header selection supports bearer credentials and API-key credentials: + - `CLAUDE_CODE_AUTH_HEADER=bearer|x-api-key|auto` + - auto uses `x-api-key` for `sk-ant-*` style keys and bearer otherwise. + - Preserve `anthropic-version` and content type. + +4. Antigravity alias/thinking metadata. + - Fix alias-to-upstream mapping to use public alias map directly, not a lossy reverse map. + - Preserve thinking-level hints for `gemini-3-pro-low` / `gemini-3-pro-high` in request metadata before the upstream model is normalized. + - Keep model discovery output stable and prefixed. + +5. Native streaming fail-closed helper. + - Centralize explicit-native streaming unsupported handling so auto and explicit modes use the same fail-closed decision. + - Keep priority providers non-streaming native only. + +6. Tests. + - Native executor: provider protocol response -> OpenAI chat client response. + - Request executor: routed native Claude/Codex/Antigravity chat completions return OpenAI Chat shape. + - Claude Code: missing `max_tokens` gets defaulted and auth header modes behave correctly. + - Antigravity: low/high aliases normalize safely and preserve request metadata. + - ProviderInterface: native hooks exist and unsupported operation checks are explicit. + - Streaming: explicit native streaming fail-closed path uses centralized helper. + +## Acceptance Criteria + +- `/v1/chat/completions` native routes return OpenAI Chat response shape regardless of provider-native protocol. +- Claude Code native Messages requests include `max_tokens`. +- Antigravity model normalization does not lose low/high alias intent. +- Provider native endpoint/header/operation hooks are explicit on `ProviderInterface`. +- Explicit native streaming fails closed through one helper path. +- Focused tests pass and both reviewers report no blockers/highs/mediums. From 342a62ec24b1613c6542e9bf78e2b243be04cccb Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 21:05:14 +0200 Subject: [PATCH 164/182] fix(providers): return client protocol from native calls Carries a client protocol through native execution so chat-completion routes receive OpenAI Chat responses from Anthropic, Responses, and Gemini native provider calls. Adds explicit provider native endpoint/header/operation hooks, Claude Code max_tokens/auth hardening, and safer Antigravity alias/thinking metadata handling. Tests: pytest tests/test_request_executor_native_routing.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_native_streaming_transport_seam.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_provider_protocol_declarations.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_responses.py tests/test_protocol_gemini.py --- src/rotator_library/client/executor.py | 15 +++++++--- .../native_provider/context.py | 3 +- .../native_provider/executor.py | 8 +++-- src/rotator_library/protocols/openai_chat.py | 23 +++++++++++++- .../providers/antigravity_provider.py | 18 ++++++++++- .../providers/claude_code_provider.py | 17 +++++++++-- .../providers/provider_interface.py | 30 +++++++++++++++++++ tests/test_request_executor_native_routing.py | 29 ++++++++++++++++-- 8 files changed, 130 insertions(+), 13 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 72790d6c4..bb4411d4c 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -675,18 +675,20 @@ def _build_native_provider_context( protocol_name = _provider_native_protocol(plugin, model, target) if not protocol_name: raise RoutingExecutionError(f"Provider {provider} has no native protocol declaration") - if not hasattr(plugin, "get_native_endpoint") or not hasattr(plugin, "get_native_headers"): - raise RoutingExecutionError(f"Provider {provider} has no native endpoint/header helpers") public_model = model native_model = plugin.normalize_native_model(model) if hasattr(plugin, "normalize_native_model") else _strip_provider_prefix(model) request_payload = _native_request_payload(raw_request or {}) + request_payload["_proxy_model"] = public_model if native_model: request_payload["model"] = native_model operation = plugin.get_native_operation(native_model, request_payload, stream=stream) if hasattr(plugin, "get_native_operation") else "chat" + if hasattr(plugin, "supports_native_operation") and not plugin.supports_native_operation(native_model, operation): + raise RoutingExecutionError(f"Provider {provider} does not support native operation {operation}") if hasattr(plugin, "prepare_native_request"): prepared = plugin.prepare_native_request(request_payload, model=native_model, operation=operation) if prepared is not request_payload: request_payload = dict(prepared) + request_payload.pop("_proxy_model", None) self._log_executor_trace( context, "provider_native_request_prepared", @@ -696,14 +698,19 @@ def _build_native_provider_context( credential_id=credential_id, metadata={"provider": provider, "model": public_model, "native_model": native_model, "operation": operation}, ) - endpoint = plugin.get_native_endpoint(model=native_model, operation=operation) - headers = plugin.get_native_headers(credential_secret, model=native_model, operation=operation) + request_payload.pop("_proxy_model", None) + try: + endpoint = plugin.get_native_endpoint(model=native_model, operation=operation) + headers = plugin.get_native_headers(credential_secret, model=native_model, operation=operation) + except NotImplementedError as exc: + raise RoutingExecutionError(str(exc)) from exc native_context = NativeProviderContext( provider=provider, model=native_model, protocol_name=protocol_name, endpoint=endpoint, operation=operation, + client_protocol_name="openai_chat", headers=headers, credential_id=credential_id, session_id=context.session_id, diff --git a/src/rotator_library/native_provider/context.py b/src/rotator_library/native_provider/context.py index bacf49b3f..cbfb91012 100644 --- a/src/rotator_library/native_provider/context.py +++ b/src/rotator_library/native_provider/context.py @@ -27,6 +27,7 @@ class NativeProviderContext: protocol_name: str endpoint: str operation: str = "chat" + client_protocol_name: Optional[str] = None headers: dict[str, str] = field(default_factory=dict) credential_id: Optional[str] = None session_id: Optional[str] = None @@ -46,7 +47,7 @@ def protocol_context(self, *, target_protocol: Optional[str] = None) -> Protocol provider=self.provider, model=self.model, source_protocol=self.protocol_name, - target_protocol=target_protocol or self.protocol_name, + target_protocol=target_protocol or self.client_protocol_name or self.protocol_name, session_id=self.session_id, credential_stable_id=self.credential_id, transport=self.transport, diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index 1e7f673eb..da0336c5d 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -73,7 +73,10 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont self._trace(context, "parsed_native_unified_response", unified_response, direction="response", stage="protocol") await cache_engine.extract("unified_response", serialize_value(unified_response), context.field_cache_context(), transaction_logger=logger) self._trace(context, "after_unified_response_field_cache_extraction", {"source": "unified_response"}, direction="response", stage="adapter", snapshot=False) - provider_response = protocol.format_response(unified_response, protocol_context) + response_protocol = get_protocol(context.client_protocol_name) if context.client_protocol_name else protocol + response_context = context.protocol_context(target_protocol=response_protocol.name) + self._trace(context, "native_response_protocol_selected", {"protocol": response_protocol.name}, direction="metadata", stage="protocol", snapshot=False) + provider_response = response_protocol.format_response(unified_response, response_context) self._trace(context, "formatted_native_response", provider_response, direction="response", stage="protocol") adapter_context = context.adapter_context() adapter_context.transaction_logger = None @@ -186,7 +189,8 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte stage="adapter", snapshot=False, ) - formatted = protocol.format_stream_event(event, protocol_context) + response_protocol = get_protocol(context.client_protocol_name) if context.client_protocol_name else protocol + formatted = response_protocol.format_stream_event(event, context.protocol_context(target_protocol=response_protocol.name)) self._trace(context, "formatted_client_stream_event", formatted, direction="stream", stage="final", snapshot=False) yield formatted except Exception as exc: diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index 27c4659df..d9b196459 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -157,7 +157,7 @@ def format_response(self, unified_response: UnifiedResponse, context: ProtocolCo choices.append( { "index": index, - "message": self._format_message(message), + "message": _format_response_message(self._format_message(message), message), "finish_reason": unified_response.stop_reason, } ) @@ -495,5 +495,26 @@ def _format_openai_usage(usage: Usage | None) -> dict[str, Any] | None: return payload +def _format_response_message(payload: dict[str, Any], message: UnifiedMessage) -> dict[str, Any]: + """Return Chat Completions response-message shape. + + Request messages may legitimately preserve content-part arrays. Assistant + response messages from non-chat native protocols often arrive as text parts; + Chat Completions clients expect the final message content to be a string in + that common case. + """ + + if payload.get("content") is not None and message.content: + if all(block.type in {"text", "input_text", "output_text"} and not block.extra for block in message.content): + payload = dict(payload) + payload["content"] = _first_response_text(message.content) or "" + return payload + + +def _first_response_text(blocks: Iterable[ContentBlock]) -> Optional[str]: + parts = [block.text for block in blocks if block.type in {"text", "input_text", "output_text"} and block.text] + return "".join(parts) if parts else first_text(blocks) + + def _without(payload: dict[str, Any], keys: set[str]) -> dict[str, Any]: return {k: deepcopy(v) for k, v in payload.items() if k not in keys} diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index c6d61e406..675f3d16e 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -131,8 +131,13 @@ def prepare_native_request(self, request: dict[str, Any], model: str = "", opera """ prepared = dict(request) + public_model = str(request.get("_proxy_model") or request.get("model") or "") if model: prepared["model"] = model + prepared.pop("_proxy_model", None) + thinking_level = _thinking_level_from_model(public_model) + if thinking_level: + prepared.setdefault("metadata", {})["thinking_level"] = thinking_level if "contents" not in prepared and isinstance(prepared.get("messages"), list): prepared["contents"] = [_message_to_gemini_content(message) for message in prepared.pop("messages")] return prepared @@ -183,7 +188,9 @@ def _with_prefix(model: str) -> str: @staticmethod def _alias_to_internal(alias: str) -> str: - return MODEL_ALIAS_REVERSE.get(alias, alias) + if alias in {"rev19-uic3-1p", "gemini-3-pro-image", "gemini-3-pro-low", "gemini-3-pro-high"}: + return MODEL_ALIAS_MAP.get(alias, alias) + return MODEL_ALIAS_REVERSE.get(alias, MODEL_ALIAS_MAP.get(alias, alias)) @staticmethod def _api_to_user_model(internal: str) -> str: @@ -212,3 +219,12 @@ def _message_to_gemini_content(message: Any) -> dict[str, Any]: content = message.get("content", "") parts = content if isinstance(content, list) else [{"text": str(content)}] return {"role": role, "parts": parts} + + +def _thinking_level_from_model(model: str) -> Optional[str]: + clean = model.split("/", 1)[1] if model.startswith("antigravity/") else model + if clean.endswith("-low"): + return "low" + if clean.endswith("-high"): + return "high" + return None diff --git a/src/rotator_library/providers/claude_code_provider.py b/src/rotator_library/providers/claude_code_provider.py index 7e6a198e6..40f5c373b 100644 --- a/src/rotator_library/providers/claude_code_provider.py +++ b/src/rotator_library/providers/claude_code_provider.py @@ -69,11 +69,17 @@ def get_api_base(self) -> str: def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "messages") -> dict[str, str]: """Return headers for native mocked HTTP requests.""" - return { - "Authorization": f"Bearer {credential_identifier}", + mode = os.getenv("CLAUDE_CODE_AUTH_HEADER", "auto").strip().lower() + use_api_key = mode == "x-api-key" or (mode == "auto" and credential_identifier.startswith("sk-ant-")) + headers = { "anthropic-version": os.getenv("CLAUDE_CODE_ANTHROPIC_VERSION", "2023-06-01"), "content-type": "application/json", } + if use_api_key: + headers["x-api-key"] = credential_identifier + else: + headers["Authorization"] = f"Bearer {credential_identifier}" + return headers def get_native_operation(self, model: str = "", request: dict[str, Any] | None = None, stream: bool = False) -> str: """Claude Code uses the Anthropic Messages operation for completions.""" @@ -90,6 +96,13 @@ def supports_native_streaming(self, model: str = "", operation: str = "messages" return False + def prepare_native_request(self, request: dict[str, Any], model: str = "", operation: str = "messages") -> dict[str, Any]: + """Ensure Anthropic Messages-required fields are present.""" + + prepared = dict(request) + prepared.setdefault("max_tokens", int(os.getenv("CLAUDE_CODE_MAX_TOKENS", "4096"))) + return prepared + def get_native_endpoint(self, model: str = "", operation: str = "messages") -> str: """Return the provider endpoint for a native operation.""" diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index 3ce8817af..710b44546 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -373,6 +373,26 @@ def supports_native_streaming(self, model: str = "", operation: str = "chat") -> return self.native_streaming_supported + def supports_native_operation(self, model: str = "", operation: str = "chat") -> bool: + """Return whether this provider supports a native operation.""" + + protocol_name = self.get_protocol_name(model) + if not protocol_name: + return False + try: + from ..protocols import get_protocol + + return get_protocol(protocol_name).supports_operation(operation) + except Exception: + return False + + def should_use_native_protocol(self, model: str = "", operation: str = "chat", *, stream: bool = False, execution: str = "auto") -> bool: + """Return whether routing should use this provider's native protocol.""" + + if stream and not self.supports_native_streaming(model, operation): + return False + return bool(self.get_protocol_name(model) and self.supports_native_operation(model, operation)) + def get_native_operation(self, model: str = "", request: Optional[Dict[str, Any]] = None, stream: bool = False) -> str: """Return the provider-native operation for a request. @@ -385,6 +405,16 @@ def get_native_operation(self, model: str = "", request: Optional[Dict[str, Any] return "chat" + def get_native_endpoint(self, model: str = "", operation: str = "chat") -> str: + """Return the upstream endpoint for a native operation.""" + + raise NotImplementedError(f"{self.__class__.__name__} does not define a native endpoint") + + def get_native_headers(self, credential_identifier: str, model: str = "", operation: str = "chat") -> Dict[str, str]: + """Return non-payload HTTP headers for native requests.""" + + raise NotImplementedError(f"{self.__class__.__name__} does not define native headers") + def normalize_native_model(self, model: str) -> str: """Return the upstream model name for native provider calls. diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 7218472d8..b3a899a35 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -279,8 +279,10 @@ async def test_claude_code_provider_runs_mock_live_native_request(monkeypatch) - response = await _executor(http_client)._execute_provider_request("claude_code", context.model, provider, "secret", "stable", dict(context.kwargs), context) assert response["id"] == "msg_1" + assert response["choices"][0]["message"]["content"] == "ok" assert http_client.calls[0]["endpoint"] == "https://claude-code.test/v1/messages" assert http_client.calls[0]["json"]["model"] == "claude-sonnet-4-5" + assert http_client.calls[0]["json"]["max_tokens"] == 4096 assert http_client.calls[0]["json"]["messages"][0]["role"] == "user" @@ -298,6 +300,7 @@ async def test_codex_provider_runs_mock_live_native_request(monkeypatch) -> None response = await _executor(http_client)._execute_provider_request("codex", context.model, provider, "secret", "stable", dict(context.kwargs), context) assert response["id"] == "resp_1" + assert response["choices"][0]["message"]["content"] == "ok" assert http_client.calls[0]["endpoint"] == "https://codex.test/v1/responses" assert http_client.calls[0]["json"]["model"] == "gpt-5.1-codex" assert http_client.calls[0]["json"]["input"][0]["content"] == [{"type": "text", "text": "hi"}] @@ -336,13 +339,35 @@ async def test_antigravity_provider_runs_mock_live_native_request(monkeypatch) - response = await _executor(http_client)._execute_provider_request("antigravity", context.model, provider, "secret", "stable", dict(context.kwargs), context) - assert response["candidates"][0]["content"]["parts"][0]["text"] == "ok" + assert response["choices"][0]["message"]["content"] == "ok" assert http_client.calls[0]["endpoint"] == "https://antigravity.test/v1internal:generateContent" assert http_client.calls[0]["json"]["model"] == "claude-sonnet-4-5" assert http_client.calls[0]["json"]["request"]["contents"][0]["parts"][0]["text"] == "hi" assert http_client.calls[0]["json"]["requestType"] == "CHAT_COMPLETION" assert "requestId" in http_client.calls[0]["json"] - assert "messages" not in http_client.calls[0]["json"]["request"] + + +def test_claude_code_header_modes(monkeypatch) -> None: + provider = ClaudeCodeProvider() + + monkeypatch.setenv("CLAUDE_CODE_AUTH_HEADER", "auto") + assert provider.get_native_headers("sk-ant-test")["x-api-key"] == "sk-ant-test" + assert provider.get_native_headers("oauth-token")["Authorization"] == "Bearer oauth-token" + monkeypatch.setenv("CLAUDE_CODE_AUTH_HEADER", "x-api-key") + assert provider.get_native_headers("any-token")["x-api-key"] == "any-token" + + +def test_antigravity_alias_normalization_preserves_thinking_level() -> None: + provider = AntigravityProvider() + request = {"_proxy_model": "antigravity/gemini-3-pro-low", "model": "gemini-3-pro-preview", "messages": [{"role": "user", "content": "hi"}]} + + prepared = provider.prepare_native_request(request, model=provider.normalize_native_model("antigravity/gemini-3-pro-low"), operation="generate") + + assert provider.normalize_native_model("antigravity/gemini-3-pro-low") == "gemini-3-pro-preview" + assert prepared["metadata"]["thinking_level"] == "low" + assert "_proxy_model" not in prepared + assert "messages" not in prepared + assert prepared["contents"][0]["parts"][0]["text"] == "hi" def test_native_request_payload_drops_litellm_only_fields() -> None: From f53cb866ae905a69124a4125a1fea9cec9c4b2e8 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 21:14:59 +0200 Subject: [PATCH 165/182] fix(providers): harden native selection and streaming Applies Antigravity thinking aliases to upstream request behavior, honors provider native opt-out hooks in auto mode, and formats cross-protocol native stream events as client OpenAI Chat SSE instead of raw provider chunks. Tests: pytest tests/test_request_executor_native_routing.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_native_streaming_transport_seam.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_provider_protocol_declarations.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_responses.py tests/test_protocol_gemini.py --- src/rotator_library/client/executor.py | 21 +++++++++++++++- src/rotator_library/protocols/openai_chat.py | 14 ++++++++++- .../providers/antigravity_provider.py | 12 ++++++++-- tests/test_antigravity_provider_restore.py | 4 +++- tests/test_native_provider_streaming.py | 24 +++++++++++++++++++ tests/test_request_executor_native_routing.py | 18 ++++++++++++++ 6 files changed, 88 insertions(+), 5 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index bb4411d4c..413a4d8e9 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -598,7 +598,7 @@ async def _execute_provider_request( ) return await plugin.acompletion(self._http_client, **kwargs) - if execution == "native" or (execution == "auto" and _provider_native_protocol(plugin, model, target)): + if execution == "native" or (execution == "auto" and _should_use_native_protocol(plugin, model, target, kwargs, stream=False, execution=execution)): native_context, native_request = self._build_native_provider_context( provider, model, @@ -2405,6 +2405,25 @@ def _provider_native_protocol(plugin: Any, model: str, target: Optional[RouteTar return None +def _should_use_native_protocol(plugin: Any, model: str, target: Optional[RouteTarget], kwargs: Dict[str, Any], *, stream: bool, execution: str) -> bool: + """Return whether auto routing should use provider-native execution.""" + + protocol_name = _provider_native_protocol(plugin, model, target) + if not plugin or not protocol_name: + return False + native_model = plugin.normalize_native_model(model) if hasattr(plugin, "normalize_native_model") else _strip_provider_prefix(model) + request_payload = _native_request_payload(kwargs) + request_payload["_proxy_model"] = model + if native_model: + request_payload["model"] = native_model + operation = plugin.get_native_operation(native_model, request_payload, stream=stream) if hasattr(plugin, "get_native_operation") else "chat" + hook = getattr(plugin, "should_use_native_protocol", None) + if callable(hook): + return bool(hook(model=native_model, operation=operation, stream=stream, execution=execution)) + support = getattr(plugin, "supports_native_operation", None) + return bool(support(native_model, operation) if callable(support) else True) + + def _strip_provider_prefix(model: str) -> str: """Return model without the proxy-facing provider prefix.""" diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index d9b196459..b75f31bbf 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -209,7 +209,7 @@ def parse_stream_event(self, raw_event: Any, context: ProtocolContext | None = N def format_stream_event(self, unified_event: UnifiedStreamEvent, context: ProtocolContext | None = None) -> Any: if unified_event.type == "done": return "data: [DONE]\n\n" - if unified_event.raw is not None: + if unified_event.raw is not None and (context is None or context.source_protocol in {None, "openai_chat"}): payload = deepcopy(unified_event.raw) if unified_event.delta is not None and isinstance(payload, dict) and isinstance(payload.get("choices"), list) and payload["choices"]: choice = payload["choices"][0] @@ -220,6 +220,18 @@ def format_stream_event(self, unified_event: UnifiedStreamEvent, context: Protoc formatted_delta.pop("role", None) choice["delta"] = formatted_delta return payload + if unified_event.delta is not None: + delta = _format_response_message(self._format_message(unified_event.delta), unified_event.delta) + if unified_event.extra.get("finish_reason") is None: + delta.pop("role", None) + payload = { + "id": unified_event.extra.get("id"), + "object": "chat.completion.chunk", + "model": unified_event.extra.get("model"), + "choices": [{"index": 0, "delta": delta, "finish_reason": unified_event.extra.get("finish_reason")}], + "usage": _format_openai_usage(unified_event.usage), + } + return f"data: {json.dumps({k: v for k, v in payload.items() if v is not None})}\n\n" return f"data: {json.dumps(unified_event.to_dict())}\n\n" def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = None) -> Usage | None: diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index 675f3d16e..6daf3b639 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -132,11 +132,13 @@ def prepare_native_request(self, request: dict[str, Any], model: str = "", opera prepared = dict(request) public_model = str(request.get("_proxy_model") or request.get("model") or "") + thinking_level = _thinking_level_from_model(public_model) if model: - prepared["model"] = model + prepared["model"] = _model_with_thinking_variant(model, thinking_level) prepared.pop("_proxy_model", None) - thinking_level = _thinking_level_from_model(public_model) if thinking_level: + generation_config = prepared.setdefault("generationConfig", {}) + generation_config.setdefault("thinkingConfig", {})["thinkingLevel"] = thinking_level prepared.setdefault("metadata", {})["thinking_level"] = thinking_level if "contents" not in prepared and isinstance(prepared.get("messages"), list): prepared["contents"] = [_message_to_gemini_content(message) for message in prepared.pop("messages")] @@ -228,3 +230,9 @@ def _thinking_level_from_model(model: str) -> Optional[str]: if clean.endswith("-high"): return "high" return None + + +def _model_with_thinking_variant(model: str, thinking_level: Optional[str]) -> str: + if model == "gemini-3-pro-preview" and thinking_level in {"low", "high"}: + return f"gemini-3-pro-{thinking_level}" + return model diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py index b8039ff57..3dc86e81f 100644 --- a/tests/test_antigravity_provider_restore.py +++ b/tests/test_antigravity_provider_restore.py @@ -74,7 +74,9 @@ def test_antigravity_native_operation_model_and_stream_support() -> None: assert provider.get_native_operation("gemini-3-flash", {}, stream=True) == "stream_generate" assert provider.supports_native_streaming("gemini-3-flash", operation="stream_generate") is False assert provider.supports_native_streaming("gemini-3-flash", operation="generate") is False - assert provider.prepare_native_request({"model": "antigravity/gemini-3-pro-low"}, model="gemini-3-pro-preview", operation="generate")["model"] == "gemini-3-pro-preview" + prepared = provider.prepare_native_request({"model": "antigravity/gemini-3-pro-low"}, model="gemini-3-pro-preview", operation="generate") + assert prepared["model"] == "gemini-3-pro-low" + assert prepared["generationConfig"]["thinkingConfig"]["thinkingLevel"] == "low" def test_antigravity_prepare_native_request_converts_messages_to_gemini_contents() -> None: diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index bd84d4a5a..d3109e86b 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -112,6 +112,30 @@ async def transform_stream_event(self, payload, context): assert "after_stream_event_adapter_chain" in pass_names +@pytest.mark.asyncio +async def test_native_cross_protocol_stream_formats_openai_chat_sse() -> None: + context = NativeProviderContext( + provider="claude_code", + model="claude-sonnet-4-5", + protocol_name="anthropic_messages", + client_protocol_name="openai_chat", + endpoint="https://example.test/messages", + operation="messages", + ) + chunks = [ + {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "hi"}}, + "[DONE]", + ] + + events = [event async for event in NativeProviderExecutor().stream({"model": "claude-sonnet-4-5", "messages": [], "max_tokens": 1}, context, NativeHTTPTransport(FakeStreamingClient(chunks)))] + + assert events[0].startswith("data: ") + payload = json.loads(events[0][len("data: ") :].strip()) + assert payload["object"] == "chat.completion.chunk" + assert payload["choices"][0]["delta"]["content"] == "hi" + assert "content_block_delta" not in events[0] + + @pytest.mark.asyncio async def test_native_provider_stream_extracts_unified_stream_events_for_later_requests() -> None: rule = FieldCacheRule( diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index b3a899a35..4da465108 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -112,6 +112,11 @@ def get_field_cache_rules(self, model=""): ) +class NativeOptOutPlugin(NativePlugin): + def should_use_native_protocol(self, model="", operation="chat", *, stream=False, execution="auto"): + return False + + class CustomPlugin: def __init__(self): self.calls = [] @@ -364,12 +369,25 @@ def test_antigravity_alias_normalization_preserves_thinking_level() -> None: prepared = provider.prepare_native_request(request, model=provider.normalize_native_model("antigravity/gemini-3-pro-low"), operation="generate") assert provider.normalize_native_model("antigravity/gemini-3-pro-low") == "gemini-3-pro-preview" + assert prepared["model"] == "gemini-3-pro-low" + assert prepared["generationConfig"]["thinkingConfig"]["thinkingLevel"] == "low" assert prepared["metadata"]["thinking_level"] == "low" assert "_proxy_model" not in prepared assert "messages" not in prepared assert prepared["contents"][0]["parts"][0]["text"] == "hi" +def test_auto_native_selection_honors_provider_opt_out() -> None: + assert executor_module._should_use_native_protocol( + NativeOptOutPlugin(), + "provider/gpt-test", + None, + {"model": "provider/gpt-test", "messages": []}, + stream=False, + execution="auto", + ) is False + + def test_native_request_payload_drops_litellm_only_fields() -> None: payload = executor_module._native_request_payload( { From ff51efeb1e3d95a4b3b4883baefc46310d7df22a Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 21:23:47 +0200 Subject: [PATCH 166/182] fix(providers): close native streaming and cache gaps Uses the fail-closed native streaming support helper in live routing, extracts provider-native response cache fields before client-protocol formatting, and scopes Antigravity quota groups by model family instead of broad Gemini/Claude buckets. Tests: pytest tests/test_request_executor_native_routing.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_native_streaming_transport_seam.py tests/test_claude_code_provider.py tests/test_codex_provider.py tests/test_copilot_provider.py tests/test_antigravity_provider_restore.py tests/test_provider_protocol_declarations.py tests/test_protocol_openai_chat.py tests/test_protocol_anthropic_messages.py tests/test_protocol_responses.py tests/test_protocol_gemini.py --- src/rotator_library/client/executor.py | 16 +++++----- .../native_provider/executor.py | 3 +- .../providers/antigravity_provider.py | 9 ++++-- tests/test_antigravity_provider_restore.py | 10 ++++++ tests/test_native_provider_executor.py | 32 +++++++++++++++++++ 5 files changed, 59 insertions(+), 11 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 413a4d8e9..1469fd1e2 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -68,6 +68,7 @@ from ..routing.policy import normalize_route_error_type from ..routing.types import RouteTarget from ..native_provider import NativeHTTPTransport, NativeProviderContext, NativeProviderExecutor +from ..native_provider.streaming import provider_supports_native_streaming as native_provider_supports_streaming from ..field_cache.paths import FieldCachePathError, PathToken, parse_path from ..transform_trace import REDACTED from ..usage.accounting import UsageRecord, extract_usage_record @@ -2470,14 +2471,13 @@ def _provider_supports_native_streaming(plugin: Any, model: str) -> bool: try: operation = resolver(model, {"model": model, "stream": True}, stream=True) except TypeError: - operation = resolver(model) - try: - return bool(support(model=model, operation=operation)) - except TypeError: - try: - return bool(support(model)) - except TypeError: - return bool(support()) + try: + operation = resolver(model) + except Exception: + return False + except Exception: + return False + return native_provider_supports_streaming(plugin, model=model, operation=operation) def _should_use_native_streaming(plugin: Any, model: str, target: Optional[RouteTarget], execution: str, provider: str) -> bool: diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index da0336c5d..aa3f51963 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -69,6 +69,8 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont self._trace(context, "native_provider_request", provider_request, direction="request", stage="provider") raw_response = await transport.post_json(context.endpoint, headers=context.headers, payload=provider_request) self._trace(context, "raw_native_provider_response", raw_response, direction="response", stage="provider") + await cache_engine.extract("response", raw_response, context.field_cache_context(), transaction_logger=logger) + self._trace(context, "after_response_field_cache_extraction", {"source": "response", "payload": "raw_provider_response"}, direction="response", stage="adapter", snapshot=False) unified_response = protocol.parse_response(raw_response, protocol_context) self._trace(context, "parsed_native_unified_response", unified_response, direction="response", stage="protocol") await cache_engine.extract("unified_response", serialize_value(unified_response), context.field_cache_context(), transaction_logger=logger) @@ -82,7 +84,6 @@ async def execute(self, raw_request: dict[str, Any], context: NativeProviderCont adapter_context.transaction_logger = None provider_response = await run_adapter_chain(adapters, provider_response, adapter_context, stage="response") self._trace(context, "after_response_adapter_chain", provider_response, direction="response", stage="adapter") - await cache_engine.extract("response", provider_response, context.field_cache_context(), transaction_logger=logger) usage_record = extract_usage_record( provider_response, provider=context.provider, diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index 6daf3b639..50e61c1f4 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -70,8 +70,13 @@ class AntigravityProvider(ProviderInterface): ) native_streaming_supported = False model_quota_groups = { - "gemini": ["gemini-2.5-flash", "gemini-2.5-flash-lite", "gemini-3-pro-preview", "gemini-3-flash"], - "claude": ["claude-sonnet-4.5", "claude-opus-4.5", "claude-opus-4.6"], + "gemini_3_pro": ["gemini-3-pro-preview", "gemini-3-pro-low", "gemini-3-pro-high"], + "gemini_3_flash": ["gemini-3-flash"], + "gemini_2_5_flash": ["gemini-2.5-flash"], + "gemini_2_5_flash_lite": ["gemini-2.5-flash-lite"], + "claude_sonnet_4_5": ["claude-sonnet-4.5"], + "claude_opus_4_5": ["claude-opus-4.5"], + "claude_opus_4_6": ["claude-opus-4.6"], } default_rotation_mode = "sequential" diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py index 3dc86e81f..e7b90043f 100644 --- a/tests/test_antigravity_provider_restore.py +++ b/tests/test_antigravity_provider_restore.py @@ -43,6 +43,16 @@ def test_antigravity_provider_restores_safe_declarations() -> None: assert rules[0].scope == ("provider", "model", "credential", "session") +def test_antigravity_quota_groups_are_model_family_scoped() -> None: + groups = AntigravityProvider.model_quota_groups + + assert groups["gemini_3_pro"] == ["gemini-3-pro-preview", "gemini-3-pro-low", "gemini-3-pro-high"] + assert groups["gemini_3_flash"] == ["gemini-3-flash"] + assert groups["gemini_2_5_flash"] == ["gemini-2.5-flash"] + assert groups["gemini_2_5_flash_lite"] == ["gemini-2.5-flash-lite"] + assert groups["claude_sonnet_4_5"] == ["claude-sonnet-4.5"] + + def test_antigravity_provider_builds_static_headers_without_device_profile(monkeypatch) -> None: monkeypatch.setenv("ANTIGRAVITY_API_BASE", "https://antigravity.test/v1internal") provider = AntigravityProvider() diff --git a/tests/test_native_provider_executor.py b/tests/test_native_provider_executor.py index a1a4c8f02..3db53fd73 100644 --- a/tests/test_native_provider_executor.py +++ b/tests/test_native_provider_executor.py @@ -269,6 +269,38 @@ async def test_native_runtime_executes_unified_response_source() -> None: assert second_client.calls[0]["json"]["metadata"]["cached_object"] == "chat.completion" +@pytest.mark.asyncio +async def test_native_response_source_extracts_raw_provider_response_before_client_formatting() -> None: + rule = FieldCacheRule( + name="anthropic_signature", + source="response", + path="content.*.signature", + mode="all", + inject=FieldCacheInjection(target="request", path="metadata.signatures", as_list=True), + allow_missing_session=True, + ) + context = NativeProviderContext( + provider="claude_code", + model="claude-sonnet-4-5", + protocol_name="anthropic_messages", + client_protocol_name="openai_chat", + endpoint="https://example.test/messages", + operation="messages", + field_cache_rules=(rule,), + ) + executor = NativeProviderExecutor() + + await executor.execute( + {"model": "claude-sonnet-4-5", "messages": [], "max_tokens": 1}, + context, + NativeHTTPTransport(FakeHTTPClient({"id": "msg_1", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "ok", "signature": "sig-1"}]})), + ) + second_client = FakeHTTPClient({"id": "msg_2", "type": "message", "role": "assistant", "content": []}) + await executor.execute({"model": "claude-sonnet-4-5", "messages": [], "max_tokens": 1}, context, NativeHTTPTransport(second_client)) + + assert second_client.calls[0]["json"]["metadata"]["signatures"] == ["sig-1"] + + @pytest.mark.asyncio async def test_native_runtime_executes_metadata_target_before_adapters() -> None: class MetadataEchoAdapter(PayloadAdapter): From 29d66ff3ec13ffa2f761726b37bd094ea88aab12 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 21:28:26 +0200 Subject: [PATCH 167/182] fix(providers): preserve plain antigravity preview alias Removes ambiguous Gemini 3 Pro preview from reverse alias normalization so plain preview requests stay preview while explicit low/high aliases still carry thinking behavior. Tests: pytest tests/test_antigravity_provider_restore.py tests/test_request_executor_native_routing.py tests/test_native_provider_executor.py tests/test_native_provider_streaming.py tests/test_provider_protocol_declarations.py tests/test_protocol_openai_chat.py tests/test_protocol_gemini.py --- src/rotator_library/providers/antigravity_provider.py | 3 ++- tests/test_antigravity_provider_restore.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index 50e61c1f4..5b871f34b 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -41,7 +41,8 @@ "claude-opus-4-5": "claude-opus-4.5", "claude-opus-4-6": "claude-opus-4.6", } -MODEL_ALIAS_REVERSE = {public: internal for internal, public in MODEL_ALIAS_MAP.items()} +_AMBIGUOUS_REVERSE_ALIASES = {"gemini-3-pro-preview"} +MODEL_ALIAS_REVERSE = {public: internal for internal, public in MODEL_ALIAS_MAP.items() if public not in _AMBIGUOUS_REVERSE_ALIASES} EXCLUDED_MODELS = {"chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-2.5-pro"} diff --git a/tests/test_antigravity_provider_restore.py b/tests/test_antigravity_provider_restore.py index e7b90043f..ce5dbba87 100644 --- a/tests/test_antigravity_provider_restore.py +++ b/tests/test_antigravity_provider_restore.py @@ -74,6 +74,7 @@ def test_antigravity_model_aliases_and_tracking_normalization() -> None: assert provider._alias_to_internal("claude-sonnet-4.5") == "claude-sonnet-4-5" assert provider.normalize_native_model("antigravity/claude-sonnet-4.5") == "claude-sonnet-4-5" + assert provider.normalize_native_model("antigravity/gemini-3-pro-preview") == "gemini-3-pro-preview" assert provider.normalize_model_for_tracking("antigravity/claude-sonnet-4-5") == "antigravity/claude-sonnet-4.5" From c225c26d8c2ae692b0b9cd86dbbd884d0d1fa46c Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 21:35:52 +0200 Subject: [PATCH 168/182] docs(experimental): plan routing fallback correction Adds the Phase 6c corrective plan for stale fallback tests, group target promotion, streaming execution-mode parity, native config hard stops, route-error aliases, and target namespace adjustment. Tests: not run (planning document only) --- .../phase-6c-routing-fallback-correction.md | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 docs/experimental/phase-6c-routing-fallback-correction.md diff --git a/docs/experimental/phase-6c-routing-fallback-correction.md b/docs/experimental/phase-6c-routing-fallback-correction.md new file mode 100644 index 000000000..4df2eb01f --- /dev/null +++ b/docs/experimental/phase-6c-routing-fallback-correction.md @@ -0,0 +1,62 @@ +# Phase 6c: Routing/Fallback Correctness And Stale Test Repair + +## Goal + +Close all Phase 6/6b third-pass routing findings while preserving the fallback group architecture and the hard-stop safety from Phase 6b. + +## Scope + +- Replace stale `tests/test_fallback_groups.py` with tests for the current routing package. +- Implement requested-model-in-group promotion. +- Align streaming execution-mode behavior with non-streaming behavior for `@custom`, `@native`, and `@litellm_fallback`. +- Ensure explicit native configuration errors are hard-stop configuration errors, not retryable unsupported-operation errors. +- Expand structured route-error alias normalization. +- Adjust fallback target session namespace per target. +- Add focused tests for every blocker/high/medium. + +## Non-Goals + +- Do not resurrect the old `FallbackGroupManager`. +- Do not reintroduce `RotatingClient(model_fallback_groups=...)`. +- Do not implement routing for embeddings unless it is trivial and needed by tests. +- Do not weaken hard-stop policy. +- Do not commit user-facing reports. + +## Implementation Plan + +1. Replace stale fallback tests. + - Rewrite `tests/test_fallback_groups.py` around `routing.config`, `FallbackResolver`, `FallbackAttemptRunner`, and executor helpers. + - Cover env/JSON group parsing, target promotion, execution suffix parsing, and fallback exhaustion behavior. + +2. Requested-model promotion. + - Promote a requested provider/model target to the first attempt when it appears in a fallback group. + - Preserve the relative order of all other targets. + - Apply this both for explicit `MODEL_ROUTE_* = group:name` routes and provider-prefixed requests that appear in any group. + +3. Streaming execution-mode parity. + - Share execution-mode selection logic between streaming and non-streaming paths. + - `@litellm_fallback` must force LiteLLM. + - `@custom` must require custom provider logic. + - `@native` must fail closed if native streaming is unsupported. + - `auto` remains custom first, then native streaming if explicitly supported, then LiteLLM. + +4. Explicit native configuration errors. + - Missing native declaration, unsupported operation, and missing endpoint/header helpers should raise `RoutingExecutionError(error_type="configuration_error")` for explicit native execution failures. + - Hard-stop policy then blocks fallback. + +5. Structured route-error aliases. + - Add common status-less aliases including `invalid_api_key`, `unauthorized`, `invalid_argument`, `rate_limited`, `too_many_requests`, `resource_exhausted`, `unavailable`, `deadline_exceeded`, and context-window variants. + +6. Session namespace adjustment. + - `clone_context_for_target()` should rewrite standard session-tracking namespaces to the target provider/model instead of preserving the first target namespace. + - Unknown/custom namespace shapes remain unchanged. + +## Acceptance Criteria + +- Broad test collection no longer fails on stale fallback imports. +- Requested provider-prefixed models inside a group are tried first. +- Streaming and non-streaming execution-mode selection are consistent. +- Explicit native config errors are hard-stop configuration errors. +- Common structured aliases normalize correctly. +- Fallback target contexts do not reuse the first target's provider/model session namespace. +- Focused tests pass and both reviewers report no blockers/highs/mediums. From 1fe74bf97816e5e3ab618f632e6da858aee7573b Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 21:45:47 +0200 Subject: [PATCH 169/182] fix(routing): harden fallback selection Replaces stale fallback-group tests with current routing coverage, promotes requested models within fallback groups, aligns streaming execution-mode precedence with non-streaming behavior, hard-stops native config errors, expands structured route-error aliases, and rewrites target session namespaces during fallback cloning. Tests: pytest tests/test_fallback_groups.py tests/test_fallback_resolver.py tests/test_fallback_policy.py tests/test_fallback_attempt_runner.py tests/test_request_executor_fallback_groups.py tests/test_request_executor_fallback_error_summary.py tests/test_routing_config.py tests/test_config_routing_json.py tests/test_routing_attempts.py tests/test_request_builder_routing.py tests/test_request_executor_native_routing.py tests/test_streaming_fallback_policy.py tests/test_retry_policy.py tests/test_cooldown_activation.py --- src/rotator_library/client/executor.py | 54 ++++++++---- src/rotator_library/routing/attempts.py | 13 +++ src/rotator_library/routing/policy.py | 13 +++ src/rotator_library/routing/resolver.py | 24 +++++- tests/test_fallback_groups.py | 86 +++++++++++++++++++ tests/test_fallback_policy.py | 12 +++ tests/test_request_executor_native_routing.py | 10 ++- tests/test_routing_attempts.py | 32 +++++++ 8 files changed, 224 insertions(+), 20 deletions(-) create mode 100644 tests/test_fallback_groups.py diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 1469fd1e2..cbdf3d371 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -89,7 +89,7 @@ class RoutingExecutionError(RuntimeError): """Internal error used when a routed target cannot use its requested mode.""" - def __init__(self, message: str, error_type: str = "unsupported_operation") -> None: + def __init__(self, message: str, error_type: str = "configuration_error") -> None: super().__init__(message) self.error_type = error_type @@ -1437,8 +1437,42 @@ async def _execute_streaming( target = _current_route_target(context) execution = target.execution if target else "auto" - # Make the API call - if _should_use_native_streaming(plugin, model, target, execution, provider): + # Make the API call. Keep execution-mode precedence aligned with + # the non-streaming path: explicit LiteLLM wins, explicit custom + # requires custom logic, explicit native fails closed, and auto + # prefers custom before native streaming. + if execution == "litellm_fallback": + kwargs["api_key"] = credential_secret + kwargs["stream"] = True + self._apply_litellm_logger(kwargs) + kwargs.pop("transaction_context", None) + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"execution": "litellm_stream", "provider": provider, "model": model}, + ) + stream = await litellm.acompletion(**kwargs) + elif execution == "custom" or (execution == "auto" and plugin and plugin.has_custom_logic()): + if not plugin or not plugin.has_custom_logic(): + raise RoutingExecutionError(f"Provider {provider} does not support custom execution") + kwargs["credential_identifier"] = credential_secret + self._log_executor_trace( + context, + "provider_execution_request", + kwargs, + direction="request", + stage="provider", + credential_id=cred_context.stable_id, + metadata={"execution": "custom_stream", "provider": provider, "model": model}, + ) + stream = await plugin.acompletion( + self._http_client, **kwargs + ) + elif _should_use_native_streaming(plugin, model, target, execution, provider): native_context, native_request = self._build_native_provider_context( provider, model, @@ -1459,20 +1493,6 @@ async def _execute_streaming( metadata={"protocol": native_context.protocol_name, "operation": native_context.operation}, ) stream = self._get_native_executor().stream(native_request, native_context, NativeHTTPTransport(self._http_client)) - elif plugin and plugin.has_custom_logic(): - kwargs["credential_identifier"] = credential_secret - self._log_executor_trace( - context, - "provider_execution_request", - kwargs, - direction="request", - stage="provider", - credential_id=cred_context.stable_id, - metadata={"execution": "custom_stream", "provider": provider, "model": model}, - ) - stream = await plugin.acompletion( - self._http_client, **kwargs - ) else: kwargs["api_key"] = credential_secret kwargs["stream"] = True diff --git a/src/rotator_library/routing/attempts.py b/src/rotator_library/routing/attempts.py index 325b19bc8..582368ee8 100644 --- a/src/rotator_library/routing/attempts.py +++ b/src/rotator_library/routing/attempts.py @@ -41,4 +41,17 @@ def clone_context_for_target( provider_config=provider_config if provider_config is not None else context.provider_config, credential_secrets=dict(credential_secrets) if credential_secrets is not None else dict(context.credential_secrets), routing_target_index=target_index, + session_tracking_namespace=_namespace_for_target(context.session_tracking_namespace, target), ) + + +def _namespace_for_target(namespace: str | None, target: RouteTarget) -> str | None: + """Rewrite standard session namespaces for the fallback target provider/model.""" + + if not namespace or ":provider:" not in namespace or ":model:" not in namespace: + return namespace + prefix, _, rest = namespace.partition(":provider:") + _provider, sep, _model = rest.partition(":model:") + if not sep: + return namespace + return f"{prefix}:provider:{target.provider}:model:{target.prefixed_model}" diff --git a/src/rotator_library/routing/policy.py b/src/rotator_library/routing/policy.py index d9becdec0..549ee114a 100644 --- a/src/rotator_library/routing/policy.py +++ b/src/rotator_library/routing/policy.py @@ -10,24 +10,37 @@ _ALIASES = { "auth": "authentication", + "invalid_api_key": "authentication", + "invalid_key": "authentication", + "unauthorized": "authentication", "permission": "forbidden", "permission_denied": "forbidden", "access_denied": "forbidden", "bad_request": "invalid_request", + "invalid_argument": "invalid_request", + "bad_argument": "invalid_request", "validation": "invalid_request", "permanent": "invalid_request", "context_length": "context_window_exceeded", "context_length_exceeded": "context_window_exceeded", "context_window": "context_window_exceeded", "too_many_tokens": "context_window_exceeded", + "max_tokens_exceeded": "context_window_exceeded", + "token_limit_exceeded": "context_window_exceeded", "pre_request_callback": "pre_request_callback_error", "configuration": "configuration_error", "config": "configuration_error", "quota": "quota_exceeded", + "resource_exhausted": "quota_exceeded", "capacity": "rate_limit", + "rate_limited": "rate_limit", + "too_many_requests": "rate_limit", "transient": "server_error", + "unavailable": "server_error", "network": "api_connection", "connection": "api_connection", + "deadline_exceeded": "api_connection", + "timeout": "api_connection", "400": "invalid_request", "401": "authentication", "403": "forbidden", diff --git a/src/rotator_library/routing/resolver.py b/src/rotator_library/routing/resolver.py index f384095b0..53b66d58b 100644 --- a/src/rotator_library/routing/resolver.py +++ b/src/rotator_library/routing/resolver.py @@ -6,7 +6,7 @@ from __future__ import annotations from .config import RoutingConfigError, parse_route_target -from .types import RoutingConfig, RoutingDecision +from .types import FallbackGroup, RouteTarget, RoutingConfig, RoutingDecision class FallbackResolver: @@ -24,9 +24,29 @@ def resolve(self, requested_model: str) -> RoutingDecision: group = self.config.fallback_groups.get(group_name) if not group: raise RoutingConfigError(f"unknown fallback group {group_name}") - return RoutingDecision(requested_model=requested_model, group_name=group.name, group=group, targets=group.targets, reason="model_route_group") + targets = _promote_requested_target(group, requested_model) + reason = "model_route_group_promoted" if targets != group.targets else "model_route_group" + return RoutingDecision(requested_model=requested_model, group_name=group.name, group=group, targets=targets, reason=reason) if route: return RoutingDecision(requested_model=requested_model, targets=(parse_route_target(route),), reason="model_route_target") + for group in self.config.fallback_groups.values(): + targets = _promote_requested_target(group, requested_model) + if targets != group.targets or any(_same_target(target, requested_model) for target in group.targets): + return RoutingDecision(requested_model=requested_model, group_name=group.name, group=group, targets=targets, reason="provider_model_group_promoted") if "/" in requested_model: return RoutingDecision(requested_model=requested_model, targets=(parse_route_target(requested_model),), reason="direct_provider_model") raise RoutingConfigError(f"model {requested_model!r} is not provider-prefixed and has no route") + + +def _promote_requested_target(group: FallbackGroup, requested_model: str) -> tuple[RouteTarget, ...]: + """Return group targets with the requested provider/model attempted first.""" + + matching = [target for target in group.targets if _same_target(target, requested_model)] + if not matching: + return group.targets + selected = matching[0] + return (selected, *(target for target in group.targets if target is not selected)) + + +def _same_target(target: RouteTarget, requested_model: str) -> bool: + return target.prefixed_model.lower() == requested_model.lower() diff --git a/tests/test_fallback_groups.py b/tests/test_fallback_groups.py new file mode 100644 index 000000000..66f24bdb0 --- /dev/null +++ b/tests/test_fallback_groups.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import pytest + +from rotator_library.routing import FallbackResolver, FallbackGroup, RouteTarget, RoutingConfig +from rotator_library.routing.config import RoutingConfigError, load_routing_config_from_env, parse_route_target +from rotator_library.routing.executor import FallbackAttemptRunner + + +def test_parse_route_target_supports_execution_suffix() -> None: + target = parse_route_target("openai/gpt-5@litellm_fallback") + + assert target.provider == "openai" + assert target.model == "gpt-5" + assert target.execution == "litellm_fallback" + + +def test_env_fallback_group_config_loads_current_routing_types() -> None: + config = load_routing_config_from_env( + { + "FALLBACK_GROUPS": "main", + "FALLBACK_GROUP_MAIN": "openai/gpt-5@litellm_fallback,anthropic/claude@native", + "MODEL_ROUTE_GPT5": "group:main", + }, + config=RoutingConfig(), + ) + + group = config.fallback_groups["main"] + assert [target.prefixed_model for target in group.targets] == ["openai/gpt-5", "anthropic/claude"] + assert group.targets[0].execution == "litellm_fallback" + assert group.targets[1].execution == "native" + assert config.model_routes["gpt5"] == "group:main" + + +def test_resolver_promotes_requested_provider_model_inside_group() -> None: + group = FallbackGroup( + name="main", + targets=( + RouteTarget("openai", "gpt-5"), + RouteTarget("anthropic", "claude"), + RouteTarget("google", "gemini"), + ), + ) + decision = FallbackResolver(RoutingConfig(fallback_groups={"main": group})).resolve("anthropic/claude") + + assert decision.reason == "provider_model_group_promoted" + assert [target.prefixed_model for target in decision.targets] == ["anthropic/claude", "openai/gpt-5", "google/gemini"] + + +def test_resolver_promotes_requested_model_for_group_route_alias() -> None: + group = FallbackGroup( + name="main", + targets=(RouteTarget("openai", "gpt-5"), RouteTarget("anthropic", "claude")), + ) + config = RoutingConfig(fallback_groups={"main": group}, model_routes={"anthropic/claude": "group:main"}) + + decision = FallbackResolver(config).resolve("anthropic/claude") + + assert decision.reason == "model_route_group_promoted" + assert [target.prefixed_model for target in decision.targets] == ["anthropic/claude", "openai/gpt-5"] + + +def test_resolver_rejects_missing_group_route() -> None: + with pytest.raises(RoutingConfigError): + FallbackResolver(RoutingConfig(model_routes={"alias": "group:missing"})).resolve("alias") + + +@pytest.mark.asyncio +async def test_fallback_runner_uses_decision_group_policy() -> None: + group = FallbackGroup(name="main", targets=(RouteTarget("a", "one"), RouteTarget("b", "two")), failover_on=frozenset({"rate_limit"})) + decision = FallbackResolver(RoutingConfig(fallback_groups={"main": group}, model_routes={"alias": "group:main"})).resolve("alias") + attempts: list[str] = [] + + class RateLimitError(RuntimeError): + error_type = "rate_limit" + + async def attempt(target, index): + attempts.append(target.prefixed_model) + if target.provider == "a": + raise RateLimitError("rate limit") + return "ok" + + result = await FallbackAttemptRunner().run(decision, attempt) + + assert result == "ok" + assert attempts == ["a/one", "b/two"] diff --git a/tests/test_fallback_policy.py b/tests/test_fallback_policy.py index 3cc8531d9..74adaeaab 100644 --- a/tests/test_fallback_policy.py +++ b/tests/test_fallback_policy.py @@ -65,3 +65,15 @@ def test_policy_normalizes_user_facing_aliases() -> None: assert normalize_route_error_type("context_length_exceeded") == "context_window_exceeded" assert FallbackPolicy().should_fallback("network") is True assert FallbackPolicy().should_fallback("validation") is False + + +def test_policy_normalizes_common_structured_provider_aliases() -> None: + assert normalize_route_error_type("invalid_api_key") == "authentication" + assert normalize_route_error_type("unauthorized") == "authentication" + assert normalize_route_error_type("invalid_argument") == "invalid_request" + assert normalize_route_error_type("max_tokens_exceeded") == "context_window_exceeded" + assert normalize_route_error_type("rate_limited") == "rate_limit" + assert normalize_route_error_type("too_many_requests") == "rate_limit" + assert normalize_route_error_type("resource_exhausted") == "quota_exceeded" + assert normalize_route_error_type("unavailable") == "server_error" + assert normalize_route_error_type("deadline_exceeded") == "api_connection" diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 4da465108..17711341a 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -419,6 +419,14 @@ def test_route_error_type_from_response_uses_structured_status_codes() -> None: assert executor_module._route_error_type_from_response({"error": {"code": 403}}) == "forbidden" assert executor_module._route_error_type_from_response({"error": {"status": 401}}) == "authentication" assert executor_module._route_error_type_from_response({"error": {"details": {"status": 403}}}) == "forbidden" + + +def test_route_error_type_from_response_uses_structured_aliases() -> None: + assert executor_module._route_error_type_from_response({"error": {"type": "invalid_api_key"}}) == "authentication" + assert executor_module._route_error_type_from_response({"error": {"code": "invalid_argument"}}) == "invalid_request" + assert executor_module._route_error_type_from_response({"error": {"code": "resource_exhausted"}}) == "quota_exceeded" + assert executor_module._route_error_type_from_response({"error": {"code": "unavailable"}}) == "server_error" + assert executor_module._route_error_type_from_response({"error": {"code": "deadline_exceeded"}}) == "api_connection" assert executor_module._route_error_type_from_response({"error": {"details": {"status_code": 503}}}) == "server_error" @@ -543,4 +551,4 @@ async def test_native_execution_mode_fails_when_provider_has_no_native_declarati with pytest.raises(RoutingExecutionError) as exc: await _executor()._execute_provider_request("provider", "provider/gpt-test", None, "secret", "stable", dict(context.kwargs), context) - assert exc.value.error_type == "unsupported_operation" + assert exc.value.error_type == "configuration_error" diff --git a/tests/test_routing_attempts.py b/tests/test_routing_attempts.py index 1a8a978b9..750108622 100644 --- a/tests/test_routing_attempts.py +++ b/tests/test_routing_attempts.py @@ -48,3 +48,35 @@ def test_clone_context_for_target_preserves_request_metadata() -> None: assert cloned.session_id == "session-1" assert cloned.classifier == "global" assert cloned.routing_group_name == "chain" + + +def test_clone_context_for_target_rewrites_standard_session_namespace() -> None: + original = RequestContext( + model="openai/gpt-5", + provider="openai", + kwargs={"model": "openai/gpt-5"}, + streaming=False, + credentials=["cred-a"], + deadline=123.0, + session_tracking_namespace="scope:openai:provider:openai:model:openai/gpt-5", + ) + + cloned = clone_context_for_target(original, parse_route_target("anthropic/claude")) + + assert cloned.session_tracking_namespace == "scope:openai:provider:anthropic:model:anthropic/claude" + + +def test_clone_context_for_target_preserves_custom_session_namespace() -> None: + original = RequestContext( + model="openai/gpt-5", + provider="openai", + kwargs={"model": "openai/gpt-5"}, + streaming=False, + credentials=["cred-a"], + deadline=123.0, + session_tracking_namespace="custom-namespace", + ) + + cloned = clone_context_for_target(original, parse_route_target("anthropic/claude")) + + assert cloned.session_tracking_namespace == "custom-namespace" From bb2c497c1d8bc95e59506d919ad48ce3617ecb1b Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 21:54:04 +0200 Subject: [PATCH 170/182] fix(routing): honor native stream opt outs Makes streaming auto-native selection call the same provider opt-out hook as non-streaming execution and rewrites fallback session namespaces with the target usage scope instead of preserving the first target scope prefix. Tests: pytest tests/test_fallback_groups.py tests/test_fallback_resolver.py tests/test_fallback_policy.py tests/test_fallback_attempt_runner.py tests/test_request_executor_fallback_groups.py tests/test_request_executor_fallback_error_summary.py tests/test_routing_config.py tests/test_config_routing_json.py tests/test_routing_attempts.py tests/test_request_builder_routing.py tests/test_request_executor_native_routing.py tests/test_streaming_fallback_policy.py tests/test_retry_policy.py tests/test_cooldown_activation.py --- src/rotator_library/client/executor.py | 6 +++++- src/rotator_library/routing/attempts.py | 11 +++++++---- tests/test_request_executor_native_routing.py | 10 ++++++++++ tests/test_routing_attempts.py | 2 +- 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index cbdf3d371..9ce5b23f5 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -2515,7 +2515,11 @@ def _should_use_native_streaming(plugin: Any, model: str, target: Optional[Route f"Provider {provider} does not support native streaming for {model}", error_type="configuration_error", ) - return bool(execution == "auto" and plugin and _provider_native_protocol(plugin, model, target) and _provider_supports_native_streaming(plugin, model)) + if execution != "auto" or not plugin: + return False + if not _should_use_native_protocol(plugin, model, target, {"model": model, "stream": True}, stream=True, execution=execution): + return False + return _provider_supports_native_streaming(plugin, model) def _redact_context_field_cache_paths(payload: Any, context: RequestContext, direction: str, plugin: Any) -> Any: diff --git a/src/rotator_library/routing/attempts.py b/src/rotator_library/routing/attempts.py index 582368ee8..bfdc5f695 100644 --- a/src/rotator_library/routing/attempts.py +++ b/src/rotator_library/routing/attempts.py @@ -31,27 +31,30 @@ def clone_context_for_target( kwargs: dict[str, Any] = dict(context.kwargs) kwargs["model"] = target.prefixed_model + next_usage_key = usage_manager_key if usage_manager_key is not None else target.provider return replace( context, model=target.prefixed_model, provider=target.provider, kwargs=kwargs, credentials=list(credentials) if credentials is not None else list(context.credentials), - usage_manager_key=usage_manager_key if usage_manager_key is not None else target.provider, + usage_manager_key=next_usage_key, provider_config=provider_config if provider_config is not None else context.provider_config, credential_secrets=dict(credential_secrets) if credential_secrets is not None else dict(context.credential_secrets), routing_target_index=target_index, - session_tracking_namespace=_namespace_for_target(context.session_tracking_namespace, target), + session_tracking_namespace=_namespace_for_target(context.session_tracking_namespace, target, scope_key=next_usage_key), ) -def _namespace_for_target(namespace: str | None, target: RouteTarget) -> str | None: +def _namespace_for_target(namespace: str | None, target: RouteTarget, *, scope_key: str) -> str | None: """Rewrite standard session namespaces for the fallback target provider/model.""" if not namespace or ":provider:" not in namespace or ":model:" not in namespace: return namespace prefix, _, rest = namespace.partition(":provider:") + if not prefix.startswith("scope:"): + return namespace _provider, sep, _model = rest.partition(":model:") if not sep: return namespace - return f"{prefix}:provider:{target.provider}:model:{target.prefixed_model}" + return f"scope:{scope_key}:provider:{target.provider}:model:{target.prefixed_model}" diff --git a/tests/test_request_executor_native_routing.py b/tests/test_request_executor_native_routing.py index 17711341a..ac7c0a7bd 100644 --- a/tests/test_request_executor_native_routing.py +++ b/tests/test_request_executor_native_routing.py @@ -388,6 +388,16 @@ def test_auto_native_selection_honors_provider_opt_out() -> None: ) is False +def test_auto_native_streaming_selection_honors_provider_opt_out() -> None: + assert executor_module._should_use_native_streaming( + NativeOptOutPlugin(), + "provider/gpt-test", + parse_route_target("provider/gpt-test"), + "auto", + "provider", + ) is False + + def test_native_request_payload_drops_litellm_only_fields() -> None: payload = executor_module._native_request_payload( { diff --git a/tests/test_routing_attempts.py b/tests/test_routing_attempts.py index 750108622..4f84c76a9 100644 --- a/tests/test_routing_attempts.py +++ b/tests/test_routing_attempts.py @@ -63,7 +63,7 @@ def test_clone_context_for_target_rewrites_standard_session_namespace() -> None: cloned = clone_context_for_target(original, parse_route_target("anthropic/claude")) - assert cloned.session_tracking_namespace == "scope:openai:provider:anthropic:model:anthropic/claude" + assert cloned.session_tracking_namespace == "scope:anthropic:provider:anthropic:model:anthropic/claude" def test_clone_context_for_target_preserves_custom_session_namespace() -> None: From c40c6f220bb499a50c7920194f4ba27de5f81756 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 22:01:23 +0200 Subject: [PATCH 171/182] docs(experimental): plan retry cooldown correction Adds the Phase 7c corrective plan for cooldown budget fail-fast behavior, transient backoff thresholds, success reset, structured routing attempt history, and model-scoped cooldown isolation. Tests: not run (planning document only) --- .../phase-7c-retry-cooldown-correction.md | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 docs/experimental/phase-7c-retry-cooldown-correction.md diff --git a/docs/experimental/phase-7c-retry-cooldown-correction.md b/docs/experimental/phase-7c-retry-cooldown-correction.md new file mode 100644 index 000000000..8fba3b192 --- /dev/null +++ b/docs/experimental/phase-7c-retry-cooldown-correction.md @@ -0,0 +1,57 @@ +# Phase 7c: Cooldown Budget, Transient Backoff, And Attempt History + +## Goal + +Close all Phase 7/7b third-pass retry/cooldown findings while preserving existing retry-after parsing and credential-rotation semantics. + +## Scope + +- Fail fast when an active provider/model cooldown exceeds the remaining request deadline budget. +- Prevent a single generic `server_error` / `api_connection` without `retry_after` from starting provider-wide cooldown. +- Record repeated transient failures into `FailureHistory` even when the cooldown threshold has not yet been reached. +- Clear or reduce failure history on successful provider/model calls. +- Populate structured `routing_attempt_history` for live non-streaming and streaming fallback attempts. +- Add tests for cooldown-over-budget, single transient no-cooldown, repeated transient cooldown, success reset, model-scoped cooldown isolation, and stream parity. + +## Non-Goals + +- Do not replace the retry-after parser. +- Do not replace `UsageManager` or `SessionTracker`. +- Do not introduce durable failure-history storage. +- Do not change small retry-after same-credential behavior. +- Do not weaken fallback hard-stop policy from Phase 6c. + +## Implementation Plan + +1. Cooldown-over-budget fail-fast. + - Raise a retryable routing error instead of returning when cooldown remaining exceeds the request deadline budget. + - Trace `cooldown_wait_exceeds_budget` without exposing credentials or provider payloads. + +2. Generic transient cooldown threshold. + - No-`retry_after` `server_error` and `api_connection` should start cooldown only after `FailureHistory.backoff_for()` crosses the configured threshold. + - Large explicit `retry_after` behavior remains unchanged. + +3. Record skipped transient failures. + - Record sanitized transient entries when cooldown is skipped because the backoff threshold is not yet met. + - Keep history bounded and in-memory only. + +4. Success reset. + - Add a clear/reset helper on `FailureHistory`. + - Clear matching provider/model transient entries after successful non-streaming and completed streaming calls. + +5. Structured attempt history. + - Append sanitized attempt entries for fallback failures and successes into `RequestContext.routing_attempt_history`. + - Include error type, target identity, execution mode, output visibility, status code when available, fallback decision, and timing where cheap. + +6. Model cooldown isolation. + - Add tests proving model-scoped cooldowns block only the matching model while provider-scoped cooldowns block all provider models. + +## Acceptance Criteria + +- Cooldown that exceeds request budget does not silently allow execution. +- A single generic transient without retry-after does not start provider-wide cooldown. +- Repeated transient failures can still start bounded cooldown/backoff. +- Successful calls clear matching failure-history entries. +- Live routing attempt history is populated and sanitized. +- Model-scoped cooldown blocks only its model plus provider-wide cooldown still blocks all provider models. +- Focused tests pass and both reviewers report no blockers/highs/mediums. From f7a6aea98c1dc4002c2685393e594beb8f7c7723 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 22:12:33 +0200 Subject: [PATCH 172/182] fix(retry): enforce cooldown backoff semantics Fails fast when active cooldowns exceed the request budget, prevents single generic transients from starting provider-wide cooldowns, records skipped transient failures for bounded backoff, clears failure history on success, and populates sanitized routing attempt history. Tests: pytest tests/test_retry_policy.py tests/test_cooldown_activation.py tests/test_request_executor_fallback_groups.py tests/test_streaming_fallback_policy.py tests/test_fallback_policy.py tests/test_fallback_groups.py tests/test_fallback_resolver.py tests/test_request_executor_fallback_error_summary.py tests/test_routing_attempts.py tests/test_streaming_error_handler.py tests/test_streaming_usage_accounting.py tests/test_request_executor_stream_metrics.py --- src/rotator_library/client/executor.py | 92 +++++++++++++++++-- src/rotator_library/client/streaming.py | 3 + src/rotator_library/retry_policy.py | 29 +++++- tests/test_cooldown_activation.py | 39 +++++++- .../test_request_executor_fallback_groups.py | 6 +- tests/test_retry_policy.py | 36 ++++++++ tests/test_streaming_fallback_policy.py | 6 +- tests/test_streaming_usage_accounting.py | 9 ++ 8 files changed, 206 insertions(+), 14 deletions(-) diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 9ce5b23f5..c1c062d71 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -773,6 +773,7 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> metadata={"group": context.routing_group_name, "targets": [_target_trace(target) for target in targets]}, ) for index, target in enumerate(targets): + attempt_started_at = time.monotonic() target_context = clone_context_for_target( context, target, @@ -800,7 +801,9 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> metadata={"target_index": index, "error_type": error_type, "exception": exc.__class__.__name__}, ) target_failures.append(_target_failure_summary(target, error_type)) - if index >= len(targets) - 1 or not policy.should_fallback(error_type, group=context.routing_group): + fallback_allowed = index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group) + _append_routing_attempt_history(context, target, index, success=False, error_type=error_type, fallback_allowed=fallback_allowed, duration_ms=_elapsed_ms(attempt_started_at)) + if not fallback_allowed: self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "fallback_targets": target_failures}) raise self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) @@ -816,11 +819,13 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> metadata={"target_index": index, "error_type": error_type}, ) target_failures.append(_target_failure_summary(target, error_type, status_code=_route_status_code_from_response(result))) - if index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group): + fallback_allowed = index < len(targets) - 1 and policy.should_fallback(error_type, group=context.routing_group) + _append_routing_attempt_history(context, target, index, success=False, error_type=error_type, status_code=_route_status_code_from_response(result), fallback_allowed=fallback_allowed, duration_ms=_elapsed_ms(attempt_started_at)) + if fallback_allowed: self._log_routing_trace(context, "routing_fallback_selected", _target_trace(targets[index + 1]), metadata={"from_target_index": index, "to_target_index": index + 1, "reason": error_type}) continue self._log_routing_trace(context, "routing_fallback_exhausted", _target_trace(target), metadata={"error_type": error_type, "fallback_targets": target_failures}) - return _with_fallback_summary(result, target_failures) + return _with_fallback_summary(result, target_failures, context.routing_attempt_history) self._log_routing_trace( context, @@ -828,6 +833,7 @@ async def _execute_non_streaming_with_fallback(self, context: RequestContext) -> _target_trace(target), metadata={"target_index": index}, ) + _append_routing_attempt_history(context, target, index, success=True, duration_ms=_elapsed_ms(attempt_started_at)) return result if isinstance(last_failure, Exception): @@ -851,6 +857,7 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy metadata={"group": context.routing_group_name, "streaming_policy": _group_streaming_policy(context.routing_group), "targets": [_target_trace(target) for target in targets]}, ) for index, target in enumerate(targets): + attempt_started_at = time.monotonic() emitted_output = False pending_chunks: List[str] = [] terminal_error_type: Optional[str] = None @@ -896,7 +903,9 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy metadata={"target_index": index, "error_type": error_type, "emitted_output": emitted_output, "terminal_error_frame": True}, ) target_failures.append(_target_failure_summary(target, error_type)) - if index < len(targets) - 1 and _streaming_policy_allows_fallback(context.routing_group) and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False): + fallback_allowed = index < len(targets) - 1 and _streaming_policy_allows_fallback(context.routing_group) and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False) + _append_routing_attempt_history(context, target, index, success=False, error_type=error_type, emitted_output=emitted_output, fallback_allowed=fallback_allowed, duration_ms=_elapsed_ms(attempt_started_at)) + if fallback_allowed: self._log_routing_trace( context, "routing_fallback_selected", @@ -916,6 +925,7 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy _target_trace(target), metadata={"target_index": index, "emitted_output": emitted_output}, ) + _append_routing_attempt_history(context, target, index, success=True, emitted_output=emitted_output, duration_ms=_elapsed_ms(attempt_started_at)) return except Exception as exc: error_type = _route_error_type(exc, target.provider) @@ -926,6 +936,8 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy metadata={"target_index": index, "error_type": error_type, "emitted_output": emitted_output}, ) target_failures.append(_target_failure_summary(target, error_type)) + fallback_allowed = (not emitted_output) and index < len(targets) - 1 and _streaming_policy_allows_fallback(context.routing_group) and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False) + _append_routing_attempt_history(context, target, index, success=False, error_type=error_type, emitted_output=emitted_output, fallback_allowed=fallback_allowed, duration_ms=_elapsed_ms(attempt_started_at)) if emitted_output: self._log_routing_trace( context, @@ -934,7 +946,7 @@ async def _execute_streaming_with_fallback(self, context: RequestContext) -> Asy metadata={"target_index": index, "error_type": error_type}, ) raise - if index < len(targets) - 1 and _streaming_policy_allows_fallback(context.routing_group) and policy.should_fallback(error_type, group=context.routing_group, stream=True, emitted_output=False): + if fallback_allowed: self._log_routing_trace( context, "routing_fallback_selected", @@ -1200,6 +1212,7 @@ async def _execute_non_streaming( approx_cost=approx_cost, response_headers=response_headers, ) + self._clear_failure_history_on_success(provider, model) self._record_session_response(context, response) lib_logger.info( @@ -1533,6 +1546,7 @@ async def _execute_streaming( response_callback=lambda response: self._record_session_response( context, response ), + success_callback=lambda: self._clear_failure_history_on_success(provider, model), transaction_logger=context.transaction_logger, ) @@ -1956,7 +1970,19 @@ async def _wait_for_cooldown( lib_logger.warning( f"Provider {provider} cooldown ({remaining:.1f}s) exceeds budget ({budget:.1f}s)" ) - return # Will fail on no keys available + self._log_provider_cooldown_trace( + context, + "cooldown_wait_exceeds_budget", + provider, + ClassifiedError(error_type="rate_limit", original_exception=RuntimeError("cooldown exceeds budget"), retry_after=int(remaining)), + int(remaining), + "cooldown_exceeds_budget", + model=model, + ) + raise RoutingExecutionError( + f"Provider {provider} cooldown exceeds request budget", + error_type="rate_limit", + ) lib_logger.info(f"Waiting {remaining:.1f}s for {provider} cooldown") self._log_cooldown_wait_trace(context, provider, model, remaining) await asyncio.sleep(remaining) @@ -2103,6 +2129,10 @@ async def _maybe_start_provider_cooldown( failure_history=getattr(self, "_failure_history", None), ) if not decision.should_start: + if decision.reason == "transient_backoff_threshold_not_met" and classified.error_type in {"server_error", "api_connection"}: + history = getattr(self, "_failure_history", None) + if history is not None: + history.record(provider=provider, model=decision.model, error_type=classified.error_type, scope=decision.scope, duration=0, reason=decision.reason) self._log_provider_cooldown_trace( context, "provider_cooldown_skipped", @@ -2137,6 +2167,13 @@ async def _maybe_start_provider_cooldown( except Exception as exc: lib_logger.debug("Failed to start provider cooldown for %s: %s", provider, exc) + def _clear_failure_history_on_success(self, provider: str, model: Optional[str]) -> None: + """Clear transient failure history after a successful provider/model call.""" + + history = getattr(self, "_failure_history", None) + if history is not None and hasattr(history, "clear"): + history.clear(provider=provider, model=model) + @staticmethod def _log_provider_cooldown_trace( context: Optional[RequestContext], @@ -2747,6 +2784,45 @@ def _target_failure_summary(target: RouteTarget, error_type: str, *, status_code return summary +def _append_routing_attempt_history( + context: RequestContext, + target: RouteTarget, + target_index: int, + *, + success: bool, + error_type: Optional[str] = None, + status_code: Optional[int] = None, + emitted_output: Optional[bool] = None, + fallback_allowed: Optional[bool] = None, + duration_ms: Optional[int] = None, +) -> None: + """Append sanitized live fallback-attempt metadata to the request context.""" + + entry: Dict[str, Any] = { + "target_index": target_index, + "target": target.name, + "provider": target.provider, + "model": target.prefixed_model, + "execution": target.execution, + "success": success, + } + if error_type is not None: + entry["error_type"] = normalize_route_error_type(error_type) + if status_code is not None: + entry["status_code"] = status_code + if emitted_output is not None: + entry["emitted_output"] = bool(emitted_output) + if fallback_allowed is not None: + entry["fallback_allowed"] = bool(fallback_allowed) + if duration_ms is not None: + entry["duration_ms"] = max(0, duration_ms) + context.routing_attempt_history.append(entry) + + +def _elapsed_ms(started_at: float) -> int: + return int((time.monotonic() - started_at) * 1000) + + def _group_streaming_policy(group: Any) -> str: """Return the active streaming fallback policy for trace and decisions.""" @@ -2826,7 +2902,7 @@ def _summary_hard_stop_type(summary: str) -> str: return "invalid_request" -def _with_fallback_summary(response: Any, target_failures: List[Dict[str, Any]]) -> Any: +def _with_fallback_summary(response: Any, target_failures: List[Dict[str, Any]], attempt_history: Optional[List[Dict[str, Any]]] = None) -> Any: """Attach fallback target summaries to a structured error response.""" if not target_failures or not isinstance(response, dict) or not isinstance(response.get("error"), dict): @@ -2834,6 +2910,8 @@ def _with_fallback_summary(response: Any, target_failures: List[Dict[str, Any]]) details = response["error"].setdefault("details", {}) if isinstance(details, dict): details["fallback_targets"] = list(target_failures) + if attempt_history: + details["routing_attempt_history"] = list(attempt_history) return response diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index ee66b5720..2e035fe99 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -56,6 +56,7 @@ async def wrap_stream( cred_context: Optional["CredentialContext"] = None, skip_cost_calculation: bool = False, response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + success_callback: Optional[Callable[[], None]] = None, transaction_logger: Optional[Any] = None, ) -> AsyncGenerator[str, None]: """ @@ -362,6 +363,8 @@ async def close_upstream(reason: str, *, force: bool = False) -> None: prompt_tokens_cache_write=prompt_tokens_cache_write, approx_cost=cost_breakdown.total_cost, ) + if success_callback: + success_callback() if response_callback and (assistant_parts or tool_call_ids): # Intentionally only record response anchors after a complete diff --git a/src/rotator_library/retry_policy.py b/src/rotator_library/retry_policy.py index 80dce1578..313c35c53 100644 --- a/src/rotator_library/retry_policy.py +++ b/src/rotator_library/retry_policy.py @@ -113,10 +113,15 @@ def decide_provider_cooldown( if error_type in {"server_error", "api_connection"} and default_duration >= provider_cooldown_min_seconds: backoff_level = 0 duration = int(default_duration) - if failure_history is not None: - backoff = failure_history.backoff_for(provider=provider, error_type=error_type, scope=scope, model=model if scope == "model" else None, default_duration=duration) - duration = backoff.duration - backoff_level = backoff.level + if scope == "model": + return ProviderCooldownDecision(True, duration=duration, reason="model_capacity_cooldown", scope=scope, model=model, backoff_level=backoff_level) + if failure_history is None: + return ProviderCooldownDecision(False, reason="missing_failure_history", scope=scope, model=model if scope == "model" else None) + backoff = failure_history.backoff_for(provider=provider, error_type=error_type, scope=scope, model=model if scope == "model" else None, default_duration=duration) + duration = backoff.duration + backoff_level = backoff.level + if backoff_level <= 0: + return ProviderCooldownDecision(False, reason="transient_backoff_threshold_not_met", scope=scope, model=model if scope == "model" else None) return ProviderCooldownDecision(True, duration=duration, reason="model_capacity_cooldown" if scope == "model" else "default_transient_cooldown", scope=scope, model=model if scope == "model" else None, backoff_level=backoff_level) return ProviderCooldownDecision(False, reason="missing_retry_after") @@ -163,6 +168,22 @@ def snapshot(self) -> tuple[FailureHistoryEntry, ...]: return tuple(self._entries) + def clear(self, *, provider: str, model: Optional[str] = None, scope: Optional[str] = None, error_type: Optional[str] = None) -> None: + """Clear matching failure entries after a successful provider/model call.""" + + kept = [ + entry + for entry in self._entries + if not ( + entry.provider == provider + and (scope is None or entry.scope == scope) + and (error_type is None or entry.error_type == error_type) + and (entry.scope != "model" or model is None or entry.model == model) + ) + ] + self._entries.clear() + self._entries.extend(kept) + def backoff_for(self, *, provider: Optional[str], error_type: str, scope: str, model: Optional[str], default_duration: int) -> BackoffDecision: """Return bounded backoff for repeated transient failures.""" diff --git a/tests/test_cooldown_activation.py b/tests/test_cooldown_activation.py index 1bb238eeb..b7f0490a3 100644 --- a/tests/test_cooldown_activation.py +++ b/tests/test_cooldown_activation.py @@ -4,7 +4,7 @@ import pytest -from rotator_library.client.executor import RequestExecutor, _can_start_stream_provider_cooldown +from rotator_library.client.executor import RequestExecutor, RoutingExecutionError, _can_start_stream_provider_cooldown from rotator_library.cooldown_manager import CooldownManager from rotator_library.core.types import RequestContext from rotator_library.error_handler import ClassifiedError @@ -28,6 +28,16 @@ async def get_max_remaining(self, provider, *, model=None): return 0 +class BudgetCooldown(FakeCooldown): + def __init__(self, remaining: float) -> None: + super().__init__() + self.remaining = remaining + + async def get_max_remaining(self, provider, *, model=None): + self.waits.append((provider, model)) + return self.remaining + + @pytest.mark.asyncio async def test_start_cooldown_extends_but_does_not_shorten() -> None: manager = CooldownManager() @@ -166,6 +176,33 @@ async def test_wait_for_cooldown_uses_model_scope_when_available() -> None: assert cooldown.waits == [("openai", "gpt-5")] +@pytest.mark.asyncio +async def test_wait_for_cooldown_exceeding_budget_fails_fast() -> None: + cooldown = BudgetCooldown(remaining=60) + + with pytest.raises(RoutingExecutionError) as exc: + await _executor(cooldown)._wait_for_cooldown("openai", 1.0, model="gpt-5") + + assert exc.value.error_type == "rate_limit" + assert cooldown.waits == [("openai", "gpt-5")] + + +@pytest.mark.asyncio +async def test_generic_transient_records_history_before_starting_cooldown(monkeypatch) -> None: + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") + monkeypatch.setenv("PROVIDER_COOLDOWN_DEFAULT_SECONDS", "10") + monkeypatch.setenv("PROVIDER_COOLDOWN_MIN_SECONDS", "10") + cooldown = FakeCooldown() + executor = _executor(cooldown) + + await executor._maybe_start_provider_cooldown("openai", _classified("server_error"), context=None, model="gpt-5") + assert cooldown.scoped_started == [] + assert executor._failure_history.snapshot()[0].reason == "transient_backoff_threshold_not_met" + + await executor._maybe_start_provider_cooldown("openai", _classified("server_error"), context=None, model="gpt-5") + assert cooldown.scoped_started == [("openai", 10, None, "provider", "default_transient_cooldown")] + + def test_streaming_provider_cooldown_gate_allows_only_pre_output_failures() -> None: assert _can_start_stream_provider_cooldown(None) is True assert _can_start_stream_provider_cooldown('data: {"error":{"type":"rate_limit"}}\n\n') is True diff --git a/tests/test_request_executor_fallback_groups.py b/tests/test_request_executor_fallback_groups.py index d86bdcce8..6779d22ed 100644 --- a/tests/test_request_executor_fallback_groups.py +++ b/tests/test_request_executor_fallback_groups.py @@ -51,11 +51,15 @@ async def test_non_streaming_fallback_group_tries_next_target_on_retryable_error attempts = [] targets = (parse_route_target("codex/gpt-5.1-codex"), parse_route_target("openai/gpt-5.1")) - result = await _executor_with_attempts(attempts)._execute_non_streaming_with_fallback(_context(routing_targets=targets)) + context = _context(routing_targets=targets) + result = await _executor_with_attempts(attempts)._execute_non_streaming_with_fallback(context) assert result == {"id": "ok", "model": "openai/gpt-5.1"} assert [attempt.provider for attempt in attempts] == ["codex", "openai"] assert [attempt.kwargs["model"] for attempt in attempts] == ["codex/gpt-5.1-codex", "openai/gpt-5.1"] + assert context.routing_attempt_history[0]["error_type"] == "rate_limit" + assert context.routing_attempt_history[0]["fallback_allowed"] is True + assert context.routing_attempt_history[1]["success"] is True @pytest.mark.asyncio diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 1386ffc96..6e2fee72e 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -142,6 +142,42 @@ def test_failure_history_escalates_repeated_transient_backoff(monkeypatch) -> No assert decision.backoff_level == 2 +def test_single_generic_transient_without_retry_after_does_not_cooldown(monkeypatch) -> None: + history = FailureHistory(clock=lambda: 1000.0) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "3") + + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="openai", + failure_history=history, + ) + + assert decision.should_start is False + assert decision.reason == "transient_backoff_threshold_not_met" + + +def test_failure_history_clear_resets_repeated_transient_backoff(monkeypatch) -> None: + history = FailureHistory(clock=lambda: 1000.0) + monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") + history.record(provider="openai", model=None, error_type="server_error", scope="provider", duration=0, reason="first") + + history.clear(provider="openai", model="gpt-test") + decision = decide_provider_cooldown( + _classified("server_error"), + small_cooldown_threshold=10, + provider_cooldown_min_seconds=10, + default_duration=10, + provider="openai", + failure_history=history, + ) + + assert history.snapshot() == () + assert decision.should_start is False + + def test_failure_history_backoff_is_provider_scoped(monkeypatch) -> None: history = FailureHistory(clock=lambda: 1000.0) monkeypatch.setenv("PROVIDER_BACKOFF_THRESHOLD", "2") diff --git a/tests/test_streaming_fallback_policy.py b/tests/test_streaming_fallback_policy.py index 7f4d656e3..55084e55e 100644 --- a/tests/test_streaming_fallback_policy.py +++ b/tests/test_streaming_fallback_policy.py @@ -74,10 +74,14 @@ async def fake_stream(self, context): executor._execute_streaming = MethodType(fake_stream, executor) - chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(_context())] + context = _context() + chunks = [chunk async for chunk in executor._execute_streaming_with_fallback(context)] assert attempts == ["codex", "openai"] assert chunks[-1] == "data: [DONE]\n\n" + assert context.routing_attempt_history[0]["error_type"] == "rate_limit" + assert context.routing_attempt_history[0]["fallback_allowed"] is True + assert context.routing_attempt_history[1]["success"] is True @pytest.mark.asyncio diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py index 8db6c492a..824fbf4c6 100644 --- a/tests/test_streaming_usage_accounting.py +++ b/tests/test_streaming_usage_accounting.py @@ -99,6 +99,15 @@ async def test_streaming_without_usage_still_marks_success_with_zero_usage() -> assert cred_context.success_kwargs["prompt_tokens_cache_write"] == 0 +@pytest.mark.asyncio +async def test_streaming_completed_calls_success_callback() -> None: + called = [] + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_zero_usage_chunks(), "cred", "gpt-test", success_callback=lambda: called.append(True))] + + assert called == [True] + + @pytest.mark.asyncio async def test_streaming_usage_uses_configured_env_pricing(monkeypatch) -> None: monkeypatch.setenv("MODEL_PRICE_OPENAI_GPT_TEST_INPUT", "2.0") From 81984ed7d0f12dcb6aa0b82e8ab1b444ea326096 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 22:54:26 +0200 Subject: [PATCH 173/182] docs(experimental): plan streaming hardening correction Adds the Phase 8c corrective plan for Responses runtime stream settings, heartbeat frames, upstream close behavior, timeout handling, disconnect handling, and Anthropic iterator close safety. Tests: not run (planning document only) --- .../phase-8c-streaming-runtime-hardening.md | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 docs/experimental/phase-8c-streaming-runtime-hardening.md diff --git a/docs/experimental/phase-8c-streaming-runtime-hardening.md b/docs/experimental/phase-8c-streaming-runtime-hardening.md new file mode 100644 index 000000000..fa731ef9a --- /dev/null +++ b/docs/experimental/phase-8c-streaming-runtime-hardening.md @@ -0,0 +1,60 @@ +# Phase 8c: Responses Stream Runtime Hardening And Anthropic Close Safety + +## Goal + +Close all Phase 8/8b third-pass streaming findings while preserving the transport-neutral Responses stream model and the existing Anthropic compatibility wrapper. + +## Scope + +- Apply Phase 8b stream runtime settings directly in `ResponsesService.stream_events()`. +- Add Responses heartbeat formatting support so SSE wrappers can emit non-visible heartbeat frames. +- Enforce Responses TTFB timeout before first upstream chunk. +- Enforce Responses stall timeout after a prior upstream chunk/visible output. +- Close upstream Responses chat streams on client disconnect, timeout, or abnormal exit. +- Keep Responses cost-comment handling for Phase 9c, but avoid making heartbeat/cost metadata visible output. +- Ensure Anthropic compatibility streaming always attempts to close upstream on disconnect and wrapper exit, including generator-style streams that only expose close on the iterator. + +## Non-Goals + +- Do not add WebSocket routes. +- Do not replace the Responses chat bridge with native Responses execution. +- Do not change public Responses event ordering except inserting SSE comment heartbeats when configured. +- Do not make timeout defaults active; all timeout/heartbeat knobs remain opt-in through existing config. +- Do not solve Phase 9 cost-comment propagation here except preserving metadata as non-visible. + +## Implementation Plan + +1. Responses heartbeat formatter. + - Add heartbeat handling to `ResponsesSSEFormatter` and `ResponsesWebSocketFormatter`. + - Represent heartbeat as a transport-neutral non-terminal `ResponsesStreamEvent` with `event_name="heartbeat"`. + - SSE heartbeat output is `": heartbeat\n\n"`. + +2. Responses runtime settings. + - Load `get_stream_runtime_settings()` inside `ResponsesService.stream_events()`. + - Use polling around upstream `__anext__()` to enforce optional TTFB and stall timeouts and emit optional heartbeats. + - Keep all settings disabled by default. + +3. Upstream close helper. + - Close upstream iterator/stream with `aclose()` or `close()` on disconnect, timeout, or abnormal exit. + - Trace close/close-failure events. + +4. Disconnect detection. + - Poll `request.is_disconnected()` while waiting for upstream chunks. + - If disconnected, close upstream when configured and stop without emitting model output. + +5. Anthropic close safety. + - Track both the async iterator and original OpenAI stream as close candidates. + - Ensure disconnect and wrapper exit close whichever object exposes `aclose()` / `close()`. + +6. Tests. + - Responses heartbeat, TTFB timeout, stall timeout after output, disconnect close, and heartbeat SSE comment tests. + - Anthropic iterator-only upstream close test. + - Broader streaming/Responses/Anthropic regression slice. + +## Acceptance Criteria + +- Responses streaming honors `STREAM_TTFB_TIMEOUT_SECONDS`, `STREAM_STALL_TIMEOUT_SECONDS`, `STREAM_HEARTBEAT_INTERVAL_SECONDS`, and `STREAM_CANCEL_UPSTREAM_ON_DISCONNECT`. +- Responses heartbeat frames are non-visible SSE comments. +- Responses upstream chat streams are closed on disconnect/timeout/abnormal exit. +- Anthropic compatibility streaming closes upstream even when close is exposed only on the async iterator. +- Focused tests pass and both reviewers report no blockers/highs/mediums. From f8600a61161d18bc8494a4752417482b55307473 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 23:04:24 +0200 Subject: [PATCH 174/182] fix(streaming): harden responses runtime streams Applies stream runtime settings to Responses streaming, adds non-visible heartbeat formatting, enforces TTFB and stall timeout handling, closes upstream streams on timeout or disconnect, and closes Anthropic compatibility streams when only the iterator exposes close. Tests: pytest tests/test_responses_streaming.py tests/test_responses_routes.py tests/test_anthropic_transform_tracing.py tests/test_streaming_error_handler.py tests/test_streaming_usage_accounting.py tests/test_request_executor_stream_metrics.py tests/test_native_provider_streaming.py tests/test_native_streaming_transport_seam.py tests/test_stream_policy.py tests/test_stream_transport.py tests/test_config_stream_settings.py tests/test_stream_metrics.py tests/test_stream_events.py --- .../anthropic_compat/streaming.py | 15 +- src/rotator_library/responses/service.py | 128 +++++++++++++++++- src/rotator_library/responses/streaming.py | 16 +++ tests/test_anthropic_transform_tracing.py | 39 ++++++ tests/test_responses_streaming.py | 103 ++++++++++++++ 5 files changed, 293 insertions(+), 8 deletions(-) diff --git a/src/rotator_library/anthropic_compat/streaming.py b/src/rotator_library/anthropic_compat/streaming.py index 4ad588d87..126ed70ab 100644 --- a/src/rotator_library/anthropic_compat/streaming.py +++ b/src/rotator_library/anthropic_compat/streaming.py @@ -67,6 +67,7 @@ async def anthropic_streaming_wrapper( accumulated_thinking = "" # Track accumulated thinking for logging stop_reason_final = "end_turn" # Track final stop reason for logging upstream_closed = False + stream_iterator = openai_stream.__aiter__() def trace_frame(pass_name: str, frame: str, payload: Any | None = None) -> str: """Trace one Anthropic stream conversion frame and return it for yield.""" @@ -90,13 +91,15 @@ async def close_upstream(reason: str) -> None: if upstream_closed: return upstream_closed = True - close = getattr(openai_stream, "aclose", None) - if callable(close): - await close() - else: - sync_close = getattr(openai_stream, "close", None) + for candidate in (stream_iterator, openai_stream): + close = getattr(candidate, "aclose", None) + if callable(close): + await close() + break + sync_close = getattr(candidate, "close", None) if callable(sync_close): sync_close() + break if transaction_logger: transaction_logger.log_transform_pass( "anthropic_stream_upstream_closed", @@ -109,7 +112,7 @@ async def close_upstream(reason: str) -> None: ) try: - async for chunk_str in openai_stream: + async for chunk_str in stream_iterator: if transaction_logger: transaction_logger.log_transform_pass( "anthropic_stream_source_chunk", diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 701b00c7c..ccd13e47d 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -5,12 +5,14 @@ from __future__ import annotations +import asyncio import time from copy import deepcopy from typing import Any, AsyncGenerator, Optional from ..protocols import ProtocolContext from ..streaming import StreamEvent, StreamMonitor +from ..config.experimental import get_stream_runtime_settings from ..usage.accounting import extract_usage_record from ..usage.costs import CostCalculator from ..protocols.responses import ResponsesProtocol @@ -156,6 +158,16 @@ async def stream_response( metadata={"event_name": event.event_name, "terminal": event.terminal, "transport": transport}, scrub_strings=True, ) + if event.heartbeat: + self._trace( + transaction_logger, + "responses_sse_formatted_heartbeat", + formatted, + direction="stream", + stage="final", + metadata={"transport": transport}, + scrub_strings=True, + ) yield formatted async def validate_stream_request(self, raw_request: dict[str, Any]) -> None: @@ -218,6 +230,88 @@ async def stream_events( usage = None item_started = False monitor = StreamMonitor(clock=time.monotonic) + stream_settings = get_stream_runtime_settings() + chat_stream = None + stream_iterator = None + upstream_closed = False + pending_next_task = None + + async def close_upstream(reason: str) -> None: + """Best-effort close for upstream Responses bridge streams.""" + + nonlocal upstream_closed + if upstream_closed: + return + upstream_closed = True + for candidate in (stream_iterator, chat_stream): + if candidate is None: + continue + try: + closer = getattr(candidate, "aclose", None) + if callable(closer): + await closer() + self._trace(transaction_logger, "responses_stream_upstream_closed", {"reason": reason}, direction="stream", stage="provider", metadata={"transport": transport}) + return + closer = getattr(candidate, "close", None) + if callable(closer): + closer() + self._trace(transaction_logger, "responses_stream_upstream_closed", {"reason": reason}, direction="stream", stage="provider", metadata={"transport": transport}) + return + except Exception as exc: + self._trace(transaction_logger, "responses_stream_upstream_close_failed", {"reason": reason, "error_type": type(exc).__name__}, direction="stream", stage="provider", metadata={"transport": transport}) + return + + async def next_upstream_chunk(*, first: bool) -> tuple[str, Any]: + """Return the next upstream chunk or a control marker.""" + + timeout = stream_settings.ttfb_timeout_seconds if first else stream_settings.stall_timeout_seconds + heartbeat = stream_settings.heartbeat_seconds + started_at = time.monotonic() + last_heartbeat_at = started_at + nonlocal pending_next_task + if pending_next_task is None or pending_next_task.done(): + pending_next_task = asyncio.create_task(stream_iterator.__anext__()) + next_task = pending_next_task + while True: + if request is not None and await request.is_disconnected(): + self._trace(transaction_logger, "responses_stream_disconnected", {"reason": "client_disconnected"}, direction="stream", stage="client", metadata={"transport": transport}) + if stream_settings.cancel_upstream_on_disconnect: + next_task.cancel() + pending_next_task = None + await close_upstream("client_disconnected") + return "disconnect", None + elapsed = time.monotonic() - started_at + waits = [] + if timeout is not None: + remaining_timeout = timeout - elapsed + if remaining_timeout <= 0: + next_task.cancel() + pending_next_task = None + await close_upstream("ttfb_timeout" if first else "stall_timeout") + raise ResponsesServiceError( + f"Responses stream {'TTFB' if first else 'stall'} timeout", + status_code=504, + error_type="api_connection", + ) + waits.append(remaining_timeout) + if heartbeat is not None: + remaining_heartbeat = heartbeat - (time.monotonic() - last_heartbeat_at) + if remaining_heartbeat <= 0: + last_heartbeat_at = time.monotonic() + return "heartbeat", None + waits.append(remaining_heartbeat) + wait_timeout = min(waits) if waits else None + if wait_timeout is None: + chunk = await next_task + pending_next_task = None + return "chunk", chunk + done, _ = await asyncio.wait({next_task}, timeout=wait_timeout) + if done: + pending_next_task = None + return "chunk", next_task.result() + if heartbeat is not None and time.monotonic() - last_heartbeat_at >= heartbeat: + last_heartbeat_at = time.monotonic() + return "heartbeat", None if transaction_logger: self._trace( transaction_logger, @@ -233,7 +327,20 @@ async def stream_events( yield ResponsesStreamEvent("response.created", created) try: chat_stream = await client.acompletion(request=request, **chat_kwargs) - async for raw_chunk in chat_stream: + stream_iterator = chat_stream.__aiter__() + first_chunk = True + while True: + try: + marker, raw_chunk = await next_upstream_chunk(first=first_chunk) + except StopAsyncIteration: + break + if marker == "disconnect": + return + if marker == "heartbeat": + self._trace(transaction_logger, "responses_stream_heartbeat", {"comment": "heartbeat"}, direction="stream", stage="transport", metadata={"transport": transport}) + yield ResponsesStreamEvent("heartbeat", {"comment": "heartbeat"}) + continue + first_chunk = False if monitor.metrics.first_byte_at is None: monitor.record_event(StreamEvent("raw_chunk", protocol="responses", raw=raw_chunk)) if transaction_logger: @@ -331,7 +438,7 @@ async def stream_events( yield ResponsesStreamEvent("done", {}, terminal=True) except Exception as exc: monitor.record_event(StreamEvent("error", protocol="responses", data={"error_type": exc.__class__.__name__})) - failed = response_failed_payload(response_id, unified.model, {"message": str(exc), "type": exc.__class__.__name__}) + failed = response_failed_payload(response_id, unified.model, _stream_failure_error(exc)) if state.output_text: failed["output"] = [output_item_done_payload(state)["item"]] self._log_transform_error(transaction_logger, "responses_stream", exc, stream_request) @@ -351,6 +458,9 @@ async def stream_events( yield ResponsesStreamEvent("response.failed", failed) self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": transport, "failed": True}) yield ResponsesStreamEvent("done", {}, terminal=True) + finally: + if chat_stream is not None and not upstream_closed: + await close_upstream("wrapper_exit") async def get_response(self, response_id: str) -> dict[str, Any]: """Return a stored response payload or raise a 404-compatible error.""" @@ -662,6 +772,20 @@ def _stream_error_message(chunk: dict[str, Any]) -> str: return str(message) if message else "Upstream stream error" +def _stream_failure_error(exc: Exception) -> dict[str, Any]: + """Return a client-safe Responses stream failure object.""" + + error_type = getattr(exc, "error_type", None) or exc.__class__.__name__ + result = {"message": str(exc), "type": str(error_type)} + if isinstance(exc, ResponsesServiceError) and error_type == "api_connection": + text = str(exc).lower() + if "ttfb" in text: + result["timeout_type"] = "ttfb" + elif "stall" in text: + result["timeout_type"] = "stall" + return result + + def _chunk_text_delta(chunk: dict[str, Any]) -> str: choices = chunk.get("choices") if isinstance(chunk.get("choices"), list) else [] if not choices: diff --git a/src/rotator_library/responses/streaming.py b/src/rotator_library/responses/streaming.py index 699a971f9..f76111134 100644 --- a/src/rotator_library/responses/streaming.py +++ b/src/rotator_library/responses/streaming.py @@ -34,6 +34,12 @@ class ResponsesStreamEvent: payload: dict[str, Any] terminal: bool = False + @property + def heartbeat(self) -> bool: + """Return whether this event is transport metadata, not model output.""" + + return self.event_name == "heartbeat" + class ResponsesSSEFormatter: """Format Responses API events as HTTP Server-Sent Events.""" @@ -46,10 +52,18 @@ def format_event(self, event_name: str, payload: dict[str, Any]) -> str: def format_stream_event(self, event: ResponsesStreamEvent) -> str: """Format one transport-neutral event for HTTP SSE.""" + if event.heartbeat: + return self.format_heartbeat(str(event.payload.get("comment") or "heartbeat")) if event.terminal: return self.done() return self.format_event(event.event_name, event.payload) + def format_heartbeat(self, comment: str = "heartbeat") -> str: + """Return a non-visible SSE comment heartbeat frame.""" + + safe_comment = comment.replace("\r", " ").replace("\n", " ") + return f": {safe_comment}\n\n" + def done(self) -> str: """Return the final compatibility sentinel used by many SSE clients.""" @@ -68,6 +82,8 @@ def format_event(self, event_name: str, payload: dict[str, Any]) -> str: def format_stream_event(self, event: ResponsesStreamEvent) -> str: """Format one transport-neutral event as a WebSocket message payload.""" + if event.heartbeat: + return json.dumps({"event": "heartbeat", "data": {"visible_output": False}}, ensure_ascii=False) if event.terminal: return json.dumps({"event": "done", "data": {}}, ensure_ascii=False) return self.format_event(event.event_name, event.payload) diff --git a/tests/test_anthropic_transform_tracing.py b/tests/test_anthropic_transform_tracing.py index 98bc9faa7..1f1c7ffbe 100644 --- a/tests/test_anthropic_transform_tracing.py +++ b/tests/test_anthropic_transform_tracing.py @@ -68,6 +68,32 @@ async def aclose(self) -> None: self.closed = True +class IteratorOnlyCloseStream: + def __init__(self) -> None: + self.iterator = IteratorOnlyCloseStreamIterator() + + def __aiter__(self): + return self.iterator + + +class IteratorOnlyCloseStreamIterator: + def __init__(self) -> None: + self.closed = False + self._chunks = iter(['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n']) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._chunks) + except StopIteration: + raise StopAsyncIteration + + async def aclose(self) -> None: + self.closed = True + + @pytest.mark.asyncio async def test_anthropic_stream_traces_and_closes_on_disconnect(tmp_path) -> None: logger = TransactionLogger("anthropic", "claude-test", parent_dir=tmp_path) @@ -86,6 +112,19 @@ async def disconnected() -> bool: assert "anthropic_stream_upstream_closed" in pass_names +@pytest.mark.asyncio +async def test_anthropic_stream_closes_iterator_only_upstream() -> None: + stream = IteratorOnlyCloseStream() + + async def disconnected() -> bool: + return True + + chunks = [chunk async for chunk in anthropic_streaming_wrapper(stream, "claude-test", is_disconnected=disconnected)] + + assert chunks == [] + assert stream.iterator.closed is True + + @pytest.mark.asyncio async def test_anthropic_stream_traces_emitted_frames(tmp_path) -> None: logger = TransactionLogger("anthropic", "claude-test", parent_dir=tmp_path) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 7fea52597..5eb9a5906 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import pytest @@ -23,6 +24,42 @@ async def chunks(): return chunks() +class DelayedCloseableStream: + def __init__(self, chunks: list[str], delays: list[float]) -> None: + self.chunks = chunks + self.delays = delays + self.index = 0 + self.closed = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.chunks): + raise StopAsyncIteration + delay = self.delays[self.index] + chunk = self.chunks[self.index] + self.index += 1 + await asyncio.sleep(delay) + return chunk + + async def aclose(self) -> None: + self.closed = True + + +class DelayedStreamingClient: + def __init__(self, stream: DelayedCloseableStream) -> None: + self.stream = stream + + async def acompletion(self, **kwargs): + return self.stream + + +class DisconnectRequest: + async def is_disconnected(self) -> bool: + return True + + class FailingStreamingClient: def __init__(self, message: str = "stream exploded") -> None: self.message = message @@ -245,12 +282,78 @@ async def test_stream_response_failure_trace_scrubs_header_like_secret_text(tmp_ def test_transport_formatters_expose_sse_and_websocket_seam() -> None: assert ResponsesSSEFormatter().transport == "sse" + assert ResponsesSSEFormatter().format_stream_event(ResponsesStreamEvent("heartbeat", {"comment": "heartbeat"})) == ": heartbeat\n\n" websocket = ResponsesWebSocketFormatter() assert websocket.transport == "websocket" assert websocket.future_supported is True assert websocket.format_stream_event(ResponsesStreamEvent("response.created", {"id": "resp"})) == '{"event": "response.created", "data": {"id": "resp"}}' +@pytest.mark.asyncio +async def test_stream_response_emits_non_visible_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream( + ['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n', "data: [DONE]\n\n"], + [0.03, 0.0], + ) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream))] + + assert ": heartbeat\n\n" in events + assert "event: response.output_text.delta" in "".join(events) + + +@pytest.mark.asyncio +async def test_stream_response_ttfb_timeout_closes_upstream(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.05]) + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream))] + + event_text = "".join(events) + assert stream.closed is True + assert "event: response.failed" in event_text + assert "ttfb" in event_text.lower() + assert event_text.endswith("data: [DONE]\n\n") + + +@pytest.mark.asyncio +async def test_stream_response_stall_timeout_preserves_partial_output(monkeypatch) -> None: + monkeypatch.setenv("STREAM_STALL_TIMEOUT_SECONDS", "0.01") + stream = DelayedCloseableStream( + ['data: {"choices":[{"delta":{"content":"partial"}}]}\n\n', 'data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], + [0.0, 0.05], + ) + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream))] + + event_text = "".join(events) + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stream.closed is True + assert "event: response.failed" in event_text + assert "stall" in event_text.lower() + assert stored is not None + assert stored.output_items[0]["content"][0]["text"] == "partial" + + +@pytest.mark.asyncio +async def test_stream_events_disconnect_closes_upstream(monkeypatch) -> None: + monkeypatch.setenv("STREAM_CANCEL_UPSTREAM_ON_DISCONNECT", "true") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.05]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [event async for event in service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream), request=DisconnectRequest())] + + assert [event.event_name for event in events] == ["response.created"] + assert stream.closed is True + + @pytest.mark.asyncio async def test_responses_stream_records_common_stream_metrics(tmp_path) -> None: logger = TransactionLogger("responses", "gpt-test", parent_dir=tmp_path) From 177e4f494caf0487bd3f54fde1d5e845ae670254 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 23:10:53 +0200 Subject: [PATCH 175/182] fix(streaming): preserve response timeout deadlines Keeps pending upstream/acquisition tasks alive across heartbeat frames, enforces TTFB across acquisition and first-chunk waits, awaits pending task cancellation before upstream close, and adds combined heartbeat-timeout coverage. Tests: pytest tests/test_responses_streaming.py tests/test_responses_routes.py tests/test_anthropic_transform_tracing.py tests/test_streaming_error_handler.py tests/test_streaming_usage_accounting.py tests/test_request_executor_stream_metrics.py tests/test_native_provider_streaming.py tests/test_native_streaming_transport_seam.py tests/test_stream_policy.py tests/test_stream_transport.py tests/test_config_stream_settings.py tests/test_stream_metrics.py tests/test_stream_events.py --- src/rotator_library/responses/service.py | 115 +++++++++++++++++++++-- tests/test_responses_streaming.py | 62 +++++++++++- 2 files changed, 166 insertions(+), 11 deletions(-) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index ccd13e47d..55fe14fea 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -235,6 +235,24 @@ async def stream_events( stream_iterator = None upstream_closed = False pending_next_task = None + pending_next_started_at = None + pending_next_last_heartbeat_at = None + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + + async def cancel_task(task: Any) -> None: + """Cancel and await an in-flight stream task before closing its source.""" + + if task is None or task.done(): + return + task.cancel() + try: + await task + except (asyncio.CancelledError, StopAsyncIteration): + return + except Exception: + return async def close_upstream(reason: str) -> None: """Best-effort close for upstream Responses bridge streams.""" @@ -242,41 +260,47 @@ async def close_upstream(reason: str) -> None: nonlocal upstream_closed if upstream_closed: return - upstream_closed = True + attempted = False for candidate in (stream_iterator, chat_stream): if candidate is None: continue + attempted = True try: closer = getattr(candidate, "aclose", None) if callable(closer): await closer() + upstream_closed = True self._trace(transaction_logger, "responses_stream_upstream_closed", {"reason": reason}, direction="stream", stage="provider", metadata={"transport": transport}) return closer = getattr(candidate, "close", None) if callable(closer): closer() + upstream_closed = True self._trace(transaction_logger, "responses_stream_upstream_closed", {"reason": reason}, direction="stream", stage="provider", metadata={"transport": transport}) return except Exception as exc: self._trace(transaction_logger, "responses_stream_upstream_close_failed", {"reason": reason, "error_type": type(exc).__name__}, direction="stream", stage="provider", metadata={"transport": transport}) - return + continue + if attempted: + self._trace(transaction_logger, "responses_stream_upstream_close_failed", {"reason": reason, "error_type": "no_close_method"}, direction="stream", stage="provider", metadata={"transport": transport}) async def next_upstream_chunk(*, first: bool) -> tuple[str, Any]: """Return the next upstream chunk or a control marker.""" timeout = stream_settings.ttfb_timeout_seconds if first else stream_settings.stall_timeout_seconds heartbeat = stream_settings.heartbeat_seconds - started_at = time.monotonic() - last_heartbeat_at = started_at - nonlocal pending_next_task + nonlocal pending_next_task, pending_next_started_at, pending_next_last_heartbeat_at if pending_next_task is None or pending_next_task.done(): pending_next_task = asyncio.create_task(stream_iterator.__anext__()) + pending_next_started_at = time.monotonic() + pending_next_last_heartbeat_at = pending_next_started_at next_task = pending_next_task + started_at = pending_next_started_at or time.monotonic() while True: if request is not None and await request.is_disconnected(): self._trace(transaction_logger, "responses_stream_disconnected", {"reason": "client_disconnected"}, direction="stream", stage="client", metadata={"transport": transport}) if stream_settings.cancel_upstream_on_disconnect: - next_task.cancel() + await cancel_task(next_task) pending_next_task = None await close_upstream("client_disconnected") return "disconnect", None @@ -285,7 +309,7 @@ async def next_upstream_chunk(*, first: bool) -> tuple[str, Any]: if timeout is not None: remaining_timeout = timeout - elapsed if remaining_timeout <= 0: - next_task.cancel() + await cancel_task(next_task) pending_next_task = None await close_upstream("ttfb_timeout" if first else "stall_timeout") raise ResponsesServiceError( @@ -295,23 +319,86 @@ async def next_upstream_chunk(*, first: bool) -> tuple[str, Any]: ) waits.append(remaining_timeout) if heartbeat is not None: + last_heartbeat_at = pending_next_last_heartbeat_at or started_at remaining_heartbeat = heartbeat - (time.monotonic() - last_heartbeat_at) if remaining_heartbeat <= 0: - last_heartbeat_at = time.monotonic() + pending_next_last_heartbeat_at = time.monotonic() return "heartbeat", None waits.append(remaining_heartbeat) wait_timeout = min(waits) if waits else None if wait_timeout is None: chunk = await next_task pending_next_task = None + pending_next_started_at = None + pending_next_last_heartbeat_at = None return "chunk", chunk done, _ = await asyncio.wait({next_task}, timeout=wait_timeout) if done: pending_next_task = None + pending_next_started_at = None + pending_next_last_heartbeat_at = None return "chunk", next_task.result() + last_heartbeat_at = pending_next_last_heartbeat_at or started_at if heartbeat is not None and time.monotonic() - last_heartbeat_at >= heartbeat: - last_heartbeat_at = time.monotonic() + pending_next_last_heartbeat_at = time.monotonic() return "heartbeat", None + + async def acquire_upstream_stream() -> tuple[str, Any]: + """Acquire the upstream stream under the same TTFB/disconnect policy.""" + + nonlocal acquire_task, acquire_started_at, acquire_last_heartbeat_at + if acquire_task is None or acquire_task.done(): + acquire_task = asyncio.create_task(client.acompletion(request=request, **chat_kwargs)) + acquire_started_at = time.monotonic() + acquire_last_heartbeat_at = acquire_started_at + task = acquire_task + started_at = acquire_started_at or time.monotonic() + timeout = stream_settings.ttfb_timeout_seconds + heartbeat = stream_settings.heartbeat_seconds + while True: + if request is not None and await request.is_disconnected(): + self._trace(transaction_logger, "responses_stream_disconnected", {"reason": "client_disconnected", "phase": "acquire"}, direction="stream", stage="client", metadata={"transport": transport}) + await cancel_task(task) + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + return "disconnect", None + waits = [] + elapsed = time.monotonic() - started_at + if timeout is not None: + remaining_timeout = timeout - elapsed + if remaining_timeout <= 0: + await cancel_task(task) + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + raise ResponsesServiceError("Responses stream TTFB timeout", status_code=504, error_type="api_connection") + waits.append(remaining_timeout) + if heartbeat is not None: + last_heartbeat_at = acquire_last_heartbeat_at or started_at + remaining_heartbeat = heartbeat - (time.monotonic() - last_heartbeat_at) + if remaining_heartbeat <= 0: + acquire_last_heartbeat_at = time.monotonic() + return "heartbeat", None + waits.append(remaining_heartbeat) + wait_timeout = min(waits) if waits else None + if wait_timeout is None: + stream = await task + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + return "stream", stream + done, _ = await asyncio.wait({task}, timeout=wait_timeout) + if done: + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + return "stream", task.result() + last_heartbeat_at = acquire_last_heartbeat_at or started_at + if heartbeat is not None and time.monotonic() - last_heartbeat_at >= heartbeat: + acquire_last_heartbeat_at = time.monotonic() + return "heartbeat", None + if transaction_logger: self._trace( transaction_logger, @@ -326,7 +413,15 @@ async def next_upstream_chunk(*, first: bool) -> tuple[str, Any]: await self._store_stream_current_state(stream_request, created, parent, transaction_logger=transaction_logger) yield ResponsesStreamEvent("response.created", created) try: - chat_stream = await client.acompletion(request=request, **chat_kwargs) + while chat_stream is None: + marker, acquired = await acquire_upstream_stream() + if marker == "disconnect": + return + if marker == "heartbeat": + self._trace(transaction_logger, "responses_stream_heartbeat", {"comment": "heartbeat", "phase": "acquire"}, direction="stream", stage="transport", metadata={"transport": transport}) + yield ResponsesStreamEvent("heartbeat", {"comment": "heartbeat"}) + continue + chat_stream = acquired stream_iterator = chat_stream.__aiter__() first_chunk = True while True: diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 5eb9a5906..2c652a5f4 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -55,11 +55,30 @@ async def acompletion(self, **kwargs): return self.stream +class SlowAcquireStreamingClient: + def __init__(self, stream: DelayedCloseableStream, delay: float) -> None: + self.stream = stream + self.delay = delay + + async def acompletion(self, **kwargs): + await asyncio.sleep(self.delay) + return self.stream + + class DisconnectRequest: async def is_disconnected(self) -> bool: return True +class DisconnectAfterAcquireRequest: + def __init__(self) -> None: + self.calls = 0 + + async def is_disconnected(self) -> bool: + self.calls += 1 + return self.calls > 1 + + class FailingStreamingClient: def __init__(self, message: str = "stream exploded") -> None: self.message = message @@ -320,6 +339,47 @@ async def test_stream_response_ttfb_timeout_closes_upstream(monkeypatch) -> None assert event_text.endswith("data: [DONE]\n\n") +@pytest.mark.asyncio +async def test_stream_response_heartbeat_does_not_reset_ttfb_timeout(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.03") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.1]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream))] + + event_text = "".join(events) + assert ": heartbeat\n\n" in events + assert stream.closed is True + assert "event: response.failed" in event_text + assert "ttfb" in event_text.lower() + + +@pytest.mark.asyncio +async def test_stream_response_acquire_wait_honors_ttfb_timeout(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.0]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, SlowAcquireStreamingClient(stream, 0.05))] + + event_text = "".join(events) + assert "event: response.failed" in event_text + assert "ttfb" in event_text.lower() + + +@pytest.mark.asyncio +async def test_stream_response_acquire_wait_can_emit_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n', "data: [DONE]\n\n"], [0.0, 0.0]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, SlowAcquireStreamingClient(stream, 0.03))] + + assert ": heartbeat\n\n" in events + assert "event: response.completed" in "".join(events) + + @pytest.mark.asyncio async def test_stream_response_stall_timeout_preserves_partial_output(monkeypatch) -> None: monkeypatch.setenv("STREAM_STALL_TIMEOUT_SECONDS", "0.01") @@ -348,7 +408,7 @@ async def test_stream_events_disconnect_closes_upstream(monkeypatch) -> None: stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.05]) service = ResponsesService(store=InMemoryResponsesStore()) - events = [event async for event in service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream), request=DisconnectRequest())] + events = [event async for event in service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream), request=DisconnectAfterAcquireRequest())] assert [event.event_name for event in events] == ["response.created"] assert stream.closed is True From f458603c23cc196ac600c57197e39c28e12ef343 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 23:19:53 +0200 Subject: [PATCH 176/182] fix(streaming): consume completed heartbeat tasks Prevents heartbeat-yielded upstream and acquisition tasks from being overwritten before their results are consumed, and cancels pending acquisition/read tasks during final stream cleanup before closing upstream. Tests: pytest tests/test_responses_streaming.py tests/test_responses_routes.py tests/test_anthropic_transform_tracing.py tests/test_streaming_error_handler.py tests/test_streaming_usage_accounting.py tests/test_request_executor_stream_metrics.py tests/test_native_provider_streaming.py tests/test_native_streaming_transport_seam.py tests/test_stream_policy.py tests/test_stream_transport.py tests/test_config_stream_settings.py tests/test_stream_metrics.py tests/test_stream_events.py --- src/rotator_library/responses/service.py | 16 +++++- tests/test_responses_streaming.py | 73 +++++++++++++++++++++++- 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 55fe14fea..7bba6ec34 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -290,13 +290,18 @@ async def next_upstream_chunk(*, first: bool) -> tuple[str, Any]: timeout = stream_settings.ttfb_timeout_seconds if first else stream_settings.stall_timeout_seconds heartbeat = stream_settings.heartbeat_seconds nonlocal pending_next_task, pending_next_started_at, pending_next_last_heartbeat_at - if pending_next_task is None or pending_next_task.done(): + if pending_next_task is None: pending_next_task = asyncio.create_task(stream_iterator.__anext__()) pending_next_started_at = time.monotonic() pending_next_last_heartbeat_at = pending_next_started_at next_task = pending_next_task started_at = pending_next_started_at or time.monotonic() while True: + if next_task.done(): + pending_next_task = None + pending_next_started_at = None + pending_next_last_heartbeat_at = None + return "chunk", next_task.result() if request is not None and await request.is_disconnected(): self._trace(transaction_logger, "responses_stream_disconnected", {"reason": "client_disconnected"}, direction="stream", stage="client", metadata={"transport": transport}) if stream_settings.cancel_upstream_on_disconnect: @@ -347,7 +352,7 @@ async def acquire_upstream_stream() -> tuple[str, Any]: """Acquire the upstream stream under the same TTFB/disconnect policy.""" nonlocal acquire_task, acquire_started_at, acquire_last_heartbeat_at - if acquire_task is None or acquire_task.done(): + if acquire_task is None: acquire_task = asyncio.create_task(client.acompletion(request=request, **chat_kwargs)) acquire_started_at = time.monotonic() acquire_last_heartbeat_at = acquire_started_at @@ -356,6 +361,11 @@ async def acquire_upstream_stream() -> tuple[str, Any]: timeout = stream_settings.ttfb_timeout_seconds heartbeat = stream_settings.heartbeat_seconds while True: + if task.done(): + acquire_task = None + acquire_started_at = None + acquire_last_heartbeat_at = None + return "stream", task.result() if request is not None and await request.is_disconnected(): self._trace(transaction_logger, "responses_stream_disconnected", {"reason": "client_disconnected", "phase": "acquire"}, direction="stream", stage="client", metadata={"transport": transport}) await cancel_task(task) @@ -554,6 +564,8 @@ async def acquire_upstream_stream() -> tuple[str, Any]: self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": transport, "failed": True}) yield ResponsesStreamEvent("done", {}, terminal=True) finally: + await cancel_task(pending_next_task) + await cancel_task(acquire_task) if chat_stream is not None and not upstream_closed: await close_upstream("wrapper_exit") diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 2c652a5f4..e3a98716f 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -59,12 +59,28 @@ class SlowAcquireStreamingClient: def __init__(self, stream: DelayedCloseableStream, delay: float) -> None: self.stream = stream self.delay = delay + self.calls = 0 async def acompletion(self, **kwargs): + self.calls += 1 await asyncio.sleep(self.delay) return self.stream +class CancelAwareAcquireClient: + def __init__(self, delay: float = 1.0) -> None: + self.cancelled = False + self.delay = delay + + async def acompletion(self, **kwargs): + try: + await asyncio.sleep(self.delay) + except asyncio.CancelledError: + self.cancelled = True + raise + return DelayedCloseableStream([], []) + + class DisconnectRequest: async def is_disconnected(self) -> bool: return True @@ -374,12 +390,67 @@ async def test_stream_response_acquire_wait_can_emit_heartbeat(monkeypatch) -> N stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"hi"}}]}\n\n', "data: [DONE]\n\n"], [0.0, 0.0]) service = ResponsesService(store=InMemoryResponsesStore()) - events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, SlowAcquireStreamingClient(stream, 0.03))] + client = SlowAcquireStreamingClient(stream, 0.03) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, client)] assert ": heartbeat\n\n" in events + assert client.calls == 1 assert "event: response.completed" in "".join(events) +@pytest.mark.asyncio +async def test_stream_response_heartbeat_does_not_drop_completed_first_chunk(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream( + ['data: {"choices":[{"delta":{"content":"kept"}}]}\n\n', "data: [DONE]\n\n"], + [0.05, 0.0], + ) + service = ResponsesService(store=InMemoryResponsesStore()) + events = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream)) + + created = await anext(events) + heartbeat = await anext(events) + await asyncio.sleep(0.06) + added = await anext(events) + delta = await anext(events) + await events.aclose() + + assert created.event_name == "response.created" + assert heartbeat.heartbeat is True + assert added.event_name == "response.output_item.added" + assert delta.event_name == "response.output_text.delta" + assert delta.payload["delta"] == "kept" + + +@pytest.mark.asyncio +async def test_stream_events_aclose_cancels_pending_read_after_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [1.0]) + service = ResponsesService(store=InMemoryResponsesStore()) + events = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, DelayedStreamingClient(stream)) + + assert (await anext(events)).event_name == "response.created" + assert (await anext(events)).heartbeat is True + await events.aclose() + + assert stream.closed is True + + +@pytest.mark.asyncio +async def test_stream_events_aclose_cancels_pending_acquire_after_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + client = CancelAwareAcquireClient() + service = ResponsesService(store=InMemoryResponsesStore()) + events = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, client) + + assert (await anext(events)).event_name == "response.created" + assert (await anext(events)).heartbeat is True + await events.aclose() + + assert client.cancelled is True + + @pytest.mark.asyncio async def test_stream_response_stall_timeout_preserves_partial_output(monkeypatch) -> None: monkeypatch.setenv("STREAM_STALL_TIMEOUT_SECONDS", "0.01") From a5f2731b42475595b8ca8a5b40142cb815340742 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 23:24:43 +0200 Subject: [PATCH 177/182] fix(streaming): keep responses ttfb end-to-end Uses one Responses TTFB deadline across stream acquisition and first chunk, and closes streams returned by completed acquisition tasks when the generator exits after a heartbeat. Tests: pytest tests/test_responses_streaming.py tests/test_responses_routes.py tests/test_anthropic_transform_tracing.py tests/test_streaming_error_handler.py tests/test_streaming_usage_accounting.py tests/test_request_executor_stream_metrics.py tests/test_native_provider_streaming.py tests/test_native_streaming_transport_seam.py tests/test_stream_policy.py tests/test_stream_transport.py tests/test_config_stream_settings.py tests/test_stream_metrics.py tests/test_stream_events.py --- src/rotator_library/responses/service.py | 11 +++++++-- tests/test_responses_streaming.py | 31 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 7bba6ec34..02f8c728e 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -240,6 +240,7 @@ async def stream_events( acquire_task = None acquire_started_at = None acquire_last_heartbeat_at = None + ttfb_started_at = time.monotonic() async def cancel_task(task: Any) -> None: """Cancel and await an in-flight stream task before closing its source.""" @@ -295,7 +296,7 @@ async def next_upstream_chunk(*, first: bool) -> tuple[str, Any]: pending_next_started_at = time.monotonic() pending_next_last_heartbeat_at = pending_next_started_at next_task = pending_next_task - started_at = pending_next_started_at or time.monotonic() + started_at = ttfb_started_at if first else (pending_next_started_at or time.monotonic()) while True: if next_task.done(): pending_next_task = None @@ -354,7 +355,7 @@ async def acquire_upstream_stream() -> tuple[str, Any]: nonlocal acquire_task, acquire_started_at, acquire_last_heartbeat_at if acquire_task is None: acquire_task = asyncio.create_task(client.acompletion(request=request, **chat_kwargs)) - acquire_started_at = time.monotonic() + acquire_started_at = ttfb_started_at acquire_last_heartbeat_at = acquire_started_at task = acquire_task started_at = acquire_started_at or time.monotonic() @@ -564,6 +565,12 @@ async def acquire_upstream_stream() -> tuple[str, Any]: self._trace(transaction_logger, "stream_done_event", {"raw": "done"}, direction="stream", stage="final", metadata={"transport": transport, "failed": True}) yield ResponsesStreamEvent("done", {}, terminal=True) finally: + if chat_stream is None and acquire_task is not None and acquire_task.done() and not acquire_task.cancelled(): + try: + chat_stream = acquire_task.result() + stream_iterator = chat_stream.__aiter__() + except Exception: + chat_stream = None await cancel_task(pending_next_task) await cancel_task(acquire_task) if chat_stream is not None and not upstream_closed: diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index e3a98716f..f1f77416f 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -384,6 +384,20 @@ async def test_stream_response_acquire_wait_honors_ttfb_timeout(monkeypatch) -> assert "ttfb" in event_text.lower() +@pytest.mark.asyncio +async def test_stream_response_ttfb_spans_acquire_and_first_chunk(monkeypatch) -> None: + monkeypatch.setenv("STREAM_TTFB_TIMEOUT_SECONDS", "0.05") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [0.03]) + service = ResponsesService(store=InMemoryResponsesStore()) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, SlowAcquireStreamingClient(stream, 0.03))] + + event_text = "".join(events) + assert stream.closed is True + assert "event: response.failed" in event_text + assert "ttfb" in event_text.lower() + + @pytest.mark.asyncio async def test_stream_response_acquire_wait_can_emit_heartbeat(monkeypatch) -> None: monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") @@ -451,6 +465,23 @@ async def test_stream_events_aclose_cancels_pending_acquire_after_heartbeat(monk assert client.cancelled is True +@pytest.mark.asyncio +async def test_stream_events_aclose_closes_completed_acquire_stream_after_heartbeat(monkeypatch) -> None: + monkeypatch.setenv("STREAM_HEARTBEAT_INTERVAL_SECONDS", "0.01") + stream = DelayedCloseableStream(['data: {"choices":[{"delta":{"content":"late"}}]}\n\n'], [1.0]) + client = SlowAcquireStreamingClient(stream, 0.015) + service = ResponsesService(store=InMemoryResponsesStore()) + events = service.stream_events({"model": "gpt-test", "input": "Hello", "stream": True}, client) + + assert (await anext(events)).event_name == "response.created" + assert (await anext(events)).heartbeat is True + await asyncio.sleep(0.02) + await events.aclose() + + assert client.calls == 1 + assert stream.closed is True + + @pytest.mark.asyncio async def test_stream_response_stall_timeout_preserves_partial_output(monkeypatch) -> None: monkeypatch.setenv("STREAM_STALL_TIMEOUT_SECONDS", "0.01") From 46e40154f037863b524d432dfb3c1495a9e998ba Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 23:33:11 +0200 Subject: [PATCH 178/182] docs(experimental): plan usage cost correction Adds the Phase 9c corrective plan for top-level provider costs, cache-write double-count prevention, Responses SSE cost metadata, native streaming cost traces, and stream cost visibility policy. Tests: not run (planning document only) --- .../phase-9c-usage-cost-correction.md | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 docs/experimental/phase-9c-usage-cost-correction.md diff --git a/docs/experimental/phase-9c-usage-cost-correction.md b/docs/experimental/phase-9c-usage-cost-correction.md new file mode 100644 index 000000000..ec056965f --- /dev/null +++ b/docs/experimental/phase-9c-usage-cost-correction.md @@ -0,0 +1,61 @@ +# Phase 9c: Usage And Cost Correctness Completion + +## Goal + +Close all Phase 9/9b third-pass usage/cost findings while preserving `UsageManager` as the authoritative persistence/limit engine and keeping provider-reported cost precedence from Phase 9b. + +## Scope + +- Preserve provider-reported top-level response cost even when a `usage` object exists. +- Prevent OpenAI-like cache-write tokens from being double-counted in normalized totals/costs. +- Preserve/sum structured provider cost breakdowns that omit `total_cost`. +- Carry Responses streaming SSE cost comments/events into completed usage/cost records. +- Ensure `event: cost` frames are treated as metadata and never visible output for retry/fallback. +- Add native streaming `usage_accounting_summary` traces and preserve provider-reported stream cost where protocol events/raw chunks expose usage. + +## Non-Goals + +- Do not replace `UsageManager`, quota windows, or persisted usage JSON format. +- Do not invent cost totals when neither provider-reported cost nor configured/advisory pricing exists. +- Do not add billing persistence beyond existing usage state and traces. +- Do not implement admin quota/cost dashboards. +- Do not alter public token fields except to correct double-counting. + +## Implementation Plan + +1. Top-level provider cost preservation. + - Merge sibling top-level cost fields into nested `usage` payloads before normalization. + - Covered keys: `cost`, `total_cost`, `cost_details`, `provider_reported_cost`, `currency`, `costMetadata`. + +2. Structured provider cost breakdowns. + - Sum known provider-reported cost fields when no total is present. + - Preserve currency and source metadata. + +3. Cache-write double-count prevention. + - For OpenAI-like usage, subtract both cache-read and cache-write tokens from regular input tokens. + - Keep cache-write tokens as their own normalized bucket. + +4. Responses streaming SSE cost comments/events. + - Parse `: cost ...` comments and `event: cost` frames. + - Merge cost into in-progress/final Responses usage without emitting model output. + - Let final provider usage cost override earlier cost comments. + +5. Metadata visibility for `event: cost`. + - Add stream policy tests locking cost frames as non-visible and retry-safe. + +6. Native streaming usage/cost trace. + - Track stream usage/cost from unified events and raw chunks. + - Emit `usage_accounting_summary` on stream completion. + +7. Tests. + - Cover usage normalization, cost precedence, Responses streaming costs, stream visibility policy, and native streaming cost trace. + +## Acceptance Criteria + +- Top-level provider-reported cost is preserved when a `usage` object exists. +- Cache-write tokens are not double-counted in OpenAI-like normalized totals/costs. +- Structured provider cost breakdowns without totals are summed and treated as provider-reported cost. +- Responses streaming carries SSE cost comments/events into final usage/cost and treats them as metadata. +- `event: cost` frames do not block retry/fallback as visible output. +- Native streaming emits `usage_accounting_summary` and preserves stream cost when available. +- Focused tests pass and both reviewers report no blockers/highs/mediums. From d8ad65c4b74ae925b8f9ebaf3b45763e626fca37 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 23:43:05 +0200 Subject: [PATCH 179/182] fix(usage): preserve provider cost metadata Preserves top-level provider-reported costs with nested usage, avoids OpenAI cache-write double counting, sums structured provider cost breakdowns, carries Responses SSE cost metadata through streamed completions, treats cost events as non-visible metadata, and adds native streaming usage/cost summary traces. Tests: pytest tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_usage_quota_snapshots.py tests/test_responses_streaming.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_stream_policy.py tests/test_native_provider_streaming.py tests/test_native_provider_executor.py tests/test_streaming_usage_accounting.py tests/test_executor_usage_accounting.py tests/test_responses_usage_accounting.py tests/test_native_usage_accounting.py --- .../native_provider/executor.py | 43 ++++++++++ src/rotator_library/responses/service.py | 79 ++++++++++++++++++- src/rotator_library/streaming/policy.py | 5 +- src/rotator_library/usage/accounting.py | 52 +++++++++++- tests/test_native_provider_streaming.py | 33 ++++++++ tests/test_responses_streaming.py | 72 +++++++++++++++++ tests/test_stream_policy.py | 8 ++ tests/test_streaming_usage_accounting.py | 4 +- tests/test_usage_accounting.py | 33 +++++++- 9 files changed, 320 insertions(+), 9 deletions(-) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index aa3f51963..25e2d0c20 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -162,10 +162,16 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte await cache_engine.extract("request", provider_request, context.field_cache_context(), transaction_logger=logger) self._trace(context, "after_request_field_cache_extraction", {"source": "request"}, direction="request", stage="adapter", snapshot=False) self._trace(context, "native_provider_stream_request", provider_request, direction="request", stage="provider") + usage_record = extract_usage_record(None, provider=context.provider, model=context.model, source="native_provider_stream") async for raw_chunk in transport.stream_json_lines(context.endpoint, headers=context.headers, payload=provider_request): self._trace(context, "raw_native_provider_stream_chunk", raw_chunk, direction="stream", stage="provider") event = protocol.parse_stream_event(raw_chunk, protocol_context) self._trace(context, "parsed_native_unified_stream_event", event, direction="stream", stage="protocol", snapshot=False) + usage_record = _merge_stream_usage_records( + usage_record, + extract_usage_record(serialize_value(event), provider=context.provider, model=context.model, source="native_stream_event"), + extract_usage_record(raw_chunk, provider=context.provider, model=context.model, source="native_raw_stream_event"), + ) if event.type == "done": event_payload = stream_event_payload(event) self._trace(context, "parsed_native_stream_event", event_payload, direction="stream", stage="protocol") @@ -194,6 +200,15 @@ async def stream(self, raw_request: dict[str, Any], context: NativeProviderConte formatted = response_protocol.format_stream_event(event, context.protocol_context(target_protocol=response_protocol.name)) self._trace(context, "formatted_client_stream_event", formatted, direction="stream", stage="final", snapshot=False) yield formatted + cost_breakdown = CostCalculator().calculate(usage_record, model=context.model, provider=context.provider) + self._trace( + context, + "usage_accounting_summary", + {"usage": usage_record.to_dict(), "cost": cost_breakdown.to_dict()}, + direction="metadata", + stage="final", + snapshot=False, + ) except Exception as exc: if logger: logger.log_transform_error( @@ -289,6 +304,34 @@ async def _inject_unified_request( return _hydrate_unified_request(unified_request, injected) +def _merge_stream_usage_records(base: Any, event_record: Any, raw_record: Any) -> Any: + """Merge native stream usage, preserving raw provider cost when needed.""" + + selected = event_record if _usage_record_has_values(event_record) else base + if not _usage_record_has_values(selected) and _usage_record_has_values(raw_record): + selected = raw_record + if selected.provider_reported_cost is None and raw_record.provider_reported_cost is not None: + selected = replace( + selected, + provider_reported_cost=raw_record.provider_reported_cost, + cost_currency=raw_record.cost_currency, + cost_source=raw_record.cost_source, + ) + return selected + + +def _usage_record_has_values(record: Any) -> bool: + return bool( + record.input_tokens + or record.completion_tokens + or record.reasoning_tokens + or record.cache_read_tokens + or record.cache_write_tokens + or record.raw_total_tokens + or record.provider_reported_cost is not None + ) + + def _redact_field_cache_paths(data: Any, context: NativeProviderContext, direction: str) -> Any: """Redact configured cache paths before broad native payload traces. diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index 02f8c728e..c50afb5d5 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +import json import time from copy import deepcopy from typing import Any, AsyncGenerator, Optional @@ -459,6 +460,11 @@ async def acquire_upstream_stream() -> tuple[str, Any]: metadata={"transport": transport}, ) self._trace(transaction_logger, "raw_chat_bridge_stream_chunk", raw_chunk, direction="stream", stage="provider") + cost_usage = _responses_sse_cost_usage(raw_chunk) + if cost_usage is not None: + usage = _merge_responses_stream_usage(usage, cost_usage) + self._trace(transaction_logger, "responses_stream_cost_event", cost_usage, direction="stream", stage="metadata", metadata={"transport": transport}) + continue chunk = parse_chat_sse_chunk(raw_chunk) if not chunk or chunk.get("type") == "done": continue @@ -466,7 +472,7 @@ async def acquire_upstream_stream() -> tuple[str, Any]: if chunk.get("error") is not None or chunk.get("type") == "error": raise ResponsesServiceError(_stream_error_message(chunk), status_code=502, error_type="upstream_error") if chunk.get("usage"): - usage = chunk["usage"] + usage = _merge_responses_stream_usage(chunk["usage"], usage) delta = _chunk_text_delta(chunk) if not delta: continue @@ -900,6 +906,77 @@ def _stream_failure_error(exc: Exception) -> dict[str, Any]: return result +def _responses_sse_cost_usage(chunk: Any) -> Optional[dict[str, Any]]: + """Extract Responses stream cost metadata from SSE comments/events.""" + + if not isinstance(chunk, str): + return None + payload = _responses_sse_cost_payload(chunk) + if payload is None: + return None + if isinstance(payload, (int, float, str)): + payload = {"provider_reported_cost": payload, "source": "responses_sse_cost"} + if not isinstance(payload, dict): + return None + cost = payload.get("provider_reported_cost", payload.get("total_cost", payload.get("cost"))) + if cost is None: + return None + return { + "provider_reported_cost": cost, + "currency": payload.get("currency", "USD"), + "cost_details": payload, + } + + +def _responses_sse_cost_payload(chunk: str) -> Any: + event_type = None + data_lines: list[str] = [] + for line in chunk.strip().splitlines(): + stripped = line.strip() + if stripped.startswith(":"): + comment = stripped[1:].strip() + if comment.startswith("cost"): + return _parse_cost_text(comment[4:].strip()) + continue + if stripped.startswith("event:"): + event_type = stripped[6:].strip() + continue + if stripped.startswith("data:"): + data_lines.append(stripped[5:].strip()) + if event_type == "cost" and data_lines: + return _parse_cost_text("\n".join(data_lines).strip()) + return None + + +def _parse_cost_text(text: str) -> Any: + if not text: + return None + try: + return json.loads(text) + except json.JSONDecodeError: + try: + return float(text) + except ValueError: + return None + + +def _merge_responses_stream_usage(primary: Any, fallback_cost: Any) -> Any: + """Merge earlier stream cost metadata into later token usage when needed.""" + + if not isinstance(primary, dict): + return fallback_cost if primary is None else primary + if not isinstance(fallback_cost, dict): + return primary + merged = deepcopy(primary) + has_cost = any(key in merged for key in ("cost_details", "cost", "total_cost", "provider_reported_cost")) + if has_cost: + return merged + for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "currency"): + if key in fallback_cost: + merged[key] = deepcopy(fallback_cost[key]) + return merged + + def _chunk_text_delta(chunk: dict[str, Any]) -> str: choices = chunk.get("choices") if isinstance(chunk.get("choices"), list) else [] if not choices: diff --git a/src/rotator_library/streaming/policy.py b/src/rotator_library/streaming/policy.py index 5973a9747..776b0dfe5 100644 --- a/src/rotator_library/streaming/policy.py +++ b/src/rotator_library/streaming/policy.py @@ -23,9 +23,12 @@ def can_retry_stream_after_error(last_streamed_chunk: Optional[str], allow_reaso return True if is_stream_heartbeat_or_comment(last_streamed_chunk): return True + metadata = _sse_json(last_streamed_chunk, malformed_is_visible=False) + if isinstance(metadata, dict) and (metadata.get("event_type") or metadata.get("type")) == "cost": + return True if not allow_reasoning_only_retry: return False - data = _sse_json(last_streamed_chunk, malformed_is_visible=False) + data = metadata if data is None: return False diff --git a/src/rotator_library/usage/accounting.py b/src/rotator_library/usage/accounting.py index 04247e849..fccc51327 100644 --- a/src/rotator_library/usage/accounting.py +++ b/src/rotator_library/usage/accounting.py @@ -114,10 +114,27 @@ def _unwrap_usage(value: Any) -> Any: if value is None: return None if isinstance(value, dict): - return value.get("usage", value) + if "usage" in value: + usage = _as_dict(value.get("usage")) + return _merge_top_level_cost_fields(usage, value) + return value + data = _as_dict(value) + if "usage" in data: + usage = _as_dict(data.get("usage")) + return _merge_top_level_cost_fields(usage, data) return getattr(value, "usage", value) +def _merge_top_level_cost_fields(usage: dict[str, Any], response: dict[str, Any]) -> dict[str, Any]: + """Copy sibling cost metadata into a nested usage payload.""" + + merged = dict(usage) + for key in ("cost", "total_cost", "cost_details", "provider_reported_cost", "currency", "costMetadata"): + if key in response and key not in merged: + merged[key] = response[key] + return merged + + def _as_dict(value: Any) -> dict[str, Any]: if isinstance(value, dict): return value @@ -133,6 +150,7 @@ def _as_dict(value: Any) -> dict[str, Any]: result: dict[str, Any] = {} for key in ( "prompt_tokens", + "usage", "completion_tokens", "total_tokens", "prompt_tokens_details", @@ -189,7 +207,7 @@ def _from_openai_like_usage(data: dict[str, Any], *, provider: Optional[str], mo ) if reasoning and completion_tokens >= reasoning: completion_tokens -= reasoning - input_tokens = max(0, prompt_tokens - cache_read) + input_tokens = max(0, prompt_tokens - cache_read - cache_write) cost = _extract_cost(data) return UsageRecord( input_tokens=input_tokens, @@ -271,18 +289,46 @@ def _extract_cost(data: dict[str, Any]) -> tuple[Optional[float], str, Optional[ if cost_payload else None ) + breakdown_used = False if raw_cost is None and cost_payload: raw_cost = cost_payload.get("total_cost", cost_payload.get("total")) if raw_cost is None and cost_payload: raw_cost = cost_payload.get("cost") + if raw_cost is None and cost_payload: + breakdown_total = _sum_cost_breakdown(cost_payload) + if breakdown_total is not None: + raw_cost = breakdown_total + breakdown_used = True if raw_cost is None: raw_cost = data.get("provider_reported_cost", data.get("total_cost", data.get("cost"))) cost_value = _float_or_none(raw_cost) currency = str(cost_payload.get("currency") or data.get("currency") or "USD") - source = cost_payload.get("source") or ("provider_reported" if cost_value is not None else None) + source = cost_payload.get("source") or ("provider_reported_breakdown" if breakdown_used else ("provider_reported" if cost_value is not None else None)) return cost_value, currency, str(source) if source else None +def _sum_cost_breakdown(payload: dict[str, Any]) -> Optional[float]: + """Sum structured provider cost buckets when no total is reported.""" + + total = 0.0 + found = False + for key in ( + "input_cost", + "prompt_cost", + "cache_read_cost", + "cache_write_cost", + "output_cost", + "completion_cost", + "reasoning_cost", + "thinking_cost", + ): + value = _float_or_none(payload.get(key)) + if value is not None: + total += value + found = True + return total if found else None + + def _looks_like_gemini(data: dict[str, Any]) -> bool: return any(key in data for key in ("promptTokenCount", "candidatesTokenCount", "thoughtsTokenCount", "cachedContentTokenCount")) diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index d3109e86b..b26238565 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -60,6 +60,39 @@ async def test_native_provider_stream_traces_and_yields_formatted_events(tmp_pat assert "opaque-vendor-state" not in trace_text +@pytest.mark.asyncio +async def test_native_provider_stream_traces_usage_accounting_summary(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + transaction_logger=logger, + ) + chunks = [ + { + "choices": [], + "usage": { + "prompt_tokens": 2, + "completion_tokens": 3, + "cost_details": {"total_cost": 0.04, "currency": "USD", "source": "stream_usage"}, + }, + }, + "[DONE]", + ] + + _ = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeStreamingClient(chunks)))] + + entries = _trace_entries(logger.log_dir) + summaries = [entry for entry in entries if entry["pass_name"] == "usage_accounting_summary"] + assert summaries + assert summaries[-1]["data"]["usage"]["input_tokens"] == 2 + assert summaries[-1]["data"]["usage"]["completion_tokens"] == 3 + assert summaries[-1]["data"]["usage"]["provider_reported_cost"] == 0.04 + assert summaries[-1]["data"]["cost"]["provider_reported_cost"] == 0.04 + + @pytest.mark.asyncio async def test_native_provider_stream_logs_errors(tmp_path) -> None: class BrokenClient: diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index f1f77416f..68e5b3ef3 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -127,6 +127,36 @@ async def chunks(): return chunks() +class CostCommentStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield ': cost {"total_cost":0.031,"currency":"USD","source":"responses_sse"}\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class CostEventStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'event: cost\ndata: {"total_cost":0.017,"currency":"EUR","source":"responses_event_cost"}\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class CostCommentFinalOverrideStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield ': cost 0.031\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}],"usage":{"prompt_tokens":1,"completion_tokens":1,"cost_details":{"total_cost":0.062,"source":"final_usage"}}}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + class FailingStore: async def save(self, response): raise RuntimeError("store failed Authorization: Bearer secret-token") @@ -229,6 +259,48 @@ async def test_stream_response_event_error_frames_are_failed() -> None: assert stored.response["error"]["message"] == "event failed" +@pytest.mark.asyncio +async def test_stream_response_preserves_sse_cost_comment() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, CostCommentStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["total_cost"] == 0.031 + assert stored.usage["cost_details"]["source"] == "responses_sse" + + +@pytest.mark.asyncio +async def test_stream_response_preserves_sse_cost_event() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, CostEventStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["total_cost"] == 0.017 + assert stored.usage["cost_details"]["currency"] == "EUR" + + +@pytest.mark.asyncio +async def test_stream_response_final_usage_cost_overrides_sse_cost_comment() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, CostCommentFinalOverrideStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["total_cost"] == 0.062 + assert stored.usage["cost_details"]["source"] == "final_usage" + + @pytest.mark.asyncio async def test_stream_response_can_skip_failed_storage() -> None: store = InMemoryResponsesStore() diff --git a/tests/test_stream_policy.py b/tests/test_stream_policy.py index 49d2a354c..69ec20a5c 100644 --- a/tests/test_stream_policy.py +++ b/tests/test_stream_policy.py @@ -19,6 +19,14 @@ def test_heartbeat_comments_do_not_block_stream_retry() -> None: assert can_retry_stream_after_error(': heartbeat\n\n', False) is True +def test_cost_events_do_not_count_as_visible_output() -> None: + cost_event = 'event: cost\ndata: {"total_cost":0.01}\n\n' + + assert is_visible_stream_output(cost_event) is False + assert is_visible_stream_output(cost_event, protocol="responses") is False + assert can_retry_stream_after_error(cost_event, False) is True + + def test_visible_output_detection_for_chat_chunks() -> None: assert is_visible_stream_output('data: {"choices":[{"delta":{"content":"hello"}}]}\n\n') is True assert is_visible_stream_output('data: {"choices":[{"delta":{"tool_calls":[{"id":"call_1"}]}}]}\n\n') is True diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py index 824fbf4c6..2dc2c1a42 100644 --- a/tests/test_streaming_usage_accounting.py +++ b/tests/test_streaming_usage_accounting.py @@ -67,7 +67,7 @@ async def test_streaming_usage_uses_normalized_accounting_and_trace(tmp_path, mo chunks = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "gpt-test", cred_context=cred_context, transaction_logger=logger)] assert chunks[-1] == "data: [DONE]\n\n" - assert cred_context.success_kwargs["prompt_tokens"] == 60 + assert cred_context.success_kwargs["prompt_tokens"] == 55 assert cred_context.success_kwargs["prompt_tokens_cache_read"] == 40 assert cred_context.success_kwargs["prompt_tokens_cache_write"] == 5 assert cred_context.success_kwargs["completion_tokens"] == 20 @@ -115,7 +115,7 @@ async def test_streaming_usage_uses_configured_env_pricing(monkeypatch) -> None: _ = [chunk async for chunk in StreamingHandler().wrap_stream(_usage_chunks(), "cred", "openai/gpt-test", cred_context=cred_context)] - assert cred_context.success_kwargs["approx_cost"] == 120.0 + assert cred_context.success_kwargs["approx_cost"] == 110.0 @pytest.mark.asyncio diff --git a/tests/test_usage_accounting.py b/tests/test_usage_accounting.py index 018793045..75bb6ae2c 100644 --- a/tests/test_usage_accounting.py +++ b/tests/test_usage_accounting.py @@ -21,13 +21,14 @@ def test_openai_dict_usage_extracts_cache_and_reasoning_without_double_counting( model="gpt-test", ) - assert record.input_tokens == 60 + assert record.input_tokens == 55 assert record.cache_read_tokens == 40 assert record.cache_write_tokens == 5 assert record.completion_tokens == 20 assert record.reasoning_tokens == 10 assert record.output_tokens == 30 assert record.raw_total_tokens == 130 + assert record.total_tokens == 130 def test_openai_usage_extracts_provider_reported_cost_details() -> None: @@ -46,6 +47,34 @@ def test_openai_usage_extracts_provider_reported_cost_details() -> None: assert record.cost_source == "provider_usage" +def test_top_level_cost_is_preserved_when_usage_exists() -> None: + record = extract_usage_record( + { + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + "total_cost": 0.03, + "currency": "EUR", + } + ) + + assert record.provider_reported_cost == 0.03 + assert record.cost_currency == "EUR" + + +def test_structured_cost_breakdown_without_total_is_summed() -> None: + record = extract_usage_record( + { + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "cost_details": {"input_cost": 0.01, "output_cost": 0.02, "reasoning_cost": 0.003}, + } + } + ) + + assert record.provider_reported_cost == 0.033 + assert record.cost_source == "provider_reported_breakdown" + + def test_openai_object_usage_extracts_attributes() -> None: response = SimpleNamespace( usage=SimpleNamespace( @@ -58,7 +87,7 @@ def test_openai_object_usage_extracts_attributes() -> None: record = extract_usage_record(response) - assert record.input_tokens == 10 + assert record.input_tokens == 9 assert record.cache_read_tokens == 2 assert record.cache_write_tokens == 1 assert record.completion_tokens == 4 From ee7b9b53d5d6372ce6722ff981f4d8a8c449e56d Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Sun, 31 May 2026 23:53:16 +0200 Subject: [PATCH 180/182] fix(usage): handle reference stream cost shapes Adds request_cost_usd support for SSE cost comments and protocol usage, preserves top-level stream cost siblings in live streaming and Responses bridge paths, expands structured breakdown summing, treats scalar cost events as retry-safe metadata, and preserves earlier native stream cost when later token usage arrives. Tests: pytest tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_usage_quota_snapshots.py tests/test_responses_streaming.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_bridge.py tests/test_stream_policy.py tests/test_native_provider_streaming.py tests/test_native_provider_executor.py tests/test_streaming_usage_accounting.py tests/test_executor_usage_accounting.py tests/test_responses_usage_accounting.py tests/test_native_usage_accounting.py tests/test_protocol_openai_chat.py tests/test_protocol_responses.py --- src/rotator_library/client/streaming.py | 14 ++++-- .../native_provider/executor.py | 7 +++ src/rotator_library/protocols/openai_chat.py | 2 +- src/rotator_library/protocols/responses.py | 2 +- src/rotator_library/responses/bridge.py | 12 ++++- src/rotator_library/responses/service.py | 17 ++++++- src/rotator_library/streaming/policy.py | 2 + src/rotator_library/usage/accounting.py | 16 +++++-- tests/test_native_provider_streaming.py | 24 ++++++++++ tests/test_responses_bridge.py | 18 ++++++++ tests/test_responses_streaming.py | 45 +++++++++++++++++++ tests/test_stream_policy.py | 3 ++ tests/test_streaming_usage_accounting.py | 41 +++++++++++++++++ tests/test_usage_accounting.py | 8 +++- 14 files changed, 198 insertions(+), 13 deletions(-) diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 2e035fe99..2f692cb57 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -741,8 +741,16 @@ def _usage_from_sse_string(chunk: str) -> Optional[dict[str, Any]]: data = json.loads(payload) except json.JSONDecodeError: return None - usage = data.get("usage") if isinstance(data, dict) else None - return usage if isinstance(usage, dict) else None + if not isinstance(data, dict): + return None + usage = data.get("usage") + if not isinstance(usage, dict): + return None + merged = dict(usage) + for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in data and key not in merged: + merged[key] = data[key] + return merged def _usage_record_from_sse_cost_chunk(chunk: Any, *, model: str) -> UsageRecord: @@ -758,7 +766,7 @@ def _usage_record_from_sse_cost_chunk(chunk: Any, *, model: str) -> UsageRecord: if not isinstance(cost_payload, dict): return UsageRecord(source="stream_cost_event", model=model) return extract_usage_record( - {"usage": {"provider_reported_cost": cost_payload.get("provider_reported_cost", cost_payload.get("total_cost", cost_payload.get("cost"))), "currency": cost_payload.get("currency", "USD"), "cost_details": cost_payload}}, + {"usage": {"provider_reported_cost": cost_payload.get("provider_reported_cost", cost_payload.get("request_cost_usd", cost_payload.get("total_cost", cost_payload.get("cost")))), "currency": cost_payload.get("currency", "USD"), "cost_details": cost_payload}}, model=model, source="stream_cost_event", ) diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index 25e2d0c20..029f2dae4 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -310,6 +310,13 @@ def _merge_stream_usage_records(base: Any, event_record: Any, raw_record: Any) - selected = event_record if _usage_record_has_values(event_record) else base if not _usage_record_has_values(selected) and _usage_record_has_values(raw_record): selected = raw_record + if selected.provider_reported_cost is None and base.provider_reported_cost is not None: + selected = replace( + selected, + provider_reported_cost=base.provider_reported_cost, + cost_currency=base.cost_currency, + cost_source=base.cost_source, + ) if selected.provider_reported_cost is None and raw_record.provider_reported_cost is not None: selected = replace( selected, diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index b75f31bbf..bd95c967b 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -250,7 +250,7 @@ def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = N cost = None cost_details = usage.get("cost_details") if isinstance(cost_details, dict): - provider_cost = cost_details.get("total_cost") or cost_details.get("cost") + provider_cost = cost_details.get("total_cost") or cost_details.get("request_cost_usd") or cost_details.get("cost") cost = CostDetails( provider_reported_cost=float(provider_cost) if provider_cost is not None else None, currency=str(cost_details.get("currency") or "USD"), diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py index 20afbaba7..1a2d8d599 100644 --- a/src/rotator_library/protocols/responses.py +++ b/src/rotator_library/protocols/responses.py @@ -186,7 +186,7 @@ def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = N cost = None cost_details = usage.get("cost_details") if isinstance(cost_details, dict): - provider_cost = cost_details.get("total_cost") or cost_details.get("cost") + provider_cost = cost_details.get("total_cost") or cost_details.get("request_cost_usd") or cost_details.get("cost") cost = CostDetails( provider_reported_cost=float(provider_cost) if provider_cost is not None else None, currency=str(cost_details.get("currency") or "USD"), diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index 3c7a530aa..93fb4598b 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -113,7 +113,7 @@ def from_chat_response( "model": response.get("model") or unified_request.model, "status": _status_from_chat(response), "output": output, - "usage": _usage_to_responses(response.get("usage")), + "usage": _usage_to_responses(response), } if unified_request.metadata: responses_payload["metadata"] = deepcopy(unified_request.metadata) @@ -309,6 +309,12 @@ def _status_from_chat(response: dict[str, Any]) -> str: def _usage_to_responses(usage: Any) -> Any: + if isinstance(usage, dict) and isinstance(usage.get("usage"), dict): + nested = dict(usage["usage"]) + for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in usage and key not in nested: + nested[key] = deepcopy(usage[key]) + usage = nested if not isinstance(usage, dict): return usage result = { @@ -322,9 +328,11 @@ def _usage_to_responses(usage: Any) -> Any: completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") if isinstance(completion_details, dict): result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} - for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "currency"): + for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): if key in usage: result[key] = deepcopy(usage[key]) + if "request_cost_usd" in result and "cost_details" not in result: + result["cost_details"] = {"request_cost_usd": result["request_cost_usd"]} return result diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index c50afb5d5..edfe36269 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -472,7 +472,7 @@ async def acquire_upstream_stream() -> tuple[str, Any]: if chunk.get("error") is not None or chunk.get("type") == "error": raise ResponsesServiceError(_stream_error_message(chunk), status_code=502, error_type="upstream_error") if chunk.get("usage"): - usage = _merge_responses_stream_usage(chunk["usage"], usage) + usage = _merge_responses_stream_usage(_responses_chunk_usage(chunk), usage) delta = _chunk_text_delta(chunk) if not delta: continue @@ -918,7 +918,7 @@ def _responses_sse_cost_usage(chunk: Any) -> Optional[dict[str, Any]]: payload = {"provider_reported_cost": payload, "source": "responses_sse_cost"} if not isinstance(payload, dict): return None - cost = payload.get("provider_reported_cost", payload.get("total_cost", payload.get("cost"))) + cost = payload.get("provider_reported_cost", payload.get("request_cost_usd", payload.get("total_cost", payload.get("cost")))) if cost is None: return None return { @@ -928,6 +928,19 @@ def _responses_sse_cost_usage(chunk: Any) -> Optional[dict[str, Any]]: } +def _responses_chunk_usage(chunk: dict[str, Any]) -> Any: + """Return stream chunk usage with sibling cost metadata preserved.""" + + usage = chunk.get("usage") + if not isinstance(usage, dict): + return usage + merged = dict(usage) + for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in chunk and key not in merged: + merged[key] = deepcopy(chunk[key]) + return merged + + def _responses_sse_cost_payload(chunk: str) -> Any: event_type = None data_lines: list[str] = [] diff --git a/src/rotator_library/streaming/policy.py b/src/rotator_library/streaming/policy.py index 776b0dfe5..1dfb0ac55 100644 --- a/src/rotator_library/streaming/policy.py +++ b/src/rotator_library/streaming/policy.py @@ -119,6 +119,8 @@ def _sse_json(chunk: str, *, malformed_is_visible: bool) -> dict[str, Any] | obj except json.JSONDecodeError: return _MALFORMED_VISIBLE if malformed_is_visible else None if not isinstance(parsed, dict): + if event_type == "cost": + return {"event_type": "cost", "value": parsed} return _MALFORMED_VISIBLE if malformed_is_visible else None if event_type and "event_type" not in parsed: parsed["event_type"] = event_type diff --git a/src/rotator_library/usage/accounting.py b/src/rotator_library/usage/accounting.py index fccc51327..58a65d757 100644 --- a/src/rotator_library/usage/accounting.py +++ b/src/rotator_library/usage/accounting.py @@ -129,7 +129,7 @@ def _merge_top_level_cost_fields(usage: dict[str, Any], response: dict[str, Any] """Copy sibling cost metadata into a nested usage payload.""" merged = dict(usage) - for key in ("cost", "total_cost", "cost_details", "provider_reported_cost", "currency", "costMetadata"): + for key in ("cost", "total_cost", "cost_details", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): if key in response and key not in merged: merged[key] = response[key] return merged @@ -172,6 +172,7 @@ def _as_dict(value: Any) -> dict[str, Any]: "cost_details", "costMetadata", "provider_reported_cost", + "request_cost_usd", "total_cost", "currency", ): @@ -285,7 +286,7 @@ def _extract_cost(data: dict[str, Any]) -> tuple[Optional[float], str, Optional[ cost_payload = _as_dict(data.get("cost") or data.get("cost_details") or data.get("costMetadata") or {}) raw_cost = ( - cost_payload.get("provider_reported_cost") + cost_payload.get("provider_reported_cost", cost_payload.get("request_cost_usd")) if cost_payload else None ) @@ -300,7 +301,7 @@ def _extract_cost(data: dict[str, Any]) -> tuple[Optional[float], str, Optional[ raw_cost = breakdown_total breakdown_used = True if raw_cost is None: - raw_cost = data.get("provider_reported_cost", data.get("total_cost", data.get("cost"))) + raw_cost = data.get("provider_reported_cost", data.get("request_cost_usd", data.get("total_cost", data.get("cost")))) cost_value = _float_or_none(raw_cost) currency = str(cost_payload.get("currency") or data.get("currency") or "USD") source = cost_payload.get("source") or ("provider_reported_breakdown" if breakdown_used else ("provider_reported" if cost_value is not None else None)) @@ -317,10 +318,19 @@ def _sum_cost_breakdown(payload: dict[str, Any]) -> Optional[float]: "prompt_cost", "cache_read_cost", "cache_write_cost", + "cached_input_cost", + "cache_write_input_cost", "output_cost", "completion_cost", + "completions_cost", "reasoning_cost", "thinking_cost", + "upstream_inference_cost", + "upstream_inference_prompt_cost", + "upstream_inference_completions_cost", + "request_cost", + "web_search_cost", + "search_cost", ): value = _float_or_none(payload.get(key)) if value is not None: diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index b26238565..3a453f453 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -93,6 +93,30 @@ async def test_native_provider_stream_traces_usage_accounting_summary(tmp_path) assert summaries[-1]["data"]["cost"]["provider_reported_cost"] == 0.04 +@pytest.mark.asyncio +async def test_native_provider_stream_preserves_earlier_cost_when_later_usage_arrives(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + transaction_logger=logger, + ) + chunks = [ + {"choices": [], "usage": {"cost_details": {"total_cost": 0.07, "source": "early_cost"}}}, + {"choices": [], "usage": {"prompt_tokens": 2, "completion_tokens": 3}}, + "[DONE]", + ] + + _ = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeStreamingClient(chunks)))] + + summaries = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "usage_accounting_summary"] + assert summaries[-1]["data"]["usage"]["input_tokens"] == 2 + assert summaries[-1]["data"]["usage"]["completion_tokens"] == 3 + assert summaries[-1]["data"]["usage"]["provider_reported_cost"] == 0.07 + + @pytest.mark.asyncio async def test_native_provider_stream_logs_errors(tmp_path) -> None: class BrokenClient: diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py index dc84e645a..41fac1b7b 100644 --- a/tests/test_responses_bridge.py +++ b/tests/test_responses_bridge.py @@ -169,6 +169,24 @@ def test_bridge_converts_chat_response_to_responses_payload() -> None: assert response["usage"]["total_tokens"] == 3 +def test_bridge_preserves_chat_response_top_level_cost_with_usage() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hello"}) + chat_response = { + "id": "chat_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "Hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2}, + "request_cost_usd": 0.012, + } + + response = bridge.from_chat_response(chat_response, unified) + + assert response["usage"]["cost_details"]["request_cost_usd"] == 0.012 + assert response["usage"]["cost_details"]["total_cost"] == 0.012 + + def test_bridge_converts_chat_tool_calls_to_responses_output_items() -> None: protocol = ResponsesProtocol() bridge = ResponsesBridge(protocol) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 68e5b3ef3..7fb612af3 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -147,6 +147,25 @@ async def chunks(): return chunks() +class RequestCostCommentStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield ': cost {"request_cost_usd":0.023,"source":"reference_sse"}\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + +class TopLevelUsageCostStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield 'data: {"choices":[{"delta":{"content":"priced"}}],"usage":{"prompt_tokens":1,"completion_tokens":1},"total_cost":0.029}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + class CostCommentFinalOverrideStreamingClient: async def acompletion(self, **kwargs): async def chunks(): @@ -287,6 +306,32 @@ async def test_stream_response_preserves_sse_cost_event() -> None: assert stored.usage["cost_details"]["currency"] == "EUR" +@pytest.mark.asyncio +async def test_stream_response_preserves_reference_request_cost_comment() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, RequestCostCommentStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["request_cost_usd"] == 0.023 + + +@pytest.mark.asyncio +async def test_stream_response_preserves_top_level_chunk_cost_with_usage() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, TopLevelUsageCostStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["total_cost"] == 0.029 + + @pytest.mark.asyncio async def test_stream_response_final_usage_cost_overrides_sse_cost_comment() -> None: store = InMemoryResponsesStore() diff --git a/tests/test_stream_policy.py b/tests/test_stream_policy.py index 69ec20a5c..4f279eab2 100644 --- a/tests/test_stream_policy.py +++ b/tests/test_stream_policy.py @@ -21,10 +21,13 @@ def test_heartbeat_comments_do_not_block_stream_retry() -> None: def test_cost_events_do_not_count_as_visible_output() -> None: cost_event = 'event: cost\ndata: {"total_cost":0.01}\n\n' + scalar_cost_event = "event: cost\ndata: 0.01\n\n" assert is_visible_stream_output(cost_event) is False assert is_visible_stream_output(cost_event, protocol="responses") is False assert can_retry_stream_after_error(cost_event, False) is True + assert is_visible_stream_output(scalar_cost_event) is False + assert can_retry_stream_after_error(scalar_cost_event, False) is True def test_visible_output_detection_for_chat_chunks() -> None: diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py index 2dc2c1a42..7058ff7ea 100644 --- a/tests/test_streaming_usage_accounting.py +++ b/tests/test_streaming_usage_accounting.py @@ -46,6 +46,20 @@ async def _cost_event_chunks(): yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} +async def _scalar_cost_event_chunks(): + yield "event: cost\ndata: 0.033\n\n" + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _request_cost_comment_chunks(): + yield ': cost {"request_cost_usd":0.044,"source":"reference_sse"}\n\n' + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + +async def _top_level_cost_usage_sse_chunks(): + yield 'data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1},"total_cost":0.055}\n\n' + + async def _cost_comment_overridden_by_final_usage_chunks(): yield ': cost 0.042\n\n' yield { @@ -137,6 +151,33 @@ async def test_streaming_cost_event_updates_approx_cost() -> None: assert cred_context.success_kwargs["approx_cost"] == 0.021 +@pytest.mark.asyncio +async def test_streaming_scalar_cost_event_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_scalar_cost_event_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.033 + + +@pytest.mark.asyncio +async def test_streaming_reference_request_cost_comment_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_request_cost_comment_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.044 + + +@pytest.mark.asyncio +async def test_streaming_sse_usage_preserves_top_level_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_top_level_cost_usage_sse_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.055 + + @pytest.mark.asyncio async def test_streaming_final_usage_cost_overrides_comment_cost() -> None: cred_context = FakeCredentialContext() diff --git a/tests/test_usage_accounting.py b/tests/test_usage_accounting.py index 75bb6ae2c..b999d86e5 100644 --- a/tests/test_usage_accounting.py +++ b/tests/test_usage_accounting.py @@ -60,13 +60,19 @@ def test_top_level_cost_is_preserved_when_usage_exists() -> None: assert record.cost_currency == "EUR" +def test_reference_request_cost_usd_is_preserved() -> None: + record = extract_usage_record({"usage": {"prompt_tokens": 1, "completion_tokens": 1}, "request_cost_usd": 0.019}) + + assert record.provider_reported_cost == 0.019 + + def test_structured_cost_breakdown_without_total_is_summed() -> None: record = extract_usage_record( { "usage": { "prompt_tokens": 10, "completion_tokens": 5, - "cost_details": {"input_cost": 0.01, "output_cost": 0.02, "reasoning_cost": 0.003}, + "cost_details": {"cached_input_cost": 0.01, "upstream_inference_cost": 0.02, "web_search_cost": 0.003}, } } ) From 41e60e04a1257e159ce8e7887aea5ae32061eb71 Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Mon, 1 Jun 2026 00:00:21 +0200 Subject: [PATCH 181/182] fix(usage): preserve all stream cost metadata Merges top-level provider cost siblings for dict streaming chunks, broadens reference cost breakdown aliases, and lets later raw native token usage replace cost-only records while preserving earlier provider-reported cost. Tests: pytest tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_usage_quota_snapshots.py tests/test_responses_streaming.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_bridge.py tests/test_stream_policy.py tests/test_native_provider_streaming.py tests/test_native_provider_executor.py tests/test_streaming_usage_accounting.py tests/test_executor_usage_accounting.py tests/test_responses_usage_accounting.py tests/test_native_usage_accounting.py tests/test_protocol_openai_chat.py tests/test_protocol_responses.py --- src/rotator_library/client/streaming.py | 22 ++++++++++++++--- .../native_provider/executor.py | 15 ++++++++++-- src/rotator_library/usage/accounting.py | 9 +++++++ tests/test_native_provider_streaming.py | 24 +++++++++++++++++++ tests/test_streaming_usage_accounting.py | 13 ++++++++++ tests/test_usage_accounting.py | 22 +++++++++++++++++ 6 files changed, 100 insertions(+), 5 deletions(-) diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index 2f692cb57..f938c4fcd 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -510,8 +510,9 @@ def _process_chunk( else: chunk_dict = chunk - # Extract metadata before modifying - usage = chunk_dict.get("usage") + # Extract metadata before modifying. Providers can report cost as a + # sibling of `usage`, so keep those fields attached for normalization. + usage = _usage_from_chunk_dict(chunk_dict) finish_reason = None chunk_has_tool_calls = False @@ -565,7 +566,7 @@ def _process_chunk( # INTERMEDIATE CHUNK: Never emit finish_reason choice["finish_reason"] = None - usage = chunk_dict.get("usage") + usage = _usage_from_chunk_dict(chunk_dict) if isinstance(usage, dict): normalize_usage_for_response(usage, model) @@ -753,6 +754,21 @@ def _usage_from_sse_string(chunk: str) -> Optional[dict[str, Any]]: return merged +def _usage_from_chunk_dict(chunk: Any) -> Optional[dict[str, Any]]: + """Return dict chunk usage with top-level provider cost siblings.""" + + if not isinstance(chunk, dict): + return None + usage = chunk.get("usage") + if not isinstance(usage, dict): + return None + merged = dict(usage) + for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + if key in chunk and key not in merged: + merged[key] = chunk[key] + return merged + + def _usage_record_from_sse_cost_chunk(chunk: Any, *, model: str) -> UsageRecord: """Extract provider-reported stream cost from SSE comments/events.""" diff --git a/src/rotator_library/native_provider/executor.py b/src/rotator_library/native_provider/executor.py index 029f2dae4..6b234048a 100644 --- a/src/rotator_library/native_provider/executor.py +++ b/src/rotator_library/native_provider/executor.py @@ -307,8 +307,8 @@ async def _inject_unified_request( def _merge_stream_usage_records(base: Any, event_record: Any, raw_record: Any) -> Any: """Merge native stream usage, preserving raw provider cost when needed.""" - selected = event_record if _usage_record_has_values(event_record) else base - if not _usage_record_has_values(selected) and _usage_record_has_values(raw_record): + selected = event_record if _usage_record_has_token_values(event_record) else base + if not _usage_record_has_token_values(selected) and _usage_record_has_token_values(raw_record): selected = raw_record if selected.provider_reported_cost is None and base.provider_reported_cost is not None: selected = replace( @@ -339,6 +339,17 @@ def _usage_record_has_values(record: Any) -> bool: ) +def _usage_record_has_token_values(record: Any) -> bool: + return bool( + record.input_tokens + or record.completion_tokens + or record.reasoning_tokens + or record.cache_read_tokens + or record.cache_write_tokens + or record.raw_total_tokens + ) + + def _redact_field_cache_paths(data: Any, context: NativeProviderContext, direction: str) -> Any: """Redact configured cache paths before broad native payload traces. diff --git a/src/rotator_library/usage/accounting.py b/src/rotator_library/usage/accounting.py index 58a65d757..16d937315 100644 --- a/src/rotator_library/usage/accounting.py +++ b/src/rotator_library/usage/accounting.py @@ -328,6 +328,15 @@ def _sum_cost_breakdown(payload: dict[str, Any]) -> Optional[float]: "upstream_inference_cost", "upstream_inference_prompt_cost", "upstream_inference_completions_cost", + "upstream_inference_input_cost", + "upstream_inference_output_cost", + "image_input_cost", + "image_output_cost", + "audio_input_cost", + "audio_output_cost", + "data_storage_cost", + "estimated_cost", + "cost_in_usd_ticks", "request_cost", "web_search_cost", "search_cost", diff --git a/tests/test_native_provider_streaming.py b/tests/test_native_provider_streaming.py index 3a453f453..de3d0a5c8 100644 --- a/tests/test_native_provider_streaming.py +++ b/tests/test_native_provider_streaming.py @@ -117,6 +117,30 @@ async def test_native_provider_stream_preserves_earlier_cost_when_later_usage_ar assert summaries[-1]["data"]["usage"]["provider_reported_cost"] == 0.07 +@pytest.mark.asyncio +async def test_native_provider_stream_preserves_cost_when_later_raw_usage_arrives(tmp_path) -> None: + logger = TransactionLogger("native", "gpt-test", parent_dir=tmp_path) + context = NativeProviderContext( + provider="native", + model="gpt-test", + protocol_name="openai_chat", + endpoint="https://example.test/chat", + transaction_logger=logger, + ) + chunks = [ + {"choices": [], "usage": {"cost_details": {"total_cost": 0.08, "source": "early_cost"}}}, + {"choices": [], "prompt_tokens": 2, "completion_tokens": 3}, + "[DONE]", + ] + + _ = [event async for event in NativeProviderExecutor().stream({"model": "gpt-test", "messages": []}, context, NativeHTTPTransport(FakeStreamingClient(chunks)))] + + summaries = [entry for entry in _trace_entries(logger.log_dir) if entry["pass_name"] == "usage_accounting_summary"] + assert summaries[-1]["data"]["usage"]["input_tokens"] == 2 + assert summaries[-1]["data"]["usage"]["completion_tokens"] == 3 + assert summaries[-1]["data"]["usage"]["provider_reported_cost"] == 0.08 + + @pytest.mark.asyncio async def test_native_provider_stream_logs_errors(tmp_path) -> None: class BrokenClient: diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py index 7058ff7ea..272b06f5f 100644 --- a/tests/test_streaming_usage_accounting.py +++ b/tests/test_streaming_usage_accounting.py @@ -60,6 +60,10 @@ async def _top_level_cost_usage_sse_chunks(): yield 'data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1},"total_cost":0.055}\n\n' +async def _top_level_cost_usage_dict_chunks(): + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}, "total_cost": 0.066} + + async def _cost_comment_overridden_by_final_usage_chunks(): yield ': cost 0.042\n\n' yield { @@ -178,6 +182,15 @@ async def test_streaming_sse_usage_preserves_top_level_cost() -> None: assert cred_context.success_kwargs["approx_cost"] == 0.055 +@pytest.mark.asyncio +async def test_streaming_dict_usage_preserves_top_level_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_top_level_cost_usage_dict_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.066 + + @pytest.mark.asyncio async def test_streaming_final_usage_cost_overrides_comment_cost() -> None: cred_context = FakeCredentialContext() diff --git a/tests/test_usage_accounting.py b/tests/test_usage_accounting.py index b999d86e5..5db290a9b 100644 --- a/tests/test_usage_accounting.py +++ b/tests/test_usage_accounting.py @@ -81,6 +81,28 @@ def test_structured_cost_breakdown_without_total_is_summed() -> None: assert record.cost_source == "provider_reported_breakdown" +def test_reference_extended_cost_breakdown_aliases_are_summed() -> None: + record = extract_usage_record( + { + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "cost_details": { + "upstream_inference_input_cost": 0.01, + "upstream_inference_output_cost": 0.02, + "image_input_cost": 0.003, + "audio_input_cost": 0.004, + "data_storage_cost": 0.005, + "estimated_cost": 0.006, + "cost_in_usd_ticks": 0.007, + }, + } + } + ) + + assert record.provider_reported_cost == 0.055 + + def test_openai_object_usage_extracts_attributes() -> None: response = SimpleNamespace( usage=SimpleNamespace( From 96ef2f6dae3f4bd8e17470fc9832aee27fe5103d Mon Sep 17 00:00:00 2001 From: Mirrowel <28632877+Mirrowel@users.noreply.github.com> Date: Mon, 1 Jun 2026 00:14:13 +0200 Subject: [PATCH 182/182] fix(usage): normalize estimated cost fields Converts cost_in_usd_ticks to USD before summing, preserves top-level estimated_cost across accounting and streaming/Responses paths, and adds regression coverage for estimated-cost cost metadata. Tests: pytest tests/test_usage_accounting.py tests/test_usage_costs.py tests/test_usage_quota_snapshots.py tests/test_responses_streaming.py tests/test_responses_service.py tests/test_responses_routes.py tests/test_responses_bridge.py tests/test_stream_policy.py tests/test_native_provider_streaming.py tests/test_native_provider_executor.py tests/test_streaming_usage_accounting.py tests/test_executor_usage_accounting.py tests/test_responses_usage_accounting.py tests/test_native_usage_accounting.py tests/test_protocol_openai_chat.py tests/test_protocol_responses.py --- src/rotator_library/client/streaming.py | 6 ++--- src/rotator_library/protocols/openai_chat.py | 2 +- src/rotator_library/protocols/responses.py | 2 +- src/rotator_library/responses/bridge.py | 6 +++-- src/rotator_library/responses/service.py | 10 ++++----- src/rotator_library/usage/accounting.py | 10 ++++++--- tests/test_responses_bridge.py | 17 +++++++++++++++ tests/test_responses_streaming.py | 23 ++++++++++++++++++++ tests/test_streaming_usage_accounting.py | 14 ++++++++++++ tests/test_usage_accounting.py | 8 ++++++- 10 files changed, 82 insertions(+), 16 deletions(-) diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py index f938c4fcd..eb73df6fe 100644 --- a/src/rotator_library/client/streaming.py +++ b/src/rotator_library/client/streaming.py @@ -748,7 +748,7 @@ def _usage_from_sse_string(chunk: str) -> Optional[dict[str, Any]]: if not isinstance(usage, dict): return None merged = dict(usage) - for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): if key in data and key not in merged: merged[key] = data[key] return merged @@ -763,7 +763,7 @@ def _usage_from_chunk_dict(chunk: Any) -> Optional[dict[str, Any]]: if not isinstance(usage, dict): return None merged = dict(usage) - for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): if key in chunk and key not in merged: merged[key] = chunk[key] return merged @@ -782,7 +782,7 @@ def _usage_record_from_sse_cost_chunk(chunk: Any, *, model: str) -> UsageRecord: if not isinstance(cost_payload, dict): return UsageRecord(source="stream_cost_event", model=model) return extract_usage_record( - {"usage": {"provider_reported_cost": cost_payload.get("provider_reported_cost", cost_payload.get("request_cost_usd", cost_payload.get("total_cost", cost_payload.get("cost")))), "currency": cost_payload.get("currency", "USD"), "cost_details": cost_payload}}, + {"usage": {"provider_reported_cost": cost_payload.get("provider_reported_cost", cost_payload.get("request_cost_usd", cost_payload.get("total_cost", cost_payload.get("cost", cost_payload.get("estimated_cost"))))), "currency": cost_payload.get("currency", "USD"), "cost_details": cost_payload}}, model=model, source="stream_cost_event", ) diff --git a/src/rotator_library/protocols/openai_chat.py b/src/rotator_library/protocols/openai_chat.py index bd95c967b..c21452510 100644 --- a/src/rotator_library/protocols/openai_chat.py +++ b/src/rotator_library/protocols/openai_chat.py @@ -250,7 +250,7 @@ def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = N cost = None cost_details = usage.get("cost_details") if isinstance(cost_details, dict): - provider_cost = cost_details.get("total_cost") or cost_details.get("request_cost_usd") or cost_details.get("cost") + provider_cost = cost_details.get("total_cost") or cost_details.get("request_cost_usd") or cost_details.get("cost") or cost_details.get("estimated_cost") cost = CostDetails( provider_reported_cost=float(provider_cost) if provider_cost is not None else None, currency=str(cost_details.get("currency") or "USD"), diff --git a/src/rotator_library/protocols/responses.py b/src/rotator_library/protocols/responses.py index 1a2d8d599..170b7cb46 100644 --- a/src/rotator_library/protocols/responses.py +++ b/src/rotator_library/protocols/responses.py @@ -186,7 +186,7 @@ def extract_usage(self, raw_or_unified: Any, context: ProtocolContext | None = N cost = None cost_details = usage.get("cost_details") if isinstance(cost_details, dict): - provider_cost = cost_details.get("total_cost") or cost_details.get("request_cost_usd") or cost_details.get("cost") + provider_cost = cost_details.get("total_cost") or cost_details.get("request_cost_usd") or cost_details.get("cost") or cost_details.get("estimated_cost") cost = CostDetails( provider_reported_cost=float(provider_cost) if provider_cost is not None else None, currency=str(cost_details.get("currency") or "USD"), diff --git a/src/rotator_library/responses/bridge.py b/src/rotator_library/responses/bridge.py index 93fb4598b..2dcda9c7b 100644 --- a/src/rotator_library/responses/bridge.py +++ b/src/rotator_library/responses/bridge.py @@ -311,7 +311,7 @@ def _status_from_chat(response: dict[str, Any]) -> str: def _usage_to_responses(usage: Any) -> Any: if isinstance(usage, dict) and isinstance(usage.get("usage"), dict): nested = dict(usage["usage"]) - for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): if key in usage and key not in nested: nested[key] = deepcopy(usage[key]) usage = nested @@ -328,11 +328,13 @@ def _usage_to_responses(usage: Any) -> Any: completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") if isinstance(completion_details, dict): result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} - for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): if key in usage: result[key] = deepcopy(usage[key]) if "request_cost_usd" in result and "cost_details" not in result: result["cost_details"] = {"request_cost_usd": result["request_cost_usd"]} + if "estimated_cost" in result and "cost_details" not in result: + result["cost_details"] = {"estimated_cost": result["estimated_cost"]} return result diff --git a/src/rotator_library/responses/service.py b/src/rotator_library/responses/service.py index edfe36269..e3c68008f 100644 --- a/src/rotator_library/responses/service.py +++ b/src/rotator_library/responses/service.py @@ -918,7 +918,7 @@ def _responses_sse_cost_usage(chunk: Any) -> Optional[dict[str, Any]]: payload = {"provider_reported_cost": payload, "source": "responses_sse_cost"} if not isinstance(payload, dict): return None - cost = payload.get("provider_reported_cost", payload.get("request_cost_usd", payload.get("total_cost", payload.get("cost")))) + cost = payload.get("provider_reported_cost", payload.get("request_cost_usd", payload.get("total_cost", payload.get("cost", payload.get("estimated_cost"))))) if cost is None: return None return { @@ -935,7 +935,7 @@ def _responses_chunk_usage(chunk: dict[str, Any]) -> Any: if not isinstance(usage, dict): return usage merged = dict(usage) - for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): if key in chunk and key not in merged: merged[key] = deepcopy(chunk[key]) return merged @@ -981,10 +981,10 @@ def _merge_responses_stream_usage(primary: Any, fallback_cost: Any) -> Any: if not isinstance(fallback_cost, dict): return primary merged = deepcopy(primary) - has_cost = any(key in merged for key in ("cost_details", "cost", "total_cost", "provider_reported_cost")) + has_cost = any(key in merged for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd")) if has_cost: return merged - for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "currency"): + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency"): if key in fallback_cost: merged[key] = deepcopy(fallback_cost[key]) return merged @@ -1015,7 +1015,7 @@ def _usage_to_responses_stream(usage: Any) -> Any: completion_details = usage.get("completion_tokens_details") or usage.get("output_tokens_details") if isinstance(completion_details, dict): result["output_tokens_details"] = {"reasoning_tokens": completion_details.get("reasoning_tokens", 0)} - for key in ("cost_details", "cost", "total_cost", "provider_reported_cost", "currency"): + for key in ("cost_details", "cost", "total_cost", "estimated_cost", "provider_reported_cost", "request_cost_usd", "currency"): if key in usage: result[key] = deepcopy(usage[key]) return result diff --git a/src/rotator_library/usage/accounting.py b/src/rotator_library/usage/accounting.py index 16d937315..1b1bdb122 100644 --- a/src/rotator_library/usage/accounting.py +++ b/src/rotator_library/usage/accounting.py @@ -129,7 +129,7 @@ def _merge_top_level_cost_fields(usage: dict[str, Any], response: dict[str, Any] """Copy sibling cost metadata into a nested usage payload.""" merged = dict(usage) - for key in ("cost", "total_cost", "cost_details", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): + for key in ("cost", "total_cost", "estimated_cost", "cost_details", "provider_reported_cost", "request_cost_usd", "currency", "costMetadata"): if key in response and key not in merged: merged[key] = response[key] return merged @@ -173,6 +173,7 @@ def _as_dict(value: Any) -> dict[str, Any]: "costMetadata", "provider_reported_cost", "request_cost_usd", + "estimated_cost", "total_cost", "currency", ): @@ -301,7 +302,7 @@ def _extract_cost(data: dict[str, Any]) -> tuple[Optional[float], str, Optional[ raw_cost = breakdown_total breakdown_used = True if raw_cost is None: - raw_cost = data.get("provider_reported_cost", data.get("request_cost_usd", data.get("total_cost", data.get("cost")))) + raw_cost = data.get("provider_reported_cost", data.get("request_cost_usd", data.get("total_cost", data.get("cost", data.get("estimated_cost"))))) cost_value = _float_or_none(raw_cost) currency = str(cost_payload.get("currency") or data.get("currency") or "USD") source = cost_payload.get("source") or ("provider_reported_breakdown" if breakdown_used else ("provider_reported" if cost_value is not None else None)) @@ -336,7 +337,6 @@ def _sum_cost_breakdown(payload: dict[str, Any]) -> Optional[float]: "audio_output_cost", "data_storage_cost", "estimated_cost", - "cost_in_usd_ticks", "request_cost", "web_search_cost", "search_cost", @@ -345,6 +345,10 @@ def _sum_cost_breakdown(payload: dict[str, Any]) -> Optional[float]: if value is not None: total += value found = True + ticks = _float_or_none(payload.get("cost_in_usd_ticks")) + if ticks is not None: + total += ticks / 10_000_000_000 + found = True return total if found else None diff --git a/tests/test_responses_bridge.py b/tests/test_responses_bridge.py index 41fac1b7b..a5dfe9706 100644 --- a/tests/test_responses_bridge.py +++ b/tests/test_responses_bridge.py @@ -187,6 +187,23 @@ def test_bridge_preserves_chat_response_top_level_cost_with_usage() -> None: assert response["usage"]["cost_details"]["total_cost"] == 0.012 +def test_bridge_preserves_chat_response_estimated_cost_with_usage() -> None: + protocol = ResponsesProtocol() + bridge = ResponsesBridge(protocol) + unified = protocol.parse_request({"model": "gpt-test", "input": "Hello"}) + chat_response = { + "id": "chat_1", + "model": "gpt-test", + "choices": [{"message": {"role": "assistant", "content": "Hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2}, + "estimated_cost": 0.013, + } + + response = bridge.from_chat_response(chat_response, unified) + + assert response["usage"]["cost_details"]["estimated_cost"] == 0.013 + + def test_bridge_converts_chat_tool_calls_to_responses_output_items() -> None: protocol = ResponsesProtocol() bridge = ResponsesBridge(protocol) diff --git a/tests/test_responses_streaming.py b/tests/test_responses_streaming.py index 7fb612af3..82fdc9af1 100644 --- a/tests/test_responses_streaming.py +++ b/tests/test_responses_streaming.py @@ -157,6 +157,16 @@ async def chunks(): return chunks() +class EstimatedCostCommentStreamingClient: + async def acompletion(self, **kwargs): + async def chunks(): + yield ': cost {"estimated_cost":0.024,"source":"reference_estimate"}\n\n' + yield 'data: {"choices":[{"delta":{"content":"priced"}}]}\n\n' + yield "data: [DONE]\n\n" + + return chunks() + + class TopLevelUsageCostStreamingClient: async def acompletion(self, **kwargs): async def chunks(): @@ -319,6 +329,19 @@ async def test_stream_response_preserves_reference_request_cost_comment() -> Non assert stored.usage["cost_details"]["request_cost_usd"] == 0.023 +@pytest.mark.asyncio +async def test_stream_response_preserves_estimated_cost_comment() -> None: + store = InMemoryResponsesStore() + service = ResponsesService(store=store) + + events = [chunk async for chunk in service.stream_response({"model": "gpt-test", "input": "Hello", "stream": True}, EstimatedCostCommentStreamingClient())] + + response_id = events[0].split('"id": "')[1].split('"')[0] + stored = await store.get(response_id) + assert stored is not None + assert stored.usage["cost_details"]["estimated_cost"] == 0.024 + + @pytest.mark.asyncio async def test_stream_response_preserves_top_level_chunk_cost_with_usage() -> None: store = InMemoryResponsesStore() diff --git a/tests/test_streaming_usage_accounting.py b/tests/test_streaming_usage_accounting.py index 272b06f5f..8077ee634 100644 --- a/tests/test_streaming_usage_accounting.py +++ b/tests/test_streaming_usage_accounting.py @@ -56,6 +56,11 @@ async def _request_cost_comment_chunks(): yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} +async def _estimated_cost_comment_chunks(): + yield ': cost {"estimated_cost":0.045,"source":"reference_estimate"}\n\n' + yield {"id": "chunk_1", "choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1, "completion_tokens": 1}} + + async def _top_level_cost_usage_sse_chunks(): yield 'data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1},"total_cost":0.055}\n\n' @@ -173,6 +178,15 @@ async def test_streaming_reference_request_cost_comment_updates_approx_cost() -> assert cred_context.success_kwargs["approx_cost"] == 0.044 +@pytest.mark.asyncio +async def test_streaming_estimated_cost_comment_updates_approx_cost() -> None: + cred_context = FakeCredentialContext() + + _ = [chunk async for chunk in StreamingHandler().wrap_stream(_estimated_cost_comment_chunks(), "cred", "gpt-test", cred_context=cred_context)] + + assert cred_context.success_kwargs["approx_cost"] == 0.045 + + @pytest.mark.asyncio async def test_streaming_sse_usage_preserves_top_level_cost() -> None: cred_context = FakeCredentialContext() diff --git a/tests/test_usage_accounting.py b/tests/test_usage_accounting.py index 5db290a9b..e30e82886 100644 --- a/tests/test_usage_accounting.py +++ b/tests/test_usage_accounting.py @@ -66,6 +66,12 @@ def test_reference_request_cost_usd_is_preserved() -> None: assert record.provider_reported_cost == 0.019 +def test_top_level_estimated_cost_is_preserved() -> None: + record = extract_usage_record({"usage": {"prompt_tokens": 1, "completion_tokens": 1}, "estimated_cost": 0.027}) + + assert record.provider_reported_cost == 0.027 + + def test_structured_cost_breakdown_without_total_is_summed() -> None: record = extract_usage_record( { @@ -94,7 +100,7 @@ def test_reference_extended_cost_breakdown_aliases_are_summed() -> None: "audio_input_cost": 0.004, "data_storage_cost": 0.005, "estimated_cost": 0.006, - "cost_in_usd_ticks": 0.007, + "cost_in_usd_ticks": 70_000_000, }, } }