Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,10 +991,13 @@ async def processing(
if isinstance(chunk, ChatCompletion):
message = chunk.choices[0].message

if hasattr(message, "reasoning_content"):
thinking_chunk = message.reasoning_content # type: ignore
if thinking_chunk is not None:
mot._thinking += thinking_chunk
# reasoning_content (Anthropic/DeepSeek attribute path) takes priority;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think Anthropic has a different API (?)

# fall back to the "reasoning" extra field used by vLLM and compatible servers.
thinking_chunk = getattr(message, "reasoning_content", None)
if thinking_chunk is None:
thinking_chunk = (message.model_extra or {}).get("reasoning")
if thinking_chunk is not None:
mot._thinking += thinking_chunk

content_chunk = message.content
if content_chunk is not None:
Expand All @@ -1015,10 +1018,11 @@ async def processing(
return

message_delta = chunk.choices[0].delta
if hasattr(message_delta, "reasoning_content"):
thinking_chunk = message_delta.reasoning_content # type: ignore
if thinking_chunk is not None:
mot._thinking += thinking_chunk
thinking_chunk = getattr(message_delta, "reasoning_content", None)
if thinking_chunk is None:
thinking_chunk = (message_delta.model_extra or {}).get("reasoning")
if thinking_chunk is not None:
mot._thinking += thinking_chunk
Comment on lines +1021 to +1025
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DRY (?)


content_chunk = message_delta.content
if content_chunk is not None:
Expand Down
143 changes: 143 additions & 0 deletions test/backends/test_openai_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
"""

import pytest
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice

from mellea.backends import ModelOption
from mellea.backends.openai import OpenAIBackend
from mellea.core.base import ModelOutputThunk


def _make_backend(model_options: dict | None = None) -> OpenAIBackend:
Expand Down Expand Up @@ -168,5 +171,145 @@ def test_make_backend_specific_unknown_mellea_keys_removed(backend):
assert ModelOption.SYSTEM_PROMPT not in result


# --- processing(): reasoning / thinking trace extraction ---


def _vllm_chat_completion(reasoning: str, content: str | None) -> ChatCompletion:
"""Build a ChatCompletion that matches vLLM's thinking-model response shape."""
message = ChatCompletionMessage.model_validate(
{"role": "assistant", "content": content, "reasoning": reasoning}
)
return ChatCompletion(
id="vllm-test",
created=0,
model="qwen3",
object="chat.completion",
choices=[Choice(index=0, finish_reason="stop", message=message)],
)


async def test_processing_captures_vllm_reasoning_field(backend):
"""Non-streaming: mot._thinking captures the raw ``reasoning`` key from vLLM."""
mot: ModelOutputThunk = ModelOutputThunk(value=None)
chunk = _vllm_chat_completion(reasoning="2 + 2 equals 4.", content="4")
# Sanity check: the SDK object does not expose reasoning_content
assert not hasattr(chunk.choices[0].message, "reasoning_content")

await backend.processing(mot, chunk)

assert mot._thinking == "2 + 2 equals 4."
assert mot._underlying_value == "4"


async def test_processing_vllm_reasoning_with_null_content(backend):
"""Non-streaming: reasoning is captured even when ``content`` is null."""
mot: ModelOutputThunk = ModelOutputThunk(value=None)
chunk = _vllm_chat_completion(reasoning="some thinking", content=None)

await backend.processing(mot, chunk)

assert mot._thinking == "some thinking"
assert mot._underlying_value == ""


async def test_processing_streaming_captures_vllm_reasoning_field(backend):
"""Streaming: per-chunk ``reasoning`` deltas accumulate into mot._thinking."""
mot: ModelOutputThunk = ModelOutputThunk(value=None)
chunk_a = ChatCompletionChunk.model_validate(
{
"id": "vllm-stream",
"created": 0,
"model": "qwen3",
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": None,
"reasoning": "first ",
},
"finish_reason": None,
}
],
}
)
chunk_b = ChatCompletionChunk.model_validate(
{
"id": "vllm-stream",
"created": 0,
"model": "qwen3",
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {"content": "ans", "reasoning": "second"},
"finish_reason": None,
}
],
}
)

await backend.processing(mot, chunk_a)
await backend.processing(mot, chunk_b)

assert mot._thinking == "first second"
assert mot._underlying_value == "ans"


async def test_processing_reasoning_content_still_used(backend):
"""Regression guard: the pre-existing ``reasoning_content`` path is preserved.

Some providers surface the trace as ``reasoning_content`` on the message
object itself. The fix must not regress that path in favour of the raw-dict
fallback.
"""
message = ChatCompletionMessage.model_validate(
{
"role": "assistant",
"content": "answer",
"reasoning_content": "attribute-style trace",
}
)
chunk = ChatCompletion(
id="rc-test",
created=0,
model="fake",
object="chat.completion",
choices=[Choice(index=0, finish_reason="stop", message=message)],
)
assert hasattr(chunk.choices[0].message, "reasoning_content")

mot: ModelOutputThunk = ModelOutputThunk(value=None)
await backend.processing(mot, chunk)

assert mot._thinking == "attribute-style trace"
assert mot._underlying_value == "answer"


async def test_processing_reasoning_content_takes_precedence_over_reasoning(backend):
"""reasoning_content attribute wins when both it and raw ``reasoning`` are present."""
message = ChatCompletionMessage.model_validate(
{
"role": "assistant",
"content": "answer",
"reasoning_content": "attr-trace",
"reasoning": "raw-trace",
}
)
chunk = ChatCompletion(
id="prec-test",
created=0,
model="fake",
object="chat.completion",
choices=[Choice(index=0, finish_reason="stop", message=message)],
)
mot: ModelOutputThunk = ModelOutputThunk(value=None)
await backend.processing(mot, chunk)

assert mot._thinking == "attr-trace"
assert mot._underlying_value == "answer"


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading