diff --git a/verifiers/clients/openai_chat_completions_client.py b/verifiers/clients/openai_chat_completions_client.py index c755d8dd4..84cd671f3 100644 --- a/verifiers/clients/openai_chat_completions_client.py +++ b/verifiers/clients/openai_chat_completions_client.py @@ -1,10 +1,7 @@ -import base64 import functools from collections.abc import Iterable, Mapping from typing import Any, TypeAlias, cast -import numpy as np - from openai import ( AsyncOpenAI, AuthenticationError, @@ -36,6 +33,7 @@ from openai.types.shared_params import FunctionDefinition from verifiers.clients.client import Client +from verifiers.clients.routed_experts import compose_split_routed_experts from verifiers.errors import ( EmptyModelResponseError, InvalidModelResponseError, @@ -459,27 +457,20 @@ def parse_tokens(response: OpenAIChatResponse) -> ResponseTokens | None: logprobs_content = response.choices[0].logprobs["content"] completion_logprobs = [token["logprob"] for token in logprobs_content] - has_routed_experts = ( - isinstance( - routed_experts := getattr(choice, "routed_experts", None), dict - ) - and "data" in routed_experts - and "shape" in routed_experts - ) - if has_routed_experts: - routed_experts = cast(dict[str, Any], routed_experts) - routed_experts = cast( - list[list[list[int]]], - ( - np.frombuffer( - base64.b85decode(routed_experts["data"]), dtype=np.int32 - ) - .reshape(routed_experts["shape"]) - .tolist() - ), - ) # [seq_len, layers, topk] - else: + response_extra = response.model_extra or {} + choice_extra = choice.model_extra or {} + if ( + "prompt_routed_experts" not in response_extra + and "routed_experts" not in choice_extra + ): routed_experts = None + else: + routed_experts = compose_split_routed_experts( + prompt_routed_experts=response_extra["prompt_routed_experts"], + completion_routed_experts=choice_extra["routed_experts"], + prompt_len=len(prompt_ids), + completion_len=len(completion_ids), + ) return ResponseTokens( prompt_ids=prompt_ids, prompt_mask=prompt_mask, diff --git a/verifiers/clients/openai_completions_client.py b/verifiers/clients/openai_completions_client.py index f7115322a..1fe82f849 100644 --- a/verifiers/clients/openai_completions_client.py +++ b/verifiers/clients/openai_completions_client.py @@ -9,6 +9,7 @@ get_usage_field, handle_openai_overlong_prompt, ) +from verifiers.clients.routed_experts import compose_split_routed_experts from verifiers.errors import ( EmptyModelResponseError, InvalidModelResponseError, @@ -82,8 +83,7 @@ async def get_native_response( ) -> OpenAITextResponse: if tools: raise ValueError( - "Completions API does not support tools. " - "Use chat_completions or messages client_type instead." + "Completions API does not support tools. Use chat_completions or messages client_type instead." ) def normalize_sampling_args(sampling_args: SamplingArgs): @@ -170,12 +170,28 @@ def parse_tokens(response: OpenAITextResponse) -> ResponseTokens | None: ) if completion_logprobs is None: return None + choice = response.choices[0] + response_extra = response.model_extra or {} + choice_extra = choice.model_extra or {} + if ( + "prompt_routed_experts" not in response_extra + and "routed_experts" not in choice_extra + ): + routed_experts = None + else: + routed_experts = compose_split_routed_experts( + prompt_routed_experts=response_extra["prompt_routed_experts"], + completion_routed_experts=choice_extra["routed_experts"], + prompt_len=len(prompt_ids), + completion_len=len(completion_ids), + ) return ResponseTokens( prompt_ids=prompt_ids, prompt_mask=prompt_mask, completion_ids=completion_ids, completion_mask=completion_mask, completion_logprobs=completion_logprobs, + routed_experts=routed_experts, ) return Response( diff --git a/verifiers/clients/renderer_client.py b/verifiers/clients/renderer_client.py index ad7644357..b06c0e0bc 100644 --- a/verifiers/clients/renderer_client.py +++ b/verifiers/clients/renderer_client.py @@ -17,25 +17,25 @@ from typing import Any, ClassVar, cast from openai import AsyncOpenAI - from renderers import Message as RendererMessage from renderers import ( MultimodalRenderer, RenderedTokens, Renderer, RendererPool, + ToolCallFunction, ToolSpec, create_renderer_pool, is_multimodal, ) from renderers import ToolCall as RendererToolCall -from renderers import ToolCallFunction from renderers.client import generate from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import ( handle_openai_overlong_prompt, ) +from verifiers.clients.routed_experts import compose_split_routed_experts from verifiers.errors import EmptyModelResponseError from verifiers.types import ( AssistantMessage, @@ -630,6 +630,18 @@ async def from_native_response(self, response: dict[str, Any]) -> Response: prompt_ids = response.get("prompt_ids", []) completion_ids = response.get("completion_ids", []) completion_logprobs = response.get("completion_logprobs", []) + if ( + "prompt_routed_experts" not in response + and "completion_routed_experts" not in response + ): + routed_experts = None + else: + routed_experts = compose_split_routed_experts( + prompt_routed_experts=response["prompt_routed_experts"], + completion_routed_experts=response["completion_routed_experts"], + prompt_len=len(prompt_ids), + completion_len=len(completion_ids), + ) tokens = ResponseTokens( prompt_ids=prompt_ids, @@ -637,7 +649,7 @@ async def from_native_response(self, response: dict[str, Any]) -> Response: completion_ids=completion_ids, completion_mask=[1] * len(completion_ids), completion_logprobs=completion_logprobs, - routed_experts=response.get("routed_experts"), + routed_experts=routed_experts, multi_modal_data=response.get("multi_modal_data"), ) diff --git a/verifiers/clients/routed_experts.py b/verifiers/clients/routed_experts.py new file mode 100644 index 000000000..080c491a9 --- /dev/null +++ b/verifiers/clients/routed_experts.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import base64 +from typing import Any, Mapping, cast + +from verifiers.types import RoutedExperts + +INT16_BYTES = 2 + + +def _shape_numel(shape: list[int]) -> int: + seq_len, num_layers, topk = shape + return seq_len * num_layers * topk + + +def _token_stride(shape: list[int]) -> int: + return shape[1] * shape[2] * INT16_BYTES + + +def _validate_routed_experts(payload: RoutedExperts) -> RoutedExperts: + assert payload.dtype == "int16" + assert len(payload.shape) == 3 + assert len(payload.data) == _shape_numel(payload.shape) * INT16_BYTES + return payload + + +def _decode_routed_experts(raw: Any) -> RoutedExperts: + if isinstance(raw, RoutedExperts): + return _validate_routed_experts(raw) + + if hasattr(raw, "model_dump"): + raw = raw.model_dump(mode="python") + + raw = cast(Mapping[str, Any], raw) + assert raw["encoding"] == "base64" + assert raw["dtype"] == "int16" + shape = [int(dim) for dim in raw["shape"]] + data = base64.b64decode(raw["data"]) + return _validate_routed_experts(RoutedExperts(shape=shape, data=data)) + + +def slice_routed_experts(payload: RoutedExperts, end: int) -> RoutedExperts: + payload = _validate_routed_experts(payload) + assert 0 <= end <= payload.shape[0] + stride = _token_stride(payload.shape) + return RoutedExperts( + shape=[end, payload.shape[1], payload.shape[2]], + data=payload.data[: end * stride], + ) + + +def compose_split_routed_experts( + *, + prompt_routed_experts: Any, + completion_routed_experts: Any, + prompt_len: int, + completion_len: int, +) -> RoutedExperts | None: + """Compose split prompt/completion routing into compact int16 bytes.""" + + if prompt_routed_experts is None and completion_routed_experts is None: + return None + + prompt = _decode_routed_experts(prompt_routed_experts) + assert prompt.shape[0] == prompt_len + + expected_completion_routed_len = max(completion_len - 1, 0) + if expected_completion_routed_len == 0: + completion_data = b"" + else: + completion = _decode_routed_experts(completion_routed_experts) + assert completion.shape[1:] == prompt.shape[1:] + assert completion.shape[0] == expected_completion_routed_len + completion_data = completion.data + + if completion_len == 0: + return prompt + + stride = _token_stride(prompt.shape) + return _validate_routed_experts( + RoutedExperts( + shape=[prompt_len + completion_len, prompt.shape[1], prompt.shape[2]], + data=prompt.data + completion_data + (b"\0" * stride), + ) + ) diff --git a/verifiers/types.py b/verifiers/types.py index 25a4b8732..03ce50cf0 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -172,13 +172,25 @@ class Usage(CustomBaseModel): total_tokens: int +class RoutedExperts(CustomBaseModel): + dtype: Literal["int16"] = "int16" + shape: list[int] + data: bytes + + @field_validator("shape") + @classmethod + def validate_shape(cls, value: list[int]) -> list[int]: + assert len(value) == 3 + return value + + class ResponseTokens(CustomBaseModel): prompt_ids: list[int] prompt_mask: list[int] completion_ids: list[int] completion_mask: list[int] completion_logprobs: list[float] - routed_experts: list[list[list[int]]] | None = None # [seq_len, layers, topk] + routed_experts: RoutedExperts | None = None # [seq_len, layers, topk] # Renderer-emitted multimodal sidecar (renderers.base.MultiModalData) # carrying processed pixel_values / placeholder ranges per modality. # Populated by the renderer client when the rollout went through a @@ -221,7 +233,7 @@ class TrajectoryStepTokens(TypedDict): completion_logprobs: list[float] overlong_prompt: bool is_truncated: bool - routed_experts: list[list[list[int]]] | None # [seq_len, layers, topk] + routed_experts: RoutedExperts | None # [seq_len, layers, topk] # Renderer-emitted multimodal sidecar (renderers.base.MultiModalData) # carrying processed pixel_values / placeholder ranges per modality. # ``NotRequired`` because text-only rollouts (and non-renderer client diff --git a/verifiers/utils/response_utils.py b/verifiers/utils/response_utils.py index 9bbb38ad8..338c1b6f8 100644 --- a/verifiers/utils/response_utils.py +++ b/verifiers/utils/response_utils.py @@ -1,3 +1,4 @@ +from verifiers.clients.routed_experts import slice_routed_experts from verifiers.types import ( AssistantMessage, Messages, @@ -48,14 +49,15 @@ async def parse_response_tokens( completion_ids = [] completion_mask = [] completion_logprobs = [] - routed_experts = [] if routed_experts is not None else None + if routed_experts is not None: + routed_experts = slice_routed_experts(routed_experts, max_seq_len) elif prompt_len + completion_len > max_seq_len: is_truncated = True completion_ids = tokens.completion_ids[: max_seq_len - prompt_len] completion_mask = tokens.completion_mask[: max_seq_len - prompt_len] completion_logprobs = tokens.completion_logprobs[: max_seq_len - prompt_len] if routed_experts is not None: - routed_experts = routed_experts[: max_seq_len - prompt_len] + routed_experts = slice_routed_experts(routed_experts, max_seq_len) else: is_truncated = False else: