Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions verifiers/clients/openai_chat_completions_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import base64
import functools
from collections.abc import Iterable, Mapping
from typing import Any, TypeAlias, cast

import numpy as np

from openai import (
AsyncOpenAI,
AuthenticationError,
Expand Down Expand Up @@ -36,6 +33,7 @@
from openai.types.shared_params import FunctionDefinition

from verifiers.clients.client import Client
from verifiers.clients.routed_experts import compose_split_routed_experts
from verifiers.errors import (
EmptyModelResponseError,
InvalidModelResponseError,
Expand Down Expand Up @@ -459,27 +457,20 @@ def parse_tokens(response: OpenAIChatResponse) -> ResponseTokens | None:
logprobs_content = response.choices[0].logprobs["content"]
completion_logprobs = [token["logprob"] for token in logprobs_content]

has_routed_experts = (
isinstance(
routed_experts := getattr(choice, "routed_experts", None), dict
)
and "data" in routed_experts
and "shape" in routed_experts
)
if has_routed_experts:
routed_experts = cast(dict[str, Any], routed_experts)
routed_experts = cast(
list[list[list[int]]],
(
np.frombuffer(
base64.b85decode(routed_experts["data"]), dtype=np.int32
)
.reshape(routed_experts["shape"])
.tolist()
),
) # [seq_len, layers, topk]
else:
response_extra = response.model_extra or {}
choice_extra = choice.model_extra or {}
if (
"prompt_routed_experts" not in response_extra
and "routed_experts" not in choice_extra
):
routed_experts = None
else:
routed_experts = compose_split_routed_experts(
prompt_routed_experts=response_extra["prompt_routed_experts"],
completion_routed_experts=choice_extra["routed_experts"],
prompt_len=len(prompt_ids),
completion_len=len(completion_ids),
)
return ResponseTokens(
prompt_ids=prompt_ids,
prompt_mask=prompt_mask,
Expand Down
20 changes: 18 additions & 2 deletions verifiers/clients/openai_completions_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_usage_field,
handle_openai_overlong_prompt,
)
from verifiers.clients.routed_experts import compose_split_routed_experts
from verifiers.errors import (
EmptyModelResponseError,
InvalidModelResponseError,
Expand Down Expand Up @@ -82,8 +83,7 @@ async def get_native_response(
) -> OpenAITextResponse:
if tools:
raise ValueError(
"Completions API does not support tools. "
"Use chat_completions or messages client_type instead."
"Completions API does not support tools. Use chat_completions or messages client_type instead."
)

def normalize_sampling_args(sampling_args: SamplingArgs):
Expand Down Expand Up @@ -170,12 +170,28 @@ def parse_tokens(response: OpenAITextResponse) -> ResponseTokens | None:
)
if completion_logprobs is None:
return None
choice = response.choices[0]
response_extra = response.model_extra or {}
choice_extra = choice.model_extra or {}
if (
"prompt_routed_experts" not in response_extra
and "routed_experts" not in choice_extra
):
routed_experts = None
else:
routed_experts = compose_split_routed_experts(
prompt_routed_experts=response_extra["prompt_routed_experts"],
completion_routed_experts=choice_extra["routed_experts"],
prompt_len=len(prompt_ids),
completion_len=len(completion_ids),
)
return ResponseTokens(
prompt_ids=prompt_ids,
prompt_mask=prompt_mask,
completion_ids=completion_ids,
completion_mask=completion_mask,
completion_logprobs=completion_logprobs,
routed_experts=routed_experts,
)

return Response(
Expand Down
18 changes: 15 additions & 3 deletions verifiers/clients/renderer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@
from typing import Any, ClassVar, cast

from openai import AsyncOpenAI

from renderers import Message as RendererMessage
from renderers import (
MultimodalRenderer,
RenderedTokens,
Renderer,
RendererPool,
ToolCallFunction,
ToolSpec,
create_renderer_pool,
is_multimodal,
)
from renderers import ToolCall as RendererToolCall
from renderers import ToolCallFunction
from renderers.client import generate

from verifiers.clients.client import Client
from verifiers.clients.openai_chat_completions_client import (
handle_openai_overlong_prompt,
)
from verifiers.clients.routed_experts import compose_split_routed_experts
from verifiers.errors import EmptyModelResponseError
from verifiers.types import (
AssistantMessage,
Expand Down Expand Up @@ -630,14 +630,26 @@ async def from_native_response(self, response: dict[str, Any]) -> Response:
prompt_ids = response.get("prompt_ids", [])
completion_ids = response.get("completion_ids", [])
completion_logprobs = response.get("completion_logprobs", [])
if (
"prompt_routed_experts" not in response
and "completion_routed_experts" not in response
):
routed_experts = None
else:
routed_experts = compose_split_routed_experts(
prompt_routed_experts=response["prompt_routed_experts"],
completion_routed_experts=response["completion_routed_experts"],
prompt_len=len(prompt_ids),
completion_len=len(completion_ids),
)

tokens = ResponseTokens(
prompt_ids=prompt_ids,
prompt_mask=[0] * len(prompt_ids),
completion_ids=completion_ids,
completion_mask=[1] * len(completion_ids),
completion_logprobs=completion_logprobs,
routed_experts=response.get("routed_experts"),
routed_experts=routed_experts,
multi_modal_data=response.get("multi_modal_data"),
)

Expand Down
85 changes: 85 additions & 0 deletions verifiers/clients/routed_experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import base64
from typing import Any, Mapping, cast

from verifiers.types import RoutedExperts

INT16_BYTES = 2


def _shape_numel(shape: list[int]) -> int:
seq_len, num_layers, topk = shape
return seq_len * num_layers * topk


def _token_stride(shape: list[int]) -> int:
return shape[1] * shape[2] * INT16_BYTES


def _validate_routed_experts(payload: RoutedExperts) -> RoutedExperts:
assert payload.dtype == "int16"
assert len(payload.shape) == 3
assert len(payload.data) == _shape_numel(payload.shape) * INT16_BYTES
return payload


def _decode_routed_experts(raw: Any) -> RoutedExperts:
if isinstance(raw, RoutedExperts):
return _validate_routed_experts(raw)

if hasattr(raw, "model_dump"):
raw = raw.model_dump(mode="python")

raw = cast(Mapping[str, Any], raw)
assert raw["encoding"] == "base64"
assert raw["dtype"] == "int16"
shape = [int(dim) for dim in raw["shape"]]
data = base64.b64decode(raw["data"])
return _validate_routed_experts(RoutedExperts(shape=shape, data=data))


def slice_routed_experts(payload: RoutedExperts, end: int) -> RoutedExperts:
payload = _validate_routed_experts(payload)
assert 0 <= end <= payload.shape[0]
stride = _token_stride(payload.shape)
return RoutedExperts(
shape=[end, payload.shape[1], payload.shape[2]],
data=payload.data[: end * stride],
)


def compose_split_routed_experts(
*,
prompt_routed_experts: Any,
completion_routed_experts: Any,
prompt_len: int,
completion_len: int,
) -> RoutedExperts | None:
"""Compose split prompt/completion routing into compact int16 bytes."""

if prompt_routed_experts is None and completion_routed_experts is None:
return None

prompt = _decode_routed_experts(prompt_routed_experts)
assert prompt.shape[0] == prompt_len
Comment thread
cursor[bot] marked this conversation as resolved.

expected_completion_routed_len = max(completion_len - 1, 0)
if expected_completion_routed_len == 0:
completion_data = b""
else:
completion = _decode_routed_experts(completion_routed_experts)
assert completion.shape[1:] == prompt.shape[1:]
assert completion.shape[0] == expected_completion_routed_len
completion_data = completion.data

if completion_len == 0:
return prompt

stride = _token_stride(prompt.shape)
return _validate_routed_experts(
RoutedExperts(
shape=[prompt_len + completion_len, prompt.shape[1], prompt.shape[2]],
data=prompt.data + completion_data + (b"\0" * stride),
)
)
16 changes: 14 additions & 2 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,25 @@ class Usage(CustomBaseModel):
total_tokens: int


class RoutedExperts(CustomBaseModel):
dtype: Literal["int16"] = "int16"
shape: list[int]
data: bytes

@field_validator("shape")
@classmethod
def validate_shape(cls, value: list[int]) -> list[int]:
assert len(value) == 3
return value


class ResponseTokens(CustomBaseModel):
prompt_ids: list[int]
prompt_mask: list[int]
completion_ids: list[int]
completion_mask: list[int]
completion_logprobs: list[float]
routed_experts: list[list[list[int]]] | None = None # [seq_len, layers, topk]
routed_experts: RoutedExperts | None = None # [seq_len, layers, topk]
# Renderer-emitted multimodal sidecar (renderers.base.MultiModalData)
# carrying processed pixel_values / placeholder ranges per modality.
# Populated by the renderer client when the rollout went through a
Expand Down Expand Up @@ -221,7 +233,7 @@ class TrajectoryStepTokens(TypedDict):
completion_logprobs: list[float]
overlong_prompt: bool
is_truncated: bool
routed_experts: list[list[list[int]]] | None # [seq_len, layers, topk]
routed_experts: RoutedExperts | None # [seq_len, layers, topk]
# Renderer-emitted multimodal sidecar (renderers.base.MultiModalData)
# carrying processed pixel_values / placeholder ranges per modality.
# ``NotRequired`` because text-only rollouts (and non-renderer client
Expand Down
6 changes: 4 additions & 2 deletions verifiers/utils/response_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from verifiers.clients.routed_experts import slice_routed_experts
from verifiers.types import (
AssistantMessage,
Messages,
Expand Down Expand Up @@ -48,14 +49,15 @@ async def parse_response_tokens(
completion_ids = []
completion_mask = []
completion_logprobs = []
routed_experts = [] if routed_experts is not None else None
if routed_experts is not None:
routed_experts = slice_routed_experts(routed_experts, max_seq_len)
elif prompt_len + completion_len > max_seq_len:
is_truncated = True
completion_ids = tokens.completion_ids[: max_seq_len - prompt_len]
completion_mask = tokens.completion_mask[: max_seq_len - prompt_len]
completion_logprobs = tokens.completion_logprobs[: max_seq_len - prompt_len]
if routed_experts is not None:
routed_experts = routed_experts[: max_seq_len - prompt_len]
routed_experts = slice_routed_experts(routed_experts, max_seq_len)
else:
is_truncated = False
else:
Expand Down
Loading