Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ renderers = false

[tool.uv.sources]
# Pinned to renderers PR #11 until the next PyPI release lands; drop after.
# 1f3de65 = Dynamo chat nvext transport for token-in /chat/completions.
renderers = { git = "https://github.com/PrimeIntellect-ai/renderers.git", rev = "7ca1ab3" }
# 17005dd = Dynamo chat nvext transport with engine_data response support.
renderers = { git = "https://github.com/PrimeIntellect-ai/renderers.git", rev = "17005dd" }

[tool.uv.extra-build-dependencies]
flash-attn = [{ requirement = "torch", match-runtime = true }]
Expand Down
228 changes: 226 additions & 2 deletions tests/test_openai_chat_completions_token_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from types import SimpleNamespace
from typing import Any, cast

import pytest
Expand Down Expand Up @@ -169,7 +170,9 @@ async def test_get_native_response_falls_back_to_super_when_no_prefix_match(
sentinel = {"source": "super"}
calls: list[dict[str, Any]] = []

async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa: ANN001
async def fake_get_prompt_ids( # noqa: ANN001
self, state, prompt_messages, oai_tools, chat_template_kwargs=None
):
return None

async def fake_super_get_native_response( # noqa: ANN001
Expand Down Expand Up @@ -235,7 +238,9 @@ async def test_get_native_response_uses_token_route_when_prompt_ids_available(
recording_client = _RecordingClient()
client = OpenAIChatCompletionsTokenClient(recording_client)

async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa: ANN001
async def fake_get_prompt_ids( # noqa: ANN001
self, state, prompt_messages, oai_tools, chat_template_kwargs=None
):
return [10, 20]

monkeypatch.setattr(
Expand Down Expand Up @@ -270,3 +275,222 @@ async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa:
assert len(recording_client.calls) == 1
assert recording_client.calls[0]["path"] == "/chat/completions/tokens"
assert recording_client.calls[0]["body"]["tokens"] == [10, 20]


# ---------------------------------------------------------------------------
# dynamo_chat_nvext transport (Dynamo bis/dynamo-rl)
# ---------------------------------------------------------------------------


class _StubRenderer:
"""Renderer stand-in for the dynamo_chat_nvext transport tests.

Returns deterministic ids so we can assert on body shape without pulling
in a real HuggingFace tokenizer download. ``render_ids`` returns a
fixed sequence; ``get_stop_token_ids`` returns a marker pair.
"""

def __init__(self) -> None:
self.render_calls: list[dict[str, Any]] = []

def render_ids(
self,
messages,
*,
tools=None,
add_generation_prompt: bool = False,
) -> list[int]:
self.render_calls.append(
{
"messages": messages,
"tools": tools,
"add_generation_prompt": add_generation_prompt,
}
)
# Encode the call shape into ids so tests can disambiguate the two
# bridge tokenize calls without a real tokenizer.
return [42, len(messages), int(add_generation_prompt)]

def get_stop_token_ids(self) -> list[int]:
return [99, 100]


class _DynamoTestClient(OpenAIChatCompletionsTokenClient):
"""Dynamo-transport TITO client with a stubbed renderer.

Subclass override is the cleanest way to inject the stub without going
through ``ClientConfig`` (which would require a real ``api_base_url``
and ``setup_client`` to construct the AsyncOpenAI). The recording
client captures the eventual ``self.client.post(...)`` call.
"""

_stub_renderer: _StubRenderer

def __init__(self, recording_client) -> None:
super().__init__(recording_client)
self._stub_renderer = _StubRenderer()

@property
def renderer_transport(self) -> str: # type: ignore[override]
return "dynamo_chat_nvext"

def _get_renderer(self, model: str): # type: ignore[override]
return self._stub_renderer


@pytest.mark.asyncio
async def test_local_tokenize_uses_renderer_under_dynamo_transport():
"""Bridge tokenize must NOT hit any HTTP route under dynamo_chat_nvext.

Goes straight through ``_local_tokenize`` -> ``renderer.render_ids``.
The recording client would record any errant POST; we assert it sees
none.
"""
recording_client = _RecordingClient()
client = _DynamoTestClient(recording_client)

ids_full = await client.tokenize(
messages=[{"role": "user", "content": "u"}],
tools=None,
model="test-model",
)
ids_base = await client.tokenize(
messages=[{"role": "user", "content": "u"}],
tools=None,
model="test-model",
extra_kwargs={"add_generation_prompt": False},
)

# Both calls hit the renderer, neither hit the wire.
assert recording_client.calls == []
assert client._stub_renderer.render_calls[0]["add_generation_prompt"] is True
assert client._stub_renderer.render_calls[1]["add_generation_prompt"] is False
# And the stub encodes that into the returned ids' last element.
assert ids_full[-1] == 1
assert ids_base[-1] == 0


@pytest.mark.asyncio
async def test_get_native_response_uses_dynamo_chat_nvext_under_transport(
monkeypatch: pytest.MonkeyPatch,
):
"""Dynamo transport must POST to /chat/completions with nvext.token_data.

Mirrors test_get_native_response_uses_token_route_when_prompt_ids_available
but for the new transport.
"""
recording_client = _RecordingClient()
client = _DynamoTestClient(recording_client)

async def fake_get_prompt_ids( # noqa: ANN001
self, state, prompt_messages, oai_tools, chat_template_kwargs=None
):
return [10, 20, 30]

monkeypatch.setattr(
OpenAIChatCompletionsTokenClient, "get_prompt_ids", fake_get_prompt_ids
)

state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "u1"}],
completion=[{"role": "assistant", "content": "a1"}],
prompt_ids=[1],
completion_ids=[2],
)
],
},
)
prompt = cast(Any, [{"role": "user", "content": "u2"}])

response = await client.get_native_response(
prompt=prompt,
model="test-model",
sampling_args={
"max_completion_tokens": 16,
"temperature": 0.5,
"extra_body": {
"nvext": {
"extra_fields": ["timing"],
"cache_salt": "ckpt-42",
},
"cache_salt": "top-level-salt",
},
},
tools=None,
state=state,
)

assert response["ok"] is True
assert len(recording_client.calls) == 1
call = recording_client.calls[0]

# Wire-shape assertions: route, nvext.token_data, stop_token_ids,
# placeholder messages, sampling fields promoted.
assert call["path"] == "/chat/completions"
body = call["body"]
assert body["nvext"]["token_data"] == [10, 20, 30]
assert body["nvext"]["extra_fields"] == ["timing", "engine_data"]
assert body["nvext"]["cache_salt"] == "ckpt-42"
assert body["cache_salt"] == "top-level-salt"
assert body["stop_token_ids"] == [99, 100]
assert body["messages"] == [{"role": "user", "content": "(token-in mode)"}]
assert body["max_completion_tokens"] == 16
assert body["temperature"] == 0.5
assert body["logprobs"] is True
assert body["stream"] is False

# No /chat/completions/tokens, no /tokenize for the dynamo transport.
assert all(
c["path"] != "/chat/completions/tokens" and not c["path"].endswith("/tokenize")
for c in recording_client.calls
)


@pytest.mark.asyncio
async def test_from_native_response_grafts_dynamo_engine_data_tokens():
client = OpenAIChatCompletionsClient(_NoopClient())
message = SimpleNamespace(
content="ok",
tool_calls=None,
model_dump=lambda: {},
)
response = SimpleNamespace(
id="chatcmpl-test",
created=0,
model="test-model",
usage=SimpleNamespace(
prompt_tokens=3,
completion_tokens=2,
total_tokens=5,
),
nvext={
"engine_data": {
"prompt_token_ids": [1, 2, 3],
"completion_token_ids": [4, 5],
},
},
choices=[
SimpleNamespace(
finish_reason="stop",
message=message,
logprobs={
"content": [
{"logprob": -0.1},
{"logprob": -0.2},
]
},
)
],
)

parsed = await client.from_native_response(cast(Any, response))

assert parsed.message.tokens is not None
assert parsed.message.tokens.prompt_ids == [1, 2, 3]
assert parsed.message.tokens.completion_ids == [4, 5]
assert parsed.message.tokens.completion_logprobs == [-0.1, -0.2]
4 changes: 2 additions & 2 deletions tests/test_renderer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def fake_generate(**kwargs):

assert response == {"content": "ok"}
assert len(calls) == 1
assert calls[0]["transport"] == "dynamo_chat_nvext"
assert calls[0]["transport"] == "dynamo"
assert calls[0]["prompt_ids"] == [10, 20]


Expand Down Expand Up @@ -469,7 +469,7 @@ async def test_get_incremental_prompt_ids_accepts_multimodal_tool_user_tail():
"auto",
id="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
),
pytest.param("openai/gpt-oss-20b", "gpt_oss", id="openai/gpt-oss-20b"),
pytest.param("openai/gpt-oss-20b", "gpt-oss", id="openai/gpt-oss-20b"),
]


Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions verifiers/clients/openai_chat_completions_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,44 @@ def parse_finish_reason(response: OpenAIChatResponse) -> FinishReason:
case _:
return None

def _graft_engine_data(response: OpenAIChatResponse) -> None:
nvext = getattr(response, "nvext", None)
if nvext is None and hasattr(response, "model_dump"):
nvext = response.model_dump().get("nvext")
if not isinstance(nvext, dict):
return

choice = response.choices[0]
engine_data = nvext.get("engine_data")
completion_token_ids_top = nvext.get("completion_token_ids")
prompt_token_ids_top = nvext.get("prompt_token_ids")

completion_token_ids: list[int] | None = None
prompt_token_ids: list[int] | None = None
if isinstance(engine_data, dict):
if engine_data.get("completion_token_ids") is not None:
completion_token_ids = list(engine_data["completion_token_ids"])
if engine_data.get("prompt_token_ids") is not None:
prompt_token_ids = list(engine_data["prompt_token_ids"])
if completion_token_ids is None and completion_token_ids_top is not None:
completion_token_ids = list(completion_token_ids_top)
if prompt_token_ids is None and prompt_token_ids_top is not None:
prompt_token_ids = list(prompt_token_ids_top)

if (
getattr(choice, "token_ids", None) is None
and completion_token_ids is not None
):
object.__setattr__(choice, "token_ids", completion_token_ids)
if (
getattr(response, "prompt_token_ids", None) is None
and prompt_token_ids is not None
):
object.__setattr__(response, "prompt_token_ids", prompt_token_ids)

def parse_tokens(response: OpenAIChatResponse) -> ResponseTokens | None:
assert len(response.choices) == 1, "Response should always have one choice"
_graft_engine_data(response)
choice = response.choices[0]
if not hasattr(choice, "token_ids"):
return None
Expand Down
Loading
Loading