diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 1eea93511..63a0fc2b1 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -991,10 +991,13 @@ async def processing( if isinstance(chunk, ChatCompletion): message = chunk.choices[0].message - if hasattr(message, "reasoning_content"): - thinking_chunk = message.reasoning_content # type: ignore - if thinking_chunk is not None: - mot._thinking += thinking_chunk + # reasoning_content (Anthropic/DeepSeek attribute path) takes priority; + # fall back to the "reasoning" extra field used by vLLM and compatible servers. + thinking_chunk = getattr(message, "reasoning_content", None) + if thinking_chunk is None: + thinking_chunk = (message.model_extra or {}).get("reasoning") + if thinking_chunk is not None: + mot._thinking += thinking_chunk content_chunk = message.content if content_chunk is not None: @@ -1015,10 +1018,11 @@ async def processing( return message_delta = chunk.choices[0].delta - if hasattr(message_delta, "reasoning_content"): - thinking_chunk = message_delta.reasoning_content # type: ignore - if thinking_chunk is not None: - mot._thinking += thinking_chunk + thinking_chunk = getattr(message_delta, "reasoning_content", None) + if thinking_chunk is None: + thinking_chunk = (message_delta.model_extra or {}).get("reasoning") + if thinking_chunk is not None: + mot._thinking += thinking_chunk content_chunk = message_delta.content if content_chunk is not None: diff --git a/test/backends/test_openai_unit.py b/test/backends/test_openai_unit.py index 09524df8c..77a0745a4 100644 --- a/test/backends/test_openai_unit.py +++ b/test/backends/test_openai_unit.py @@ -5,9 +5,12 @@ """ import pytest +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice from mellea.backends import ModelOption from mellea.backends.openai import OpenAIBackend +from mellea.core.base import ModelOutputThunk def _make_backend(model_options: dict | None = None) -> OpenAIBackend: @@ -168,5 +171,145 @@ def test_make_backend_specific_unknown_mellea_keys_removed(backend): assert ModelOption.SYSTEM_PROMPT not in result +# --- processing(): reasoning / thinking trace extraction --- + + +def _vllm_chat_completion(reasoning: str, content: str | None) -> ChatCompletion: + """Build a ChatCompletion that matches vLLM's thinking-model response shape.""" + message = ChatCompletionMessage.model_validate( + {"role": "assistant", "content": content, "reasoning": reasoning} + ) + return ChatCompletion( + id="vllm-test", + created=0, + model="qwen3", + object="chat.completion", + choices=[Choice(index=0, finish_reason="stop", message=message)], + ) + + +async def test_processing_captures_vllm_reasoning_field(backend): + """Non-streaming: mot._thinking captures the raw ``reasoning`` key from vLLM.""" + mot: ModelOutputThunk = ModelOutputThunk(value=None) + chunk = _vllm_chat_completion(reasoning="2 + 2 equals 4.", content="4") + # Sanity check: the SDK object does not expose reasoning_content + assert not hasattr(chunk.choices[0].message, "reasoning_content") + + await backend.processing(mot, chunk) + + assert mot._thinking == "2 + 2 equals 4." + assert mot._underlying_value == "4" + + +async def test_processing_vllm_reasoning_with_null_content(backend): + """Non-streaming: reasoning is captured even when ``content`` is null.""" + mot: ModelOutputThunk = ModelOutputThunk(value=None) + chunk = _vllm_chat_completion(reasoning="some thinking", content=None) + + await backend.processing(mot, chunk) + + assert mot._thinking == "some thinking" + assert mot._underlying_value == "" + + +async def test_processing_streaming_captures_vllm_reasoning_field(backend): + """Streaming: per-chunk ``reasoning`` deltas accumulate into mot._thinking.""" + mot: ModelOutputThunk = ModelOutputThunk(value=None) + chunk_a = ChatCompletionChunk.model_validate( + { + "id": "vllm-stream", + "created": 0, + "model": "qwen3", + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": None, + "reasoning": "first ", + }, + "finish_reason": None, + } + ], + } + ) + chunk_b = ChatCompletionChunk.model_validate( + { + "id": "vllm-stream", + "created": 0, + "model": "qwen3", + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": {"content": "ans", "reasoning": "second"}, + "finish_reason": None, + } + ], + } + ) + + await backend.processing(mot, chunk_a) + await backend.processing(mot, chunk_b) + + assert mot._thinking == "first second" + assert mot._underlying_value == "ans" + + +async def test_processing_reasoning_content_still_used(backend): + """Regression guard: the pre-existing ``reasoning_content`` path is preserved. + + Some providers surface the trace as ``reasoning_content`` on the message + object itself. The fix must not regress that path in favour of the raw-dict + fallback. + """ + message = ChatCompletionMessage.model_validate( + { + "role": "assistant", + "content": "answer", + "reasoning_content": "attribute-style trace", + } + ) + chunk = ChatCompletion( + id="rc-test", + created=0, + model="fake", + object="chat.completion", + choices=[Choice(index=0, finish_reason="stop", message=message)], + ) + assert hasattr(chunk.choices[0].message, "reasoning_content") + + mot: ModelOutputThunk = ModelOutputThunk(value=None) + await backend.processing(mot, chunk) + + assert mot._thinking == "attribute-style trace" + assert mot._underlying_value == "answer" + + +async def test_processing_reasoning_content_takes_precedence_over_reasoning(backend): + """reasoning_content attribute wins when both it and raw ``reasoning`` are present.""" + message = ChatCompletionMessage.model_validate( + { + "role": "assistant", + "content": "answer", + "reasoning_content": "attr-trace", + "reasoning": "raw-trace", + } + ) + chunk = ChatCompletion( + id="prec-test", + created=0, + model="fake", + object="chat.completion", + choices=[Choice(index=0, finish_reason="stop", message=message)], + ) + mot: ModelOutputThunk = ModelOutputThunk(value=None) + await backend.processing(mot, chunk) + + assert mot._thinking == "attr-trace" + assert mot._underlying_value == "answer" + + if __name__ == "__main__": pytest.main([__file__, "-v"])