From 8128dfad567ac072fa6731c50ce4a361d428da7f Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 27 Apr 2026 15:48:17 +0100 Subject: [PATCH 01/26] feat(core): add cancel_generation() to ModelOutputThunk Adds an async cancel_generation() method that cancels in-progress _generate and _generate_extra tasks, drains the internal async queue to release any blocked put() calls, closes the open telemetry span, and sets _computed=True so the MOT is immediately usable. Required by the stream_with_chunking() orchestrator (#901) for clean early-exit when a streaming requirement returns "fail". Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/core/base.py | 58 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/mellea/core/base.py b/mellea/core/base.py index 2028008d9..95b6e8cdc 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -364,6 +364,64 @@ def _record_ttfb(self) -> None: ).total_seconds() * 1000 self._first_chunk_received = True + async def cancel_generation(self) -> None: + """Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span. + + Safe to call at any point during streaming. After this method returns, + ``is_computed()`` is ``True`` and ``value`` contains whatever text was + accumulated before cancellation. Calling on an already-computed MOT + is a no-op. + + Draining the internal queue after cancellation is necessary to release + any ``asyncio.Queue.put()`` call that the generation task was blocked on + (queue maxsize=20). + """ + if self._computed: + return + + def _drain() -> None: + while not self._async_queue.empty(): + try: + self._async_queue.get_nowait() + except asyncio.QueueEmpty: + break + + if self._generate is not None and not self._generate.done(): + self._generate.cancel() + + if self._generate_extra is not None and not self._generate_extra.done(): + self._generate_extra.cancel() + + # Drain before awaiting — unblocks any put() the task is stuck on. + _drain() + + if self._generate is not None: + try: + await self._generate + except (asyncio.CancelledError, Exception): + pass + + if self._generate_extra is not None: + try: + await self._generate_extra + except (asyncio.CancelledError, Exception): + pass + + # Drain again for any final item the task put before terminating. + _drain() + + span = self._meta.get("_telemetry_span") + if span is not None: + from ..telemetry import end_backend_span, set_span_error + + set_span_error(span, RuntimeError("Generation cancelled")) + end_backend_span(span) + del self._meta["_telemetry_span"] + + if self._underlying_value is None: + self._underlying_value = "" + self._computed = True + def _copy_from(self, other: ModelOutputThunk) -> None: """Copy computed-output fields from *other* into *self*. From f26cce729ff3d9b3b8e91952e6acb5dc2b50d9db Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 27 Apr 2026 15:48:34 +0100 Subject: [PATCH 02/26] feat(stdlib): add stream_with_chunking() with per-chunk validation (#901) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds stream_with_chunking() — the core streaming orchestration primitive that consumes a ModelOutputThunk's async stream via a background asyncio.Task, applies a ChunkingStrategy to produce semantic chunks, and runs stream_validate() in parallel across all requirements at each chunk boundary. Key behaviours: - Early exit: if any requirement returns "fail" during streaming, generation is cancelled immediately via cancel_generation() and StreamChunkingResult.completed is set to False. - Final validation: after natural completion, validate() is called on all non-failed requirements. - Clone-per-call: requirements are cloned (copy(req)) before each invocation; originals are never mutated. - String aliases: "sentence", "word", "paragraph" map to the corresponding ChunkingStrategy subclasses. StreamChunkingResult exposes: - astream() — async iterator yielding individual validated chunks - acomplete() — await full completion including final validation - as_thunk — wrap full_text as a computed ModelOutputThunk - completed, full_text, final_validations, streaming_failures Re-exports StreamChunkingResult and stream_with_chunking from mellea.stdlib for day-to-day use. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/stdlib/__init__.py | 15 +- mellea/stdlib/streaming.py | 282 +++++++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 mellea/stdlib/streaming.py diff --git a/mellea/stdlib/__init__.py b/mellea/stdlib/__init__.py index e4f32941b..7a30fdd53 100644 --- a/mellea/stdlib/__init__.py +++ b/mellea/stdlib/__init__.py @@ -10,9 +10,20 @@ ``mellea.stdlib.session`` — for day-to-day use. Streaming chunking strategies (for use with streaming validation) are available at -``mellea.stdlib.chunking`` and re-exported here for convenience. +``mellea.stdlib.chunking`` and re-exported here for convenience. The core streaming +orchestration primitive :func:`~mellea.stdlib.streaming.stream_with_chunking` and +its result type :class:`~mellea.stdlib.streaming.StreamChunkingResult` are also +re-exported here. """ from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker +from .streaming import StreamChunkingResult, stream_with_chunking -__all__ = ["ChunkingStrategy", "ParagraphChunker", "SentenceChunker", "WordChunker"] +__all__ = [ + "ChunkingStrategy", + "ParagraphChunker", + "SentenceChunker", + "StreamChunkingResult", + "WordChunker", + "stream_with_chunking", +] diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py new file mode 100644 index 000000000..0da2202ad --- /dev/null +++ b/mellea/stdlib/streaming.py @@ -0,0 +1,282 @@ +"""Streaming generation with per-chunk validation. + +Provides :func:`stream_with_chunking`, the core orchestration primitive that +consumes a streaming :class:`~mellea.core.base.ModelOutputThunk`, applies a +:class:`~mellea.stdlib.chunking.ChunkingStrategy` to produce semantic chunks, +and runs :meth:`~mellea.core.requirement.Requirement.stream_validate` on each +chunk in parallel. Higher-level streaming APIs build on this function. +""" + +import asyncio +from collections.abc import AsyncIterator, Sequence +from copy import copy +from typing import Any + +from ..backends.model_options import ModelOption +from ..core.backend import Backend +from ..core.base import CBlock, Component, Context, ModelOutputThunk +from ..core.requirement import PartialValidationResult, Requirement, ValidationResult +from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker + +_CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = { + "sentence": SentenceChunker, + "word": WordChunker, + "paragraph": ParagraphChunker, +} + + +class StreamChunkingResult: + """Result of a :func:`stream_with_chunking` operation. + + Provides async iteration over validated text chunks as they complete + (:meth:`astream`), a blocking :meth:`acomplete` for awaiting the full + result including final validation, and :attr:`as_thunk` for wrapping the + output as a :class:`~mellea.core.base.ModelOutputThunk`. + + Instances are created by :func:`stream_with_chunking`; do not instantiate + directly. + + Attributes: + completed: ``False`` if the stream exited early because a requirement + returned ``"fail"`` during streaming; ``True`` otherwise. + full_text: The complete generated text accumulated during streaming. + Available after :meth:`acomplete` returns. + final_validations: :class:`~mellea.core.requirement.ValidationResult` + objects from the final :meth:`~mellea.core.requirement.Requirement.validate` + calls on all non-failed requirements. Available after + :meth:`acomplete` returns. + streaming_failures: ``(Requirement, PartialValidationResult)`` pairs + for every requirement that returned ``"fail"`` during streaming. + """ + + def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: + """Initialise with the MOT and context from the backend call.""" + self._mot = mot + self._ctx = ctx + self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue() + self._orchestration_task: asyncio.Task[None] | None = None + self._done = asyncio.Event() + + self.completed: bool = True + self.full_text: str = "" + self.final_validations: list[ValidationResult] = [] + self.streaming_failures: list[tuple[Requirement, PartialValidationResult]] = [] + + async def astream(self) -> AsyncIterator[str]: + """Yield validated text chunks as they complete. + + Each yielded string is a chunk that has passed per-chunk streaming + validation (or the stream had no requirements). Iteration ends when + all chunks have been yielded, whether the stream completed normally or + was cancelled early on a ``"fail"`` result. + + Yields: + str: A validated text chunk from the chunking strategy. + + Raises: + Exception: Propagates any error from the background orchestration + task. + """ + while True: + item = await self._chunk_queue.get() + if item is None: + return + if isinstance(item, Exception): + raise item + yield item + + async def acomplete(self) -> None: + """Await full completion, including final validation. + + After this method returns, :attr:`full_text`, :attr:`completed`, + :attr:`final_validations`, and :attr:`streaming_failures` are all + populated. If :meth:`astream` has already been consumed to + exhaustion, this call is effectively a no-op. + + Raises: + Exception: Propagates any error from the orchestration task. + """ + await self._done.wait() + if self._orchestration_task is not None and self._orchestration_task.done(): + exc = self._orchestration_task.exception() + if exc is not None: + raise exc + + @property + def as_thunk(self) -> ModelOutputThunk: + """Wrap the output as a computed :class:`~mellea.core.base.ModelOutputThunk`. + + Returns a new thunk with ``value`` set to :attr:`full_text` and + generation metadata copied from the original MOT. Safe to call on + early-exit results; ``value`` will reflect whatever was accumulated + before cancellation. + + Returns: + ModelOutputThunk: A computed thunk containing the streamed output. + + Raises: + RuntimeError: If called before :meth:`acomplete` has returned. + """ + if not self._done.is_set(): + raise RuntimeError( + "as_thunk accessed before acomplete() — await acomplete() first" + ) + thunk = ModelOutputThunk(value=self.full_text) + thunk.generation = copy(self._mot.generation) + return thunk + + +async def _orchestrate_streaming( + result: StreamChunkingResult, + mot: ModelOutputThunk, + ctx: Context, + cloned_reqs: list[Requirement], + chunking: ChunkingStrategy, + val_backend: Backend, +) -> None: + accumulated = "" + prev_chunk_count = 0 + failed_indices: set[int] = set() + early_exit = False + + try: + while not mot.is_computed(): + try: + delta = await mot.astream() + except RuntimeError: + break + + accumulated += delta + chunks = chunking.split(accumulated) + new_chunks = chunks[prev_chunk_count:] + + if new_chunks: + active = [ + (i, req) + for i, req in enumerate(cloned_reqs) + if i not in failed_indices + ] + if active: + pvrs: list[PartialValidationResult] = list( + await asyncio.gather( + *[ + req.stream_validate( + accumulated, backend=val_backend, ctx=ctx + ) + for _, req in active + ] + ) + ) + for (idx, req), pvr in zip(active, pvrs): + if pvr.success == "fail": + failed_indices.add(idx) + result.streaming_failures.append((req, pvr)) + + if failed_indices: + early_exit = True + result.completed = False + await mot.cancel_generation() + for c in new_chunks: + await result._chunk_queue.put(c) + break + + for c in new_chunks: + await result._chunk_queue.put(c) + prev_chunk_count = len(chunks) + + result.full_text = accumulated + + non_failed = [ + req for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + if non_failed and not early_exit: + result.final_validations = list( + await asyncio.gather( + *[req.validate(val_backend, ctx) for req in non_failed] + ) + ) + + except Exception as exc: + await result._chunk_queue.put(exc) + finally: + await result._chunk_queue.put(None) + result._done.set() + + +async def stream_with_chunking( + action: Component[Any] | CBlock, + backend: Backend, + ctx: Context, + *, + quick_check_requirements: Sequence[Requirement] | None = None, + chunking: str | ChunkingStrategy = "sentence", + quick_check_backend: Backend | None = None, +) -> StreamChunkingResult: + """Generate a streaming response with per-chunk validation. + + Starts a backend generation with streaming enabled, consumes the + :class:`~mellea.core.base.ModelOutputThunk`'s async stream in a single + background task, splits the accumulated text using *chunking*, and runs + :meth:`~mellea.core.requirement.Requirement.stream_validate` on each new + chunk in parallel across all requirements. + + If any requirement returns ``"fail"`` during streaming validation, the + generation is cancelled immediately (via + :meth:`~mellea.core.base.ModelOutputThunk.cancel_generation`) and + :attr:`StreamChunkingResult.completed` is set to ``False``. + + After the stream ends (naturally or via early exit), ``validate()`` is + called on all requirements that did not return ``"fail"``. Requirements + are cloned (``copy(req)``) before use so originals are never mutated. + + ``stream_validate`` receives the *accumulated* model output so far, not + just the current chunk. The chunking strategy determines *when* it is + called (at chunk boundaries). Requirements that want delta-only + processing track ``self._seen_len`` and slice + ``accumulated[self._seen_len:]``. + + Note: + v1 retry is simple re-invocation of this function. Plugin hooks + (``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire + on retries — use the ``#902`` event types for observability instead. + + Args: + action: The component or content block to generate from. + backend: The backend used for generation and final validation. + ctx: The generation context. + quick_check_requirements: Sequence of requirements to validate against + each chunk during streaming. ``None`` disables streaming validation + (chunks are still produced; ``validate()`` is not called at stream end). + chunking: Chunking strategy — either a :class:`~mellea.stdlib.chunking.ChunkingStrategy` + instance or one of the string aliases ``"sentence"`` (default), + ``"word"``, or ``"paragraph"``. + quick_check_backend: Optional alternate backend for both + ``stream_validate`` and final ``validate`` calls. When ``None``, + *backend* is used for validation. + + Returns: + StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream` + for incremental chunk consumption and + :meth:`~StreamChunkingResult.acomplete` for blocking until done. + """ + if isinstance(chunking, str): + cls = _CHUNKING_ALIASES.get(chunking) + if cls is None: + raise ValueError( + f"Unknown chunking alias {chunking!r}. Choose from: {list(_CHUNKING_ALIASES)}" + ) + chunking = cls() + + opts: dict[str, Any] = {ModelOption.STREAM: True} + mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) + + result = StreamChunkingResult(mot, gen_ctx) + + cloned_reqs = [copy(req) for req in (quick_check_requirements or [])] + val_backend = quick_check_backend if quick_check_backend is not None else backend + + result._orchestration_task = asyncio.create_task( + _orchestrate_streaming(result, mot, gen_ctx, cloned_reqs, chunking, val_backend) + ) + + return result From 93e75878c7b2774f446cfd66cef47e56ac6a73d8 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 27 Apr 2026 15:48:48 +0100 Subject: [PATCH 03/26] test(stdlib): add StreamingMockBackend and streaming orchestration tests Adds test/stdlib/test_streaming.py with 9 unit tests covering: - Normal completion: validate() called at stream end, completed=True - Early exit on "fail": completed=False, streaming_failures populated - Clone isolation: originals never mutated across retries - quick_check_backend routing: validation uses alternate backend - Deadlock prevention: early exit with asyncio.wait_for timeout - as_thunk correctness: value=full_text, raises before acomplete() - astream() yields individual chunks (not accumulated text) - No requirements: streams without validation StreamingMockBackend subclasses Backend and feeds a fixed response string into a MOT queue char-by-char via asyncio.create_task, following the create_manual_mock_thunk() pattern from test_astream_mock.py. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- test/stdlib/test_streaming.py | 384 ++++++++++++++++++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 test/stdlib/test_streaming.py diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py new file mode 100644 index 000000000..9284ec1c9 --- /dev/null +++ b/test/stdlib/test_streaming.py @@ -0,0 +1,384 @@ +"""Tests for stream_with_chunking() and StreamChunkingResult. + +Uses StreamingMockBackend — a deterministic test double that feeds tokens from a +fixed response string into a MOT queue without network or LLM calls. + +All tests are unit tests (no @pytest.mark.ollama needed). +""" + +import asyncio +from typing import Any + +import pytest + +from mellea.core.backend import Backend +from mellea.core.base import CBlock, Context, GenerateType, ModelOutputThunk +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.streaming import stream_with_chunking + +# --------------------------------------------------------------------------- +# StreamingMockBackend +# --------------------------------------------------------------------------- + + +async def _mock_process(mot: ModelOutputThunk, chunk: Any) -> None: + if mot._underlying_value is None: + mot._underlying_value = "" + if chunk is not None: + mot._underlying_value += chunk + + +async def _mock_post_process(_mot: ModelOutputThunk) -> None: + pass + + +def _make_mot() -> ModelOutputThunk: + mot = ModelOutputThunk(value=None) + mot._action = CBlock("mock_action") + mot._generate_type = GenerateType.ASYNC + mot._process = _mock_process + mot._post_process = _mock_post_process + mot._chunk_size = 0 + return mot + + +async def _feed_tokens(mot: ModelOutputThunk, response: str, token_size: int) -> None: + i = 0 + while i < len(response): + token = response[i : i + token_size] + await mot._async_queue.put(token) + await asyncio.sleep(0) + i += token_size + await mot._async_queue.put(None) + + +class StreamingMockBackend(Backend): + """Test double that streams a fixed response one token at a time. + + ``token_size`` controls how many characters constitute one token. + Validation calls (via ``stream_validate`` / ``validate``) are delegated + to the requirements themselves — this backend does not perform any real + inference. + """ + + def __init__(self, response: str, token_size: int = 1) -> None: + self._response = response + self._token_size = token_size + + async def _generate_from_context( + self, + action: Any, + ctx: Context, + *, + format: Any = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk, Context]: + _ = format, model_options, tool_calls + mot = _make_mot() + task = asyncio.create_task(_feed_tokens(mot, self._response, self._token_size)) + _ = task + new_ctx = ctx.add(action).add(mot) + return mot, new_ctx + + async def generate_from_raw( + self, actions: Any, ctx: Any, **kwargs: Any + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Requirement test doubles +# --------------------------------------------------------------------------- + + +class AlwaysUnknownReq(Requirement): + """stream_validate always returns 'unknown'; validate returns True.""" + + def format_for_llm(self) -> str: + return "always unknown" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class FailAfterWordsReq(Requirement): + """Returns 'fail' once the accumulated text reaches *threshold* words.""" + + def __init__(self, threshold: int) -> None: + self._threshold = threshold + + def format_for_llm(self) -> str: + return f"fail after {self._threshold} words" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + if len(chunk.split()) >= self._threshold: + return PartialValidationResult("fail", reason="too many words") + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class BackendRecordingReq(Requirement): + """Records which backend was passed to stream_validate and validate.""" + + def __init__(self) -> None: + self.seen_backends: list[Any] = [] + + def __copy__(self) -> "BackendRecordingReq": + clone = BackendRecordingReq() + clone.seen_backends = [] # fresh list — do not share with original + return clone + + def format_for_llm(self) -> str: + return "backend recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk + self.seen_backends.append(backend) + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + self.seen_backends.append(backend) + return ValidationResult(result=True) + + +class MutationDetectorReq(Requirement): + """Tracks how many times stream_validate was called on this instance.""" + + def __init__(self) -> None: + self._call_count = 0 + + def format_for_llm(self) -> str: + return "mutation detector" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + self._call_count += 1 + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _ctx() -> SimpleContext: + return SimpleContext() + + +def _action() -> CBlock: + return CBlock("prompt") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normal_completion_calls_validate_at_stream_end() -> None: + """All 'unknown' requirements → validate() called at stream end; completed=True.""" + response = "Hello world. How are you. " + backend = StreamingMockBackend(response, token_size=3) + req = AlwaysUnknownReq() + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + await result.acomplete() + + assert result.completed is True + assert result.full_text == response + assert len(result.final_validations) == 1 + assert result.final_validations[0].as_bool() is True + assert result.streaming_failures == [] + + +@pytest.mark.asyncio +async def test_early_exit_on_fail() -> None: + """Requirement fails mid-stream → completed=False, streaming_failures populated.""" + # 5 words to trigger failure + response = "one two three four five six seven eight. " + backend = StreamingMockBackend(response, token_size=2) + req = FailAfterWordsReq(threshold=4) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" + ) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + _req, pvr = result.streaming_failures[0] + assert pvr.success == "fail" + assert pvr.reason == "too many words" + # final_validations should be empty — final validate() skipped on early exit + assert result.final_validations == [] + + +@pytest.mark.asyncio +async def test_clone_isolation_across_retries() -> None: + """Originals must not be mutated; two invocations are independent.""" + response = "Sentence one. Sentence two. " + req = MutationDetectorReq() + original_reqs = [req] + + backend = StreamingMockBackend(response, token_size=4) + + r1 = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=original_reqs, + chunking="sentence", + ) + await r1.acomplete() + + r2 = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=original_reqs, + chunking="sentence", + ) + await r2.acomplete() + + # Original requirement must never have been called — only clones are used + assert req._call_count == 0 + + +@pytest.mark.asyncio +async def test_quick_check_backend_routing() -> None: + """stream_validate and validate receive quick_check_backend, not the main backend.""" + response = "One sentence. Two sentences. " + main_backend = StreamingMockBackend(response, token_size=3) + val_backend = StreamingMockBackend("unused", token_size=1) + + req = BackendRecordingReq() + + result = await stream_with_chunking( + _action(), + main_backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + quick_check_backend=val_backend, + ) + await result.acomplete() + + # The clone's seen_backends should only contain val_backend + # (The original req was never called; clones were.) + # Verify via final_validations side-effect: at least one backend recorded + assert result.completed is True + # The original req._seen_backends is untouched (clone isolation) + assert req.seen_backends == [] + + +@pytest.mark.asyncio +async def test_early_exit_does_not_deadlock() -> None: + """Early failure with a high-throughput stream must not hang.""" + long_response = "word " * 200 + backend = StreamingMockBackend(long_response, token_size=5) + req = FailAfterWordsReq(threshold=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" + ) + # 5-second timeout — should complete in milliseconds on success + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + assert result.completed is False + + +@pytest.mark.asyncio +async def test_as_thunk_correctness() -> None: + """as_thunk is computed, value matches full_text, generation metadata preserved.""" + response = "This is a test sentence. " + backend = StreamingMockBackend(response, token_size=4) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + await result.acomplete() + + thunk = result.as_thunk + assert thunk.is_computed() + assert thunk.value == result.full_text == response + + +@pytest.mark.asyncio +async def test_as_thunk_raises_before_acomplete() -> None: + """as_thunk raises RuntimeError if accessed before acomplete().""" + response = "Some text. " + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + + with pytest.raises(RuntimeError, match="acomplete"): + _ = result.as_thunk + + +@pytest.mark.asyncio +async def test_astream_yields_individual_chunks() -> None: + """Consumer via astream() receives individual chunks, not accumulated text.""" + response = "First sentence. Second sentence. Third sentence. " + backend = StreamingMockBackend(response, token_size=5) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + + chunks: list[str] = [] + async for chunk in result.astream(): + chunks.append(chunk) + + await result.acomplete() + + # Each chunk must be a complete sentence (not the accumulated text) + assert len(chunks) == 3 + for chunk in chunks: + assert chunk.endswith(".") + # Chunks don't include inter-sentence spaces; joined with a space they appear in full_text + assert " ".join(chunks) in result.full_text + + +@pytest.mark.asyncio +async def test_no_requirements_streams_without_validation() -> None: + """quick_check_requirements=None → chunks produced, no validate() called.""" + response = "Chunk one. Chunk two. Chunk three. " + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=None, chunking="sentence" + ) + await result.acomplete() + + assert result.completed is True + assert result.full_text == response + assert result.final_validations == [] + assert result.streaming_failures == [] From a5d358c970bc405aa8ca1ea0f277f42bc8d5a3d2 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 27 Apr 2026 15:49:19 +0100 Subject: [PATCH 04/26] docs: add streaming_chunking example (#901) Adds docs/examples/streaming/streaming_chunking.py demonstrating stream_with_chunking() end-to-end: defining a custom stream_validate() override, consuming chunks via astream(), and awaiting acomplete() to inspect final_validations and streaming_failures. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- docs/examples/streaming/streaming_chunking.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 docs/examples/streaming/streaming_chunking.py diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py new file mode 100644 index 000000000..c11bceb24 --- /dev/null +++ b/docs/examples/streaming/streaming_chunking.py @@ -0,0 +1,91 @@ +# pytest: ollama, integration + +"""Streaming generation with per-chunk validation using stream_with_chunking(). + +Demonstrates: +- Subclassing Requirement to override stream_validate() for early-exit checks +- Calling stream_with_chunking() with sentence-level chunking +- Consuming validated chunks via astream() as they arrive +- Awaiting full completion with acomplete() to access final_validations and full_text +""" + +import asyncio + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import stream_with_chunking + + +class MaxSentencesReq(Requirement): + """Fails if the model generates more than *limit* sentences mid-stream.""" + + def __init__(self, limit: int) -> None: + self._limit = limit + self._count = 0 + + def format_for_llm(self) -> str: + return f"The response must be at most {self._limit} sentences long." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + sentence_count = chunk.count(".") + chunk.count("!") + chunk.count("?") + if sentence_count > self._limit: + return PartialValidationResult( + "fail", + reason=f"Response exceeded {self._limit} sentence limit mid-stream", + ) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + backend = m.backend + ctx = m.ctx + + action = Instruction( + "Write a short paragraph about the water cycle in exactly two sentences." + ) + req = MaxSentencesReq(limit=3) + + result = await stream_with_chunking( + action, backend, ctx, quick_check_requirements=[req], chunking="sentence" + ) + + print("Streaming chunks as they arrive:") + async for chunk in result.astream(): + print(f" CHUNK: {chunk!r}") + + await result.acomplete() + + print(f"\nCompleted normally: {result.completed}") + print(f"Full text: {result.full_text!r}") + + if result.streaming_failures: + for _req, pvr in result.streaming_failures: + print(f"Streaming failure: {pvr.reason}") + + if result.final_validations: + for vr in result.final_validations: + print(f"Final validation: {'PASS' if vr.as_bool() else 'FAIL'}") + + +asyncio.run(main()) From 39f18a4eb6ee61fb43f44d1a07bf5e423ad2a40e Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 09:40:48 +0100 Subject: [PATCH 05/26] docs(stdlib): add Args section to StreamChunkingResult class docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes [no_class_args] CI failure — the docs build-and-validate checker requires __init__ parameters to be documented in the class docstring (not __init__) per Option C convention. Assisted-by: Claude Code --- mellea/stdlib/streaming.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 0da2202ad..685378511 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -36,6 +36,11 @@ class StreamChunkingResult: Instances are created by :func:`stream_with_chunking`; do not instantiate directly. + Args: + mot: The :class:`~mellea.core.base.ModelOutputThunk` from the backend + generation call. + ctx: The generation context returned alongside the MOT. + Attributes: completed: ``False`` if the stream exited early because a requirement returned ``"fail"`` during streaming; ``True`` otherwise. From 36173cb839f03e9d14f20f51db8996efd9c6fa89 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 09:56:37 +0100 Subject: [PATCH 06/26] docs(stdlib): add Raises section to stream_with_chunking docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes second [no_raises] CI failure — stream_with_chunking raises ValueError for unknown chunking aliases but had no Raises: section. Assisted-by: Claude Code --- mellea/stdlib/streaming.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 685378511..425ab4b3b 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -263,6 +263,10 @@ async def stream_with_chunking( StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream` for incremental chunk consumption and :meth:`~StreamChunkingResult.acomplete` for blocking until done. + + Raises: + ValueError: If *chunking* is a string that does not match any known + alias (``"sentence"``, ``"word"``, ``"paragraph"``). """ if isinstance(chunking, str): cls = _CHUNKING_ALIASES.get(chunking) From ea6bdb077d8a9084a6b386344a9f95b3addd8906 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 17:07:29 +0100 Subject: [PATCH 07/26] fix(stdlib): stream_with_chunking passes one chunk per stream_validate call Aligns the orchestrator with the chunk-at-a-time design set out in the #891 epic and #900 spec. Previously _orchestrate_streaming passed the full accumulated text to stream_validate, and called it once per batch of new chunks rather than once per chunk. This masked the design intent of the chunking strategy and forced stateful requirements into the self._seen_len workaround. Behaviour changes: - stream_validate is called once per complete chunk produced by the chunking strategy (not once per astream() iteration) - The call receives that single chunk (not the accumulated text) - Multiple chunks from one astream() iteration are validated in order; early exit on a "fail" prevents later chunks in the same batch from being validated or emitted - On early exit, the failing chunk is no longer emitted to the consumer; consumers inspect StreamChunkingResult.streaming_failures instead (previous behaviour emitted whatever the current batch contained) Test changes: - FailAfterWordsReq now maintains a running word count on self, since each stream_validate call sees a one-word chunk rather than the growing accumulation - New test_stream_validate_receives_individual_chunks asserts the per-chunk contract directly by capturing the cloned requirement and checking the chunks it saw Docstring updated to describe the per-chunk contract, the in-order validation of a batch, the non-emission of failing chunks, and the MOT single-consumer constraint. Assisted-by: Claude Code --- mellea/stdlib/streaming.py | 41 ++++++++++-------- test/stdlib/test_streaming.py | 80 ++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 19 deletions(-) diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 425ab4b3b..1e5aca985 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -154,8 +154,9 @@ async def _orchestrate_streaming( accumulated += delta chunks = chunking.split(accumulated) new_chunks = chunks[prev_chunk_count:] + prev_chunk_count = len(chunks) - if new_chunks: + for c in new_chunks: active = [ (i, req) for i, req in enumerate(cloned_reqs) @@ -165,9 +166,7 @@ async def _orchestrate_streaming( pvrs: list[PartialValidationResult] = list( await asyncio.gather( *[ - req.stream_validate( - accumulated, backend=val_backend, ctx=ctx - ) + req.stream_validate(c, backend=val_backend, ctx=ctx) for _, req in active ] ) @@ -181,13 +180,12 @@ async def _orchestrate_streaming( early_exit = True result.completed = False await mot.cancel_generation() - for c in new_chunks: - await result._chunk_queue.put(c) break - for c in new_chunks: - await result._chunk_queue.put(c) - prev_chunk_count = len(chunks) + await result._chunk_queue.put(c) + + if early_exit: + break result.full_text = accumulated @@ -225,20 +223,29 @@ async def stream_with_chunking( :meth:`~mellea.core.requirement.Requirement.stream_validate` on each new chunk in parallel across all requirements. - If any requirement returns ``"fail"`` during streaming validation, the - generation is cancelled immediately (via + For each new complete chunk produced by the chunking strategy, + ``stream_validate`` is called once per active requirement (in parallel + via :func:`asyncio.gather`), receiving that single chunk. Multiple + chunks produced from one ``astream()`` iteration are validated + sequentially in order, so early exit on a ``"fail"`` result prevents + later chunks in the same batch from being validated or emitted to the + consumer. + + If any requirement returns ``"fail"``, the generation is cancelled + immediately (via :meth:`~mellea.core.base.ModelOutputThunk.cancel_generation`) and - :attr:`StreamChunkingResult.completed` is set to ``False``. + :attr:`StreamChunkingResult.completed` is set to ``False``. The + failing chunk is not emitted to the consumer; use + :attr:`StreamChunkingResult.streaming_failures` to inspect what failed. After the stream ends (naturally or via early exit), ``validate()`` is called on all requirements that did not return ``"fail"``. Requirements are cloned (``copy(req)``) before use so originals are never mutated. - ``stream_validate`` receives the *accumulated* model output so far, not - just the current chunk. The chunking strategy determines *when* it is - called (at chunk boundaries). Requirements that want delta-only - processing track ``self._seen_len`` and slice - ``accumulated[self._seen_len:]``. + Requirements that need context beyond the current chunk should + accumulate it themselves across ``stream_validate`` calls (e.g. + ``self._seen = self._seen + chunk``). They must not read ``mot.astream()`` + directly — this orchestrator is the single consumer of the MOT stream. Note: v1 retry is simple re-invocation of this function. Plugin hooks diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 9284ec1c9..bd03e5ffb 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -115,10 +115,15 @@ async def validate( class FailAfterWordsReq(Requirement): - """Returns 'fail' once the accumulated text reaches *threshold* words.""" + """Returns 'fail' once the cumulative word count reaches *threshold*. + + Each call to ``stream_validate`` receives a single chunk (delta) from the + chunking strategy; the running total is maintained on the instance. + """ def __init__(self, threshold: int) -> None: self._threshold = threshold + self._word_count = 0 def format_for_llm(self) -> str: return f"fail after {self._threshold} words" @@ -126,7 +131,8 @@ def format_for_llm(self) -> str: async def stream_validate( self, chunk: str, *, backend: Any, ctx: Any ) -> PartialValidationResult: - if len(chunk.split()) >= self._threshold: + self._word_count += len(chunk.split()) + if self._word_count >= self._threshold: return PartialValidationResult("fail", reason="too many words") return PartialValidationResult("unknown") @@ -367,6 +373,76 @@ async def test_astream_yields_individual_chunks() -> None: assert " ".join(chunks) in result.full_text +@pytest.mark.asyncio +async def test_stream_validate_receives_individual_chunks() -> None: + """stream_validate is called once per chunk with the chunk itself, not accumulated text.""" + + class ChunkRecordingReq(Requirement): + def __init__(self) -> None: + self.seen_chunks: list[str] = [] + + def __copy__(self) -> "ChunkRecordingReq": + clone = ChunkRecordingReq() + clone.seen_chunks = [] + return clone + + def format_for_llm(self) -> str: + return "chunk recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self.seen_chunks.append(chunk) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "First sentence. Second sentence. Third sentence. " + backend = StreamingMockBackend(response, token_size=4) + req = ChunkRecordingReq() + + # Capture the cloned requirement used by the orchestrator via a side channel. + captured: list[ChunkRecordingReq] = [] + original_copy = ChunkRecordingReq.__copy__ + + def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + ) + await result.acomplete() + finally: + ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + assert len(captured) == 1 + seen = captured[0].seen_chunks + # Three complete sentences → three separate stream_validate calls. + assert len(seen) == 3 + # Each chunk is one sentence, not a prefix of accumulated text. + for chunk in seen: + assert chunk.endswith(".") + # Lengths must not be monotonically growing (which would indicate accumulated text). + # With per-chunk semantics, each chunk is roughly the same length as one sentence. + assert not all(len(seen[i]) < len(seen[i + 1]) for i in range(len(seen) - 1)) + + @pytest.mark.asyncio async def test_no_requirements_streams_without_validation() -> None: """quick_check_requirements=None → chunks produced, no validate() called.""" From 35df77f0ed3dcba7bdd6d793af2288d60e470d26 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 17:27:30 +0100 Subject: [PATCH 08/26] docs(stdlib): fix example for delta semantics and note validator latency Two documentation fixes following the per-chunk semantics correction: - streaming_chunking.py: MaxSentencesReq previously counted sentence-end punctuation in the chunk, which worked under the old accumulated-text behaviour but returns at most 1 per sentence under delta semantics. Rewritten to increment self._count once per chunk -- the canonical pattern for a requirement that needs context beyond a single chunk. - stream_with_chunking docstring: add a Note that chunks are emitted to the consumer only after every active validator returns for that chunk. A slow stream_validate (e.g. an LLM-based one) therefore adds latency to every chunk. The invariant preserved is that the consumer never sees unvalidated content; a concurrent-emission fast path may be added in future if a concrete use case calls for it. Assisted-by: Claude Code --- docs/examples/streaming/streaming_chunking.py | 12 +++++++++--- mellea/stdlib/streaming.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py index c11bceb24..70037d0a1 100644 --- a/docs/examples/streaming/streaming_chunking.py +++ b/docs/examples/streaming/streaming_chunking.py @@ -23,7 +23,13 @@ class MaxSentencesReq(Requirement): - """Fails if the model generates more than *limit* sentences mid-stream.""" + """Fails if the model generates more than *limit* sentences mid-stream. + + Each ``stream_validate`` call receives one complete sentence from the + :class:`~mellea.stdlib.chunking.SentenceChunker`. The running count is + maintained on ``self`` — this is the standard pattern for requirements + that need context beyond a single chunk. + """ def __init__(self, limit: int) -> None: self._limit = limit @@ -35,8 +41,8 @@ def format_for_llm(self) -> str: async def stream_validate( self, chunk: str, *, backend: Backend, ctx: Context ) -> PartialValidationResult: - sentence_count = chunk.count(".") + chunk.count("!") + chunk.count("?") - if sentence_count > self._limit: + self._count += 1 + if self._count > self._limit: return PartialValidationResult( "fail", reason=f"Response exceeded {self._limit} sentence limit mid-stream", diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 1e5aca985..44cec5201 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -247,6 +247,19 @@ async def stream_with_chunking( ``self._seen = self._seen + chunk``). They must not read ``mot.astream()`` directly — this orchestrator is the single consumer of the MOT stream. + Note: + Chunks are emitted to the consumer (via + :meth:`StreamChunkingResult.astream`) only after every requirement's + ``stream_validate`` has returned for that chunk. A slow validator + (for example, one that invokes an LLM) therefore adds latency to + every chunk — the consumer sees a chunk at most as quickly as the + slowest active validator. This trade is deliberate in v1: it + preserves the invariant that the consumer never sees content that + has not been validated, which matters for UIs displaying generated + text live. A future fast-path mode that emits chunks to the + consumer concurrently with validation (at the cost of that + invariant) may be added if a concrete use case calls for it. + Note: v1 retry is simple re-invocation of this function. Plugin hooks (``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire From 61448a90f49a5a7b8a9ab8840e19fac77b14c9fa Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 17:48:37 +0100 Subject: [PATCH 09/26] feat(stdlib): flush trailing chunk fragment at end of stream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ChunkingStrategy.split() withholds the trailing fragment by design (#899). Previously the orchestrator discarded it — it appeared in full_text and the final validate() saw it, but it was never yielded to astream() consumers and never seen by stream_validate. For a response that did not end in a chunk terminator (e.g. "Sentence one. Sentence two." with no trailing whitespace under SentenceChunker), the last sentence silently bypassed streaming validation. Adds ChunkingStrategy.flush(accumulated_text) -> list[str]: - Default in the ABC returns [] (backward-compatible — external chunkers retain the old discard behaviour until they opt in). - SentenceChunker, WordChunker, ParagraphChunker each override to return the withheld trailing fragment as a single-element list. _orchestrate_streaming calls chunking.flush(accumulated) after the main loop (only when the stream ended naturally, not on early exit — a cancelled stream's trailing fragment is by definition incomplete). Each flushed chunk goes through the same stream_validate / emit path as regular chunks, so the "no unvalidated content reaches the consumer" invariant extends to the trailing fragment, and a fail on the fragment still records a streaming failure and skips final validate(). Tests: - 13 new chunker tests covering the default-discard behaviour and each built-in's flush logic (empty input, fragment-present, already- terminated cases). - test_trailing_fragment_is_flushed_to_consumer: stream_validate sees the fragment and astream yields it. - test_early_exit_on_trailing_fragment: fail on the flushed fragment propagates to streaming_failures and skips final validation. Assisted-by: Claude Code --- mellea/stdlib/chunking.py | 56 ++++++++++++++++ mellea/stdlib/streaming.py | 70 +++++++++++++------- test/stdlib/test_chunking.py | 76 ++++++++++++++++++++++ test/stdlib/test_streaming.py | 118 ++++++++++++++++++++++++++++++++++ 4 files changed, 298 insertions(+), 22 deletions(-) diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index 6b9091780..d4a5f79e1 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -35,6 +35,27 @@ def split(self, accumulated_text: str) -> list[str]: """ ... + def flush(self, accumulated_text: str) -> list[str]: + """Return any trailing fragment that ``split`` withheld. + + Called once by the orchestrator after the stream has ended naturally + (not on early-exit cancellation). Gives the chunker a chance to + release the final fragment that did not reach a terminator. + + The default implementation returns an empty list — the trailing + fragment is discarded. Built-in chunkers override this to return + the withheld fragment as a single-element list when non-empty. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + The trailing fragment as ``[fragment]`` if it should be treated + as a final chunk, or an empty list to discard it. + """ + _ = accumulated_text + return [] + # Sentence boundary: sentence-ending punctuation, optionally followed by a closing # quote or paren, then whitespace. @@ -94,6 +115,19 @@ def split(self, accumulated_text: str) -> list[str]: return chunks + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing sentence fragment (if any) as a final chunk.""" + if not accumulated_text: + return [] + remaining = accumulated_text + while True: + match = _SENTENCE_BOUNDARY.search(remaining) + if match is None: + break + remaining = remaining[match.end() :].lstrip() + trailing = remaining.strip() + return [trailing] if trailing else [] + class WordChunker(ChunkingStrategy): """Splits accumulated text on whitespace boundaries. @@ -134,6 +168,18 @@ def split(self, accumulated_text: str) -> list[str]: return parts + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing word fragment (if any) as a final chunk.""" + if not accumulated_text: + return [] + if accumulated_text[-1].isspace(): + return [] + parts = _WHITESPACE.split(accumulated_text) + for part in reversed(parts): + if part: + return [part] + return [] + class ParagraphChunker(ChunkingStrategy): r"""Splits accumulated text on double-newline paragraph boundaries. @@ -168,3 +214,13 @@ def split(self, accumulated_text: str) -> list[str]: # _PARA_BOUNDARY.split on leading \n\n produces an empty first element. return [p for p in parts if p] + + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing paragraph fragment (if any) as a final chunk.""" + if not accumulated_text: + return [] + if _PARA_BOUNDARY_END.search(accumulated_text): + return [] + parts = _PARA_BOUNDARY.split(accumulated_text) + trailing = parts[-1] if parts else "" + return [trailing] if trailing else [] diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 44cec5201..0346fb519 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -144,6 +144,35 @@ async def _orchestrate_streaming( failed_indices: set[int] = set() early_exit = False + async def _validate_and_emit(c: str) -> bool: + """Run stream_validate on chunk c across active requirements. + + Returns True if a failure was recorded (caller should early-exit), + False otherwise (chunk was emitted to the consumer queue). + """ + active = [ + (i, req) for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + if active: + pvrs: list[PartialValidationResult] = list( + await asyncio.gather( + *[ + req.stream_validate(c, backend=val_backend, ctx=ctx) + for _, req in active + ] + ) + ) + for (idx, req), pvr in zip(active, pvrs): + if pvr.success == "fail": + failed_indices.add(idx) + result.streaming_failures.append((req, pvr)) + + if failed_indices: + return True + + await result._chunk_queue.put(c) + return False + try: while not mot.is_computed(): try: @@ -157,36 +186,27 @@ async def _orchestrate_streaming( prev_chunk_count = len(chunks) for c in new_chunks: - active = [ - (i, req) - for i, req in enumerate(cloned_reqs) - if i not in failed_indices - ] - if active: - pvrs: list[PartialValidationResult] = list( - await asyncio.gather( - *[ - req.stream_validate(c, backend=val_backend, ctx=ctx) - for _, req in active - ] - ) - ) - for (idx, req), pvr in zip(active, pvrs): - if pvr.success == "fail": - failed_indices.add(idx) - result.streaming_failures.append((req, pvr)) - - if failed_indices: + failed = await _validate_and_emit(c) + if failed: early_exit = True result.completed = False await mot.cancel_generation() break - await result._chunk_queue.put(c) - if early_exit: break + # Stream ended naturally: flush any withheld trailing fragment and + # run stream_validate on it. Skipped on early exit — the generation + # was cancelled, the trailing fragment is incomplete. + if not early_exit: + for c in chunking.flush(accumulated): + failed = await _validate_and_emit(c) + if failed: + early_exit = True + result.completed = False + break + result.full_text = accumulated non_failed = [ @@ -238,6 +258,12 @@ async def stream_with_chunking( failing chunk is not emitted to the consumer; use :attr:`StreamChunkingResult.streaming_failures` to inspect what failed. + When the stream ends naturally, any trailing fragment withheld by the + chunking strategy (see :meth:`~mellea.stdlib.chunking.ChunkingStrategy.flush`) + is released as a final chunk and run through ``stream_validate`` on the + same terms as the regular chunks. On early exit, the trailing fragment + is discarded because the generation was cancelled mid-token. + After the stream ends (naturally or via early exit), ``validate()`` is called on all requirements that did not return ``"fail"``. Requirements are cloned (``copy(req)``) before use so originals are never mutated. diff --git a/test/stdlib/test_chunking.py b/test/stdlib/test_chunking.py index fbaf727a2..7b965350f 100644 --- a/test/stdlib/test_chunking.py +++ b/test/stdlib/test_chunking.py @@ -242,3 +242,79 @@ def test_paragraph_chunker_incremental_simulation(): "First paragraph.", "Second paragraph.", ] + + +# --------------------------------------------------------------------------- +# flush() — trailing-fragment release at end of stream +# --------------------------------------------------------------------------- + + +def test_default_flush_returns_empty_list(): + """The ABC default discards the trailing fragment.""" + + class Minimal(ChunkingStrategy): + def split(self, accumulated_text: str) -> list[str]: + _ = accumulated_text + return [] + + assert Minimal().flush("anything at all") == [] + assert Minimal().flush("") == [] + + +def test_sentence_chunker_flush_empty(): + assert SentenceChunker().flush("") == [] + + +def test_sentence_chunker_flush_only_complete(): + """All text ends in a complete sentence with trailing whitespace → no fragment.""" + assert SentenceChunker().flush("One. Two. ") == [] + + +def test_sentence_chunker_flush_trailing_fragment(): + """Final sentence without trailing whitespace is released by flush.""" + assert SentenceChunker().flush("One. Two without period") == ["Two without period"] + + +def test_sentence_chunker_flush_terminated_no_trailing_space(): + """Final sentence with terminator but no trailing whitespace is a fragment + under split() semantics and gets released by flush().""" + assert SentenceChunker().flush("One. Two.") == ["Two."] + + +def test_sentence_chunker_flush_single_sentence_no_terminator(): + assert SentenceChunker().flush("Incomplete sentence") == ["Incomplete sentence"] + + +def test_word_chunker_flush_empty(): + assert WordChunker().flush("") == [] + + +def test_word_chunker_flush_trailing_whitespace(): + """Trailing whitespace means all words are complete → no fragment.""" + assert WordChunker().flush("one two three ") == [] + + +def test_word_chunker_flush_trailing_fragment(): + assert WordChunker().flush("one two three") == ["three"] + + +def test_word_chunker_flush_single_word(): + assert WordChunker().flush("solo") == ["solo"] + + +def test_paragraph_chunker_flush_empty(): + assert ParagraphChunker().flush("") == [] + + +def test_paragraph_chunker_flush_only_complete(): + assert ParagraphChunker().flush("Para one.\n\nPara two.\n\n") == [] + + +def test_paragraph_chunker_flush_trailing_fragment(): + assert ParagraphChunker().flush("Para one.\n\nPara two (no sep)") == [ + "Para two (no sep)" + ] + + +def test_paragraph_chunker_flush_single_paragraph_no_separator(): + assert ParagraphChunker().flush("Only paragraph") == ["Only paragraph"] diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index bd03e5ffb..7c3d97793 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -443,6 +443,124 @@ def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: assert not all(len(seen[i]) < len(seen[i + 1]) for i in range(len(seen) - 1)) +@pytest.mark.asyncio +async def test_trailing_fragment_is_flushed_to_consumer() -> None: + """Response without trailing whitespace: final sentence reaches astream() and stream_validate.""" + + class ChunkRecordingReq(Requirement): + def __init__(self) -> None: + self.seen_chunks: list[str] = [] + + def __copy__(self) -> "ChunkRecordingReq": + clone = ChunkRecordingReq() + clone.seen_chunks = [] + return clone + + def format_for_llm(self) -> str: + return "chunk recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self.seen_chunks.append(chunk) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + # No trailing whitespace after the final sentence — SentenceChunker withholds it. + response = "First sentence. Second sentence." + backend = StreamingMockBackend(response, token_size=4) + req = ChunkRecordingReq() + + captured: list[ChunkRecordingReq] = [] + original_copy = ChunkRecordingReq.__copy__ + + def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + finally: + ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + # Both sentences reach the consumer, including the terminating one without trailing whitespace. + assert yielded == ["First sentence.", "Second sentence."] + # stream_validate was called on both — the flush path is not a shortcut. + assert captured[0].seen_chunks == ["First sentence.", "Second sentence."] + assert result.completed is True + + +@pytest.mark.asyncio +async def test_early_exit_on_trailing_fragment() -> None: + """A fail on the flushed fragment records a streaming failure and skips final validate().""" + + class FailOnSecondSentence(Requirement): + def __init__(self) -> None: + self._count = 0 + + def format_for_llm(self) -> str: + return "fail on second sentence" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + self._count += 1 + if self._count >= 2: + return PartialValidationResult("fail", reason="second sentence hit") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "First sentence. Second sentence." + backend = StreamingMockBackend(response, token_size=4) + req = FailOnSecondSentence() + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + # First sentence was emitted; second (the flushed fragment) failed and wasn't emitted. + assert yielded == ["First sentence."] + # Early exit on fail skips final validate(). + assert result.final_validations == [] + + @pytest.mark.asyncio async def test_no_requirements_streams_without_validation() -> None: """quick_check_requirements=None → chunks produced, no validate() called.""" From def10b6f87ce30c2b14901b9ba99acb083df2d1a Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 18:23:10 +0100 Subject: [PATCH 10/26] fix(stdlib): address review feedback on streaming validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses issues raised by independent review on top of PR #942. Orchestrator (mellea/stdlib/streaming.py): - except Exception now calls mot.cancel_generation() before surfacing the exception to the consumer — previously the backend producer was left running, eventually blocking on mot._async_queue (maxsize=20). Cleanup failures are logged via MelleaLogger.warning with a TODO(#902) marker; #902 replaces the log with a proper ErrorEvent. - RuntimeError catch in the astream() loop now re-raises unless mot.is_computed() is true, so only the documented "already computed" race is swallowed. - astream() docstring now states the single-consumer contract explicitly; a second iteration blocks on an empty queue with no sentinel to deliver. - as_thunk docstring now flags the early-exit case: cancel_generation forces is_computed=True without running post_processing(), so generation.usage and related telemetry fields may be None. Chunker (mellea/stdlib/chunking.py): - SentenceChunker.flush switches from .strip() to .rstrip() with a comment explaining why: the loop's lstrip has already removed leading whitespace, and trailing whitespace on a sentence fragment is non-semantic (consistent with split() returning sentences without trailing whitespace). - ParagraphChunker.flush adds a docstring noting the deliberate asymmetry: paragraph fragments are returned byte-for-byte because internal whitespace (e.g. trailing \n of a list item) can be semantically meaningful. Tests (test/stdlib/test_streaming.py): - test_stream_validate_receives_individual_chunks now uses exact- match on the captured chunk list, which directly regresses if someone reverts to accumulated-text semantics. - test_multiple_chunks_in_one_batch_with_mid_batch_fail: response fed as one large token so split() yields 4 sentences at once; verifies chunk 1 emits, chunk 2 fails (not emitted), chunks 3 and 4 are neither validated nor emitted. - test_cancel_generation_invoked_on_fail: spies on ModelOutputThunk.cancel_generation and asserts it was called on the "fail" early-exit path. - test_exception_in_stream_validate_cancels_generation: a requirement that raises must cause cancel_generation to run and the exception to surface via astream()/acomplete() without hanging. Telemetry observability (orchestrator-level spans, metrics, span events) remains deferred to #902 per the epic, which now has the acceptance criteria updated to cover event emission, the OTEL bridge, and the ErrorEvent type that will replace the MelleaLogger stopgap. Assisted-by: Claude Code --- mellea/stdlib/chunking.py | 22 +++- mellea/stdlib/streaming.py | 42 ++++++- test/stdlib/test_streaming.py | 201 ++++++++++++++++++++++++++++++++-- 3 files changed, 253 insertions(+), 12 deletions(-) diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index d4a5f79e1..fb3521ec2 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -116,7 +116,15 @@ def split(self, accumulated_text: str) -> list[str]: return chunks def flush(self, accumulated_text: str) -> list[str]: - """Return the trailing sentence fragment (if any) as a final chunk.""" + """Return the trailing sentence fragment (if any) as a final chunk. + + Trailing whitespace on the fragment is non-semantic for sentence + boundaries and is dropped via ``rstrip``. Leading whitespace is + already removed by the loop's ``lstrip`` on each advance, so no + ``lstrip`` is needed here. The result is the fragment's content + only, consistent with how :meth:`split` returns sentences without + trailing whitespace. + """ if not accumulated_text: return [] remaining = accumulated_text @@ -125,7 +133,7 @@ def flush(self, accumulated_text: str) -> list[str]: if match is None: break remaining = remaining[match.end() :].lstrip() - trailing = remaining.strip() + trailing = remaining.rstrip() return [trailing] if trailing else [] @@ -216,7 +224,15 @@ def split(self, accumulated_text: str) -> list[str]: return [p for p in parts if p] def flush(self, accumulated_text: str) -> list[str]: - """Return the trailing paragraph fragment (if any) as a final chunk.""" + r"""Return the trailing paragraph fragment (if any) as a final chunk. + + Unlike :class:`SentenceChunker.flush`, the fragment is returned + byte-for-byte without stripping. Internal whitespace — including + a trailing single ``\n`` — can be semantically meaningful inside + a paragraph (e.g. a list item or a deliberate line break), and a + consumer validating paragraph content should see the fragment as + it was withheld. + """ if not accumulated_text: return [] if _PARA_BOUNDARY_END.search(accumulated_text): diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 0346fb519..267e14848 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -16,6 +16,7 @@ from ..core.backend import Backend from ..core.base import CBlock, Component, Context, ModelOutputThunk from ..core.requirement import PartialValidationResult, Requirement, ValidationResult +from ..core.utils import MelleaLogger from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker _CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = { @@ -75,6 +76,14 @@ async def astream(self) -> AsyncIterator[str]: all chunks have been yielded, whether the stream completed normally or was cancelled early on a ``"fail"`` result. + **Single-consumer.** Chunks are delivered via an + :class:`asyncio.Queue` that this method drains; calling + ``astream()`` a second time on the same result blocks indefinitely + because the queue is empty and the terminating ``None`` sentinel + has already been consumed. If you need the chunks after + iteration, capture them into a list during the first pass or use + :attr:`full_text` after :meth:`acomplete`. + Yields: str: A validated text chunk from the chunking strategy. @@ -116,6 +125,15 @@ def as_thunk(self) -> ModelOutputThunk: early-exit results; ``value`` will reflect whatever was accumulated before cancellation. + Note: + On early exit, ``cancel_generation()`` forces the MOT into a + computed state without running the backend's + ``post_processing()``. Telemetry fields on the returned thunk + (``generation.usage``, ``generation.ttfb_ms``, etc.) may + therefore be ``None`` or reflect the partial state at + cancellation time. ``value`` and ``streaming`` are reliable; + usage totals are not. + Returns: ModelOutputThunk: A computed thunk containing the streamed output. @@ -178,7 +196,12 @@ async def _validate_and_emit(c: str) -> bool: try: delta = await mot.astream() except RuntimeError: - break + # Expected race: mot.is_computed() was False at the top of the + # loop but the stream finished before we re-entered astream(). + # Any other RuntimeError is a real bug and must propagate. + if mot.is_computed(): + break + raise accumulated += delta chunks = chunking.split(accumulated) @@ -220,6 +243,23 @@ async def _validate_and_emit(c: str) -> bool: ) except Exception as exc: + # Orchestrator is leaving — we must stop the backend producer too, + # otherwise mot._async_queue (maxsize=20) fills and the feeder task + # blocks indefinitely. The spec (#891, #901) calls this out for the + # "fail" path; the same reasoning applies to any unplanned exit. + try: + await mot.cancel_generation() + except Exception as cleanup_exc: + # Never let cleanup mask the original exception: log loudly and + # continue to surface `exc` to the consumer. + # TODO(#902): replace this log with an ErrorEvent emission. + MelleaLogger.get_logger().warning( + "stream_with_chunking: cancel_generation() raised during " + "exception cleanup (original: %r, cleanup: %r)", + exc, + cleanup_exc, + ) + result.completed = False await result._chunk_queue.put(exc) finally: await result._chunk_queue.put(None) diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 7c3d97793..d52ce18a1 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -433,14 +433,12 @@ def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: assert len(captured) == 1 seen = captured[0].seen_chunks - # Three complete sentences → three separate stream_validate calls. - assert len(seen) == 3 - # Each chunk is one sentence, not a prefix of accumulated text. - for chunk in seen: - assert chunk.endswith(".") - # Lengths must not be monotonically growing (which would indicate accumulated text). - # With per-chunk semantics, each chunk is roughly the same length as one sentence. - assert not all(len(seen[i]) < len(seen[i + 1]) for i in range(len(seen) - 1)) + # Exact match: three separate calls, one per complete sentence, + # each call receiving that sentence and nothing more. Under the old + # accumulated-text semantics, seen would have been + # ["First sentence.", "First sentence. Second sentence.", ...] — + # exact match against the per-chunk list is the direct regression guard. + assert seen == ["First sentence.", "Second sentence.", "Third sentence."] @pytest.mark.asyncio @@ -576,3 +574,190 @@ async def test_no_requirements_streams_without_validation() -> None: assert result.full_text == response assert result.final_validations == [] assert result.streaming_failures == [] + + +@pytest.mark.asyncio +async def test_multiple_chunks_in_one_batch_with_mid_batch_fail() -> None: + """When one astream() delta produces several complete chunks and one in + the middle fails, earlier chunks emit, failing chunk is recorded, later + chunks are neither validated nor emitted.""" + + captured: list[Any] = [] + + class FailOnNthChunk(Requirement): + def __init__(self, n: int) -> None: + self._n = n + self._calls = 0 + self.seen: list[str] = [] + + def __copy__(self) -> "FailOnNthChunk": + clone = FailOnNthChunk(self._n) + captured.append(clone) + return clone + + def format_for_llm(self) -> str: + return f"fail on chunk {self._n}" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = backend, ctx + self._calls += 1 + self.seen.append(chunk) + if self._calls == self._n: + return PartialValidationResult("fail", reason=f"n={self._n}") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + # token_size larger than the whole response → one astream() delta delivers + # the full text, so chunking.split produces 4 sentences in a single batch. + response = "One. Two. Three. Four. " + backend = StreamingMockBackend(response, token_size=100) + req = FailOnNthChunk(n=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for c in result.astream(): + yielded.append(c) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + # Chunk 1 was validated and emitted; chunk 2 was validated and failed + # (NOT emitted); chunks 3 and 4 were NEITHER validated NOR emitted. + assert yielded == ["One."] + assert len(captured) == 1 + assert captured[0].seen == ["One.", "Two."] + assert captured[0]._calls == 2 + + +@pytest.mark.asyncio +async def test_cancel_generation_invoked_on_fail() -> None: + """Early exit on 'fail' must call mot.cancel_generation() — the spec reason + is that asyncio.Queue(maxsize=20) will block the producer if the consumer + stops without cancelling.""" + + from mellea.core.base import ModelOutputThunk + + response = "word " * 50 + backend = StreamingMockBackend(response, token_size=3) + + class FailOnFirstChunk(Requirement): + def format_for_llm(self) -> str: + return "fail immediately" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + return PartialValidationResult("fail", reason="nope") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + call_count = 0 + real_cancel = ModelOutputThunk.cancel_generation + + async def spy_cancel(self: ModelOutputThunk) -> None: + nonlocal call_count + call_count += 1 + await real_cancel(self) + + ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[FailOnFirstChunk()], + chunking="word", + ) + await asyncio.wait_for(result.acomplete(), timeout=5.0) + finally: + ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign] + + assert result.completed is False + assert call_count >= 1 + + +@pytest.mark.asyncio +async def test_exception_in_stream_validate_cancels_generation() -> None: + """If stream_validate raises, the orchestrator must still call + cancel_generation() — otherwise the backend producer blocks on the + (maxsize=20) queue — and surface the exception to the consumer via + astream()/acomplete().""" + + from mellea.core.base import ModelOutputThunk + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + raise ValueError("boom") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + response = "word " * 50 # enough to fill maxsize=20 queue without cleanup + backend = StreamingMockBackend(response, token_size=3) + + call_count = 0 + real_cancel = ModelOutputThunk.cancel_generation + + async def spy_cancel(self: ModelOutputThunk) -> None: + nonlocal call_count + call_count += 1 + await real_cancel(self) + + ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[RaisingReq()], + chunking="word", + ) + with pytest.raises(ValueError, match="boom"): + async for _chunk in result.astream(): + pass + # acomplete must complete (not hang) even though the orchestration + # task raised, because cancel_generation was called in the except path. + await asyncio.wait_for(result.acomplete(), timeout=5.0) + finally: + ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign] + + assert result.completed is False + assert call_count >= 1 From da41a06a0ce0926d482f8d74371da5f6fc7f4a41 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 19:09:15 +0100 Subject: [PATCH 11/26] fix(stdlib): address second-round review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three items from the second independent review: cancel_generation(error=) — accept an optional Exception parameter. When the orchestrator enters the except Exception path, it now passes the caught exception to cancel_generation() so the backend telemetry span records the real cause via set_span_error instead of a generic RuntimeError("Generation cancelled"). The original exception still surfaces to the consumer via astream()/acomplete(); this is purely an OTEL accuracy fix. Backward-compatible: the default None preserves the previous "Generation cancelled" message for the normal fail path. stream_with_chunking docstring — the "After the stream ends (naturally or via early exit), validate() is called" wording overstated behaviour. The orchestrator actually skips final validate() on early exit (test_early_exit_on_fail verifies final_validations == []). Docstring now correctly says final validate() runs only on natural completion. test_exception_in_stream_validate_cancels_generation docstring — the test fails on chunk 1 so the queue never actually fills; it verifies the cancel-on-exception path and the no-hang guarantee but does not directly prove the worst-case "producer blocked on full queue" scenario. Docstring now states what it actually covers and points at test/core/ for the cancel_generation drain logic. Assisted-by: Claude Code --- mellea/core/base.py | 15 +++++++++++++-- mellea/stdlib/streaming.py | 13 +++++++++---- test/stdlib/test_streaming.py | 26 ++++++++++++++++++-------- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/mellea/core/base.py b/mellea/core/base.py index 95b6e8cdc..28ab78783 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -364,7 +364,7 @@ def _record_ttfb(self) -> None: ).total_seconds() * 1000 self._first_chunk_received = True - async def cancel_generation(self) -> None: + async def cancel_generation(self, error: Exception | None = None) -> None: """Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span. Safe to call at any point during streaming. After this method returns, @@ -375,6 +375,14 @@ async def cancel_generation(self) -> None: Draining the internal queue after cancellation is necessary to release any ``asyncio.Queue.put()`` call that the generation task was blocked on (queue maxsize=20). + + Args: + error: Optional cause attributed to the open telemetry span. When + provided, this exception is recorded via ``set_span_error`` so + the span reflects the actual reason for cancellation (e.g. the + requirement failure or an unhandled exception from a streaming + validator). When ``None``, a generic + ``RuntimeError("Generation cancelled")`` is recorded. """ if self._computed: return @@ -414,7 +422,10 @@ def _drain() -> None: if span is not None: from ..telemetry import end_backend_span, set_span_error - set_span_error(span, RuntimeError("Generation cancelled")) + recorded: Exception = ( + error if error is not None else RuntimeError("Generation cancelled") + ) + set_span_error(span, recorded) end_backend_span(span) del self._meta["_telemetry_span"] diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 267e14848..16cafbe9d 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -247,8 +247,10 @@ async def _validate_and_emit(c: str) -> bool: # otherwise mot._async_queue (maxsize=20) fills and the feeder task # blocks indefinitely. The spec (#891, #901) calls this out for the # "fail" path; the same reasoning applies to any unplanned exit. + # Pass `exc` so the backend telemetry span records the real cause + # rather than a generic "Generation cancelled". try: - await mot.cancel_generation() + await mot.cancel_generation(error=exc) except Exception as cleanup_exc: # Never let cleanup mask the original exception: log loudly and # continue to surface `exc` to the consumer. @@ -304,9 +306,12 @@ async def stream_with_chunking( same terms as the regular chunks. On early exit, the trailing fragment is discarded because the generation was cancelled mid-token. - After the stream ends (naturally or via early exit), ``validate()`` is - called on all requirements that did not return ``"fail"``. Requirements - are cloned (``copy(req)``) before use so originals are never mutated. + After the stream ends naturally, ``validate()`` is called on every + requirement that did not return ``"fail"`` — both ``"pass"`` and + ``"unknown"`` trigger final validation. On early exit, no ``validate()`` + call is made; :attr:`StreamChunkingResult.final_validations` remains + empty. Requirements are cloned (``copy(req)``) before use so originals + are never mutated. Requirements that need context beyond the current chunk should accumulate it themselves across ``stream_validate`` calls (e.g. diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index d52ce18a1..560b94ec9 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -678,10 +678,12 @@ async def validate( call_count = 0 real_cancel = ModelOutputThunk.cancel_generation - async def spy_cancel(self: ModelOutputThunk) -> None: + async def spy_cancel( + self: ModelOutputThunk, error: Exception | None = None + ) -> None: nonlocal call_count call_count += 1 - await real_cancel(self) + await real_cancel(self, error) ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] try: @@ -702,10 +704,16 @@ async def spy_cancel(self: ModelOutputThunk) -> None: @pytest.mark.asyncio async def test_exception_in_stream_validate_cancels_generation() -> None: - """If stream_validate raises, the orchestrator must still call - cancel_generation() — otherwise the backend producer blocks on the - (maxsize=20) queue — and surface the exception to the consumer via - astream()/acomplete().""" + """Verifies the orchestrator's exception-path cleanup: if stream_validate + raises, cancel_generation() is called and the exception surfaces to the + consumer via astream()/acomplete() without hanging. + + This covers the cancel-on-exception path and the no-hang guarantee. + It does not directly exercise the worst-case "producer already blocked on + full queue" scenario (here the fail happens on chunk 1 so the queue never + fills); the cancel_generation drain logic is covered by its own tests in + test/core/. + """ from mellea.core.base import ModelOutputThunk @@ -736,10 +744,12 @@ async def validate( call_count = 0 real_cancel = ModelOutputThunk.cancel_generation - async def spy_cancel(self: ModelOutputThunk) -> None: + async def spy_cancel( + self: ModelOutputThunk, error: Exception | None = None + ) -> None: nonlocal call_count call_count += 1 - await real_cancel(self) + await real_cancel(self, error) ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] try: From 74c009d245e94d2d7dc4dbb721c60a135e096120 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 19:38:00 +0100 Subject: [PATCH 12/26] docs(stdlib): add Args and Returns sections to chunker flush overrides The Docs CI docstring quality gate [no_class_args]-equivalent check requires every documented method with typed params to have an Args section and a Returns section matching the return annotation. SentenceChunker.flush, WordChunker.flush, and ParagraphChunker.flush all took accumulated_text and returned list[str] without the sections. Add both to each override, documenting each flush's specific semantics (rstrip for sentences, whitespace-split trailing fragment for words, byte-for-byte for paragraphs). Assisted-by: Claude Code --- mellea/stdlib/chunking.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index fb3521ec2..6c81105c5 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -124,6 +124,15 @@ def flush(self, accumulated_text: str) -> list[str]: ``lstrip`` is needed here. The result is the fragment's content only, consistent with how :meth:`split` returns sentences without trailing whitespace. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing sentence fragment + with leading and trailing whitespace stripped, or an empty list + when there is no fragment (all content ended in a sentence + boundary or the input is empty/whitespace-only). """ if not accumulated_text: return [] @@ -177,7 +186,21 @@ def split(self, accumulated_text: str) -> list[str]: return parts def flush(self, accumulated_text: str) -> list[str]: - """Return the trailing word fragment (if any) as a final chunk.""" + """Return the trailing word fragment (if any) as a final chunk. + + The trailing fragment is the text after the last whitespace run when + the accumulated text does not end with whitespace. When it does end + with whitespace, every word is already complete and no fragment is + released. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing word fragment, or + an empty list when the input ends with whitespace (every word + already complete) or is empty. + """ if not accumulated_text: return [] if accumulated_text[-1].isspace(): @@ -232,6 +255,14 @@ def flush(self, accumulated_text: str) -> list[str]: a paragraph (e.g. a list item or a deliberate line break), and a consumer validating paragraph content should see the fragment as it was withheld. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing paragraph fragment + byte-for-byte, or an empty list when the input ends with a + paragraph boundary (``\n\n`` or more) or is empty. """ if not accumulated_text: return [] From 3fb501ef56c78c74fe18e617e1e3c17170611fab Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 29 Apr 2026 11:14:10 +0100 Subject: [PATCH 13/26] fix(stdlib): address third-round review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _orchestrate_streaming: add cancel_generation() in finally block so the backend producer is stopped even on external CancelledError (BaseException bypasses except Exception, leaving _generate hung on a full queue) - cancel_generation: replace .get + del on _telemetry_span with .pop to prevent KeyError if two coroutines race before _computed is set - Example and test doubles: add super().__init__() to Requirement subclasses so description/validation_fn/_output are always initialised - docs/examples: fix pytest tier marker integration → e2e (Ollama example must be e2e per MARKERS_GUIDE; all peer examples use e2e) - test_quick_check_backend_routing: capture clone via __copy__ intercept and assert all seen_backends are val_backend, not just clone-isolation check Assisted-by: Claude Code --- docs/examples/streaming/streaming_chunking.py | 3 +- mellea/core/base.py | 3 +- mellea/stdlib/streaming.py | 9 ++++ test/stdlib/test_streaming.py | 44 +++++++++++++------ 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py index 70037d0a1..737ea7486 100644 --- a/docs/examples/streaming/streaming_chunking.py +++ b/docs/examples/streaming/streaming_chunking.py @@ -1,4 +1,4 @@ -# pytest: ollama, integration +# pytest: ollama, e2e """Streaming generation with per-chunk validation using stream_with_chunking(). @@ -32,6 +32,7 @@ class MaxSentencesReq(Requirement): """ def __init__(self, limit: int) -> None: + super().__init__() self._limit = limit self._count = 0 diff --git a/mellea/core/base.py b/mellea/core/base.py index 28ab78783..a8f35e79d 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -418,7 +418,7 @@ def _drain() -> None: # Drain again for any final item the task put before terminating. _drain() - span = self._meta.get("_telemetry_span") + span = self._meta.pop("_telemetry_span", None) if span is not None: from ..telemetry import end_backend_span, set_span_error @@ -427,7 +427,6 @@ def _drain() -> None: ) set_span_error(span, recorded) end_backend_span(span) - del self._meta["_telemetry_span"] if self._underlying_value is None: self._underlying_value = "" diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 16cafbe9d..c417f9057 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -264,6 +264,15 @@ async def _validate_and_emit(c: str) -> bool: result.completed = False await result._chunk_queue.put(exc) finally: + # CancelledError (BaseException, not Exception) bypasses the except + # block above, so cancel_generation() may not have been called. + # Guard here ensures the backend producer is always stopped, even on + # external task cancellation (e.g. asyncio.wait_for timeout). + if not mot.is_computed(): + try: + await mot.cancel_generation() + except BaseException: + pass await result._chunk_queue.put(None) result._done.set() diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 560b94ec9..46550d9a0 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -122,6 +122,7 @@ class FailAfterWordsReq(Requirement): """ def __init__(self, threshold: int) -> None: + super().__init__() self._threshold = threshold self._word_count = 0 @@ -146,6 +147,7 @@ class BackendRecordingReq(Requirement): """Records which backend was passed to stream_validate and validate.""" def __init__(self) -> None: + super().__init__() self.seen_backends: list[Any] = [] def __copy__(self) -> "BackendRecordingReq": @@ -174,6 +176,7 @@ class MutationDetectorReq(Requirement): """Tracks how many times stream_validate was called on this instance.""" def __init__(self) -> None: + super().__init__() self._call_count = 0 def format_for_llm(self) -> str: @@ -291,22 +294,37 @@ async def test_quick_check_backend_routing() -> None: req = BackendRecordingReq() - result = await stream_with_chunking( - _action(), - main_backend, - _ctx(), - quick_check_requirements=[req], - chunking="sentence", - quick_check_backend=val_backend, - ) - await result.acomplete() + # Capture the cloned requirement so we can inspect which backends it saw. + captured: list[BackendRecordingReq] = [] + original_copy = BackendRecordingReq.__copy__ + + def _capturing_copy(self: BackendRecordingReq) -> BackendRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + BackendRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + main_backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + quick_check_backend=val_backend, + ) + await result.acomplete() + finally: + BackendRecordingReq.__copy__ = original_copy # type: ignore[method-assign] - # The clone's seen_backends should only contain val_backend - # (The original req was never called; clones were.) - # Verify via final_validations side-effect: at least one backend recorded assert result.completed is True - # The original req._seen_backends is untouched (clone isolation) + # The original was never called — only clones are used. assert req.seen_backends == [] + # The clone must have seen val_backend for every call (stream_validate + validate), + # never main_backend. This is the actual routing assertion. + assert len(captured) == 1 + assert len(captured[0].seen_backends) > 0 + assert all(b is val_backend for b in captured[0].seen_backends) @pytest.mark.asyncio From 5850f924787960b96c6915864911255a89db62b9 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Fri, 1 May 2026 14:26:22 +0100 Subject: [PATCH 14/26] fix(stdlib): stash orchestrator exception and narrow finally except Addresses review feedback on `_orchestrate_streaming` cleanup: - Exceptions caught by the orchestrator were only pushed to the chunk queue, so callers who skipped `astream()` and went straight to `acomplete()` saw the call return silently. Stash the exception on the result and raise it from `acomplete()` with raise-once semantics (cleared by whichever of astream/acomplete reads it first). - The finally cleanup caught `BaseException`, silently absorbing CancelledError/KeyboardInterrupt/SystemExit. Narrow to `except Exception` and switch the terminator to `put_nowait(None)` + `set()` so the sync ops always run even when the task is being cancelled (otherwise acomplete consumers hang). Two regression tests: - test_acomplete_surfaces_exception_without_astream - test_external_task_cancellation_releases_consumers Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/stdlib/streaming.py | 36 +++++++++++--- test/stdlib/test_streaming.py | 88 +++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 7 deletions(-) diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index c417f9057..dcdd6e894 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -62,6 +62,10 @@ def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue() self._orchestration_task: asyncio.Task[None] | None = None self._done = asyncio.Event() + # Stashed so acomplete() surfaces orchestrator failures even when the + # consumer never iterates astream(). Cleared once consumed by + # whichever of the two reads it first. + self._orchestration_exception: BaseException | None = None self.completed: bool = True self.full_text: str = "" @@ -96,6 +100,10 @@ async def astream(self) -> AsyncIterator[str]: if item is None: return if isinstance(item, Exception): + if self._orchestration_exception is None: + # Already surfaced by acomplete(); don't raise twice. + continue + self._orchestration_exception = None raise item yield item @@ -111,10 +119,16 @@ async def acomplete(self) -> None: Exception: Propagates any error from the orchestration task. """ await self._done.wait() + # Raise-once: if astream() already consumed the exception, the stash + # is already None and this is a no-op. + exc = self._orchestration_exception + if exc is not None: + self._orchestration_exception = None + raise exc if self._orchestration_task is not None and self._orchestration_task.done(): - exc = self._orchestration_task.exception() - if exc is not None: - raise exc + task_exc = self._orchestration_task.exception() + if task_exc is not None: + raise task_exc @property def as_thunk(self) -> ModelOutputThunk: @@ -262,18 +276,26 @@ async def _validate_and_emit(c: str) -> bool: cleanup_exc, ) result.completed = False + result._orchestration_exception = exc await result._chunk_queue.put(exc) finally: # CancelledError (BaseException, not Exception) bypasses the except # block above, so cancel_generation() may not have been called. - # Guard here ensures the backend producer is always stopped, even on - # external task cancellation (e.g. asyncio.wait_for timeout). + # Catch only Exception here so CancelledError / KeyboardInterrupt / + # SystemExit still propagate to the caller. if not mot.is_computed(): try: await mot.cancel_generation() - except BaseException: + except Exception: pass - await result._chunk_queue.put(None) + # put_nowait + set() are synchronous — no await point, so they cannot + # be interrupted by task cancellation. Consumers waiting on + # _done.wait() are always released, even if the task was cancelled + # mid-cleanup. The queue is unbounded, so QueueFull cannot occur. + try: + result._chunk_queue.put_nowait(None) + except asyncio.QueueFull: + pass result._done.set() diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 46550d9a0..759d8d272 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -789,3 +789,91 @@ async def spy_cancel( assert result.completed is False assert call_count >= 1 + + +@pytest.mark.asyncio +async def test_acomplete_surfaces_exception_without_astream() -> None: + """acomplete() must surface orchestrator exceptions even when the + consumer never iterates astream(). + + The alternative — only delivering the exception through the chunk queue + — silently swallows validator failures for callers who skip astream(). + """ + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + raise ValueError("surfaced-without-astream") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + response = "word " * 50 + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[RaisingReq()], + chunking="word", + ) + # Deliberately skip astream(). wait_for bounds any hang. + with pytest.raises(ValueError, match="surfaced-without-astream"): + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + assert result.completed is False + # Raise-once: a second acomplete() must not re-raise. + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + +@pytest.mark.asyncio +async def test_external_task_cancellation_releases_consumers() -> None: + """External cancellation of the orchestration task must still set _done. + + If the finally cleanup itself contains an ``await`` (e.g. awaiting a + terminator put into the chunk queue), CancelledError re-raises at that + await and ``_done.set()`` never runs — any consumer blocked on + ``acomplete()`` hangs forever. The cleanup must therefore end with + synchronous operations only. + """ + response = "word " * 200 # long enough that streaming is still in progress + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[AlwaysUnknownReq()], + chunking="word", + ) + + assert result._orchestration_task is not None + # Yield once so the orchestration task enters its main loop before we + # cancel it. + await asyncio.sleep(0.01) + + # Same mechanism asyncio.wait_for uses on timeout. + result._orchestration_task.cancel() + + # _done must be set by the finally cleanup. A hang would time out here. + await asyncio.wait_for(result._done.wait(), timeout=2.0) + assert result._done.is_set() + + # acomplete() surfaces the CancelledError via task.exception() and must + # not hang. + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(result.acomplete(), timeout=2.0) From 4f508fd4b12d035733e6f069b2805fbc0982a4f2 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 5 May 2026 10:11:17 +0100 Subject: [PATCH 15/26] feat(core): add cancelled flag on ModelOutputThunk Adds a `_cancelled` attribute (False by default) on `ModelOutputThunk`, set to True inside `cancel_generation()` just before `_computed = True`, exposed via a read-only `cancelled` property. Propagated through `StreamChunkingResult.as_thunk` so consumers that only hold the wrapped thunk can still distinguish cancellation from a natural completion. Addresses nrfulton's review feedback on #942 and pre-stages the cancel-vs-complete signal that #902's `CompletedEvent` needs to surface. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/core/base.py | 13 +++++++++++++ mellea/stdlib/streaming.py | 1 + 2 files changed, 14 insertions(+) diff --git a/mellea/core/base.py b/mellea/core/base.py index a8f35e79d..60a079c8a 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -320,6 +320,7 @@ def __init__( # Set computed to True if a value is passed in. self._computed: bool = True if value is not None else False + self._cancelled: bool = False # Additional fields that should be standardized across apis. self.tool_calls = tool_calls @@ -430,8 +431,20 @@ def _drain() -> None: if self._underlying_value is None: self._underlying_value = "" + self._cancelled = True self._computed = True + @property + def cancelled(self) -> bool: + """``True`` if :meth:`cancel_generation` ran to completion on this MOT. + + A normally-completed MOT leaves this ``False``; only an actual + cancellation via :meth:`cancel_generation` flips it. Consumers holding + a computed MOT can use this to distinguish a genuine result from one + cut short (for example by a streaming requirement failure). + """ + return self._cancelled + def _copy_from(self, other: ModelOutputThunk) -> None: """Copy computed-output fields from *other* into *self*. diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index dcdd6e894..2ba815096 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -159,6 +159,7 @@ def as_thunk(self) -> ModelOutputThunk: "as_thunk accessed before acomplete() — await acomplete() first" ) thunk = ModelOutputThunk(value=self.full_text) + thunk._cancelled = self._mot._cancelled thunk.generation = copy(self._mot.generation) return thunk From 5075a4794cf29578c4bbcc4f3affa224bf53f289 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 5 May 2026 11:59:32 +0100 Subject: [PATCH 16/26] docs(stdlib): note ChunkingStrategy is text-only Adds a short note on the ChunkingStrategy class docstring stating that the ABC operates on text streams only and does not support multi-modal output (audio segments, image regions). Addresses review feedback on #942 without expanding scope. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/stdlib/chunking.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index 6c81105c5..462af81a8 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -19,6 +19,10 @@ class ChunkingStrategy(ABC): take the full accumulated text, identify everything after the last returned chunk boundary, and handle it appropriately (e.g. pass to a final validator or discard). + + Note: this ABC operates on text streams only. Multi-modal output (audio + segments, image regions) is not supported — the ``accumulated_text: str`` + signatures on ``split`` and ``flush`` preclude it. """ @abstractmethod From f0f93b3acf8b73d4a7be09e082c5bfc1fe3598a6 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 5 May 2026 12:06:33 +0100 Subject: [PATCH 17/26] test(stdlib): assert cancelled flag reflects cancellation state Adds test_cancelled_flag_reflects_cancellation_state covering both the early-exit path (cancelled is True, is_computed True, propagates through as_thunk) and the normal-completion path (cancelled is False). Pairs with the cancellation flag added in the prior commit. Addresses nrfulton's review feedback on #942. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- test/stdlib/test_streaming.py | 61 +++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 759d8d272..2010c9b3e 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -720,6 +720,67 @@ async def spy_cancel( assert call_count >= 1 +@pytest.mark.asyncio +async def test_cancelled_flag_reflects_cancellation_state() -> None: + """The ``cancelled`` property on ModelOutputThunk distinguishes an early-exit + cancellation from a normal completion and propagates through ``as_thunk``.""" + + # Early exit → cancelled is True, is_computed True, propagates through as_thunk. + fail_response = "word " * 50 + fail_backend = StreamingMockBackend(fail_response, token_size=3) + + class FailImmediately(Requirement): + def format_for_llm(self) -> str: + return "fail immediately" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + return PartialValidationResult("fail", reason="nope") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + fail_result = await stream_with_chunking( + _action(), + fail_backend, + _ctx(), + quick_check_requirements=[FailImmediately()], + chunking="word", + ) + await asyncio.wait_for(fail_result.acomplete(), timeout=5.0) + + assert fail_result.completed is False + assert fail_result.as_thunk.cancelled is True + assert fail_result.as_thunk.is_computed() is True + + # Normal completion → cancelled is False. + ok_response = "Hello world. How are you. " + ok_backend = StreamingMockBackend(ok_response, token_size=3) + + ok_result = await stream_with_chunking( + _action(), + ok_backend, + _ctx(), + quick_check_requirements=[AlwaysUnknownReq()], + chunking="sentence", + ) + await ok_result.acomplete() + + assert ok_result.completed is True + assert ok_result.as_thunk.cancelled is False + assert ok_result.as_thunk.is_computed() is True + + @pytest.mark.asyncio async def test_exception_in_stream_validate_cancels_generation() -> None: """Verifies the orchestrator's exception-path cleanup: if stream_validate From 18bfe02ec7113dd26f89b4ac29edb6b50ff924eb Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 11 May 2026 12:30:49 +0100 Subject: [PATCH 18/26] fix(stdlib): address psschwei review comments on streaming Three fixes: 1. Raise-once regression (5850f924): replace the _orchestration_exception stash as the already-surfaced sentinel with a dedicated _exception_surfaced: bool flag. The previous guard conflated 'stash never set' with 'already cleared by acomplete()' -- both show as None, so a subsequent astream() call could silently skip a queued exception item and return zero chunks with no error. 2. _copy_from / __copy__ / __deepcopy__ omitted _cancelled: the cancelled flag added in 4f508fd4 was not propagated by any of the three copy paths on ModelOutputThunk. 3. full_text on early exit now reflects only validated-and-emitted chunks: accumulated += delta ran unconditionally before chunk iteration, so when a multi-chunk delta contained a failing chunk, full_text (and as_thunk.value) included the failing chunk's text and any later chunks from the same delta that were never validated. A new emitted_text local tracks only the chunks emitted to the consumer queue; full_text is set from emitted_text on early exit and from accumulated on natural completion. Three new regression tests cover each fix. Inline comment added on the while-loop break to address indentation question. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/core/base.py | 3 + mellea/stdlib/streaming.py | 30 +++++-- test/stdlib/test_streaming.py | 142 ++++++++++++++++++++++++++++++++++ 3 files changed, 168 insertions(+), 7 deletions(-) diff --git a/mellea/core/base.py b/mellea/core/base.py index 60a079c8a..a8ff2f3aa 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -459,6 +459,7 @@ def _copy_from(self, other: ModelOutputThunk) -> None: self._thinking = other._thinking self.generation = other.generation self._generate_log = other._generate_log + self._cancelled = other._cancelled def is_computed(self) -> bool: """Returns true only if this Thunk has already been filled. @@ -687,6 +688,7 @@ def __copy__(self) -> ModelOutputThunk: copied.parsed_repr = copied # type: ignore copied._computed = self._computed + copied._cancelled = self._cancelled copied._thinking = self._thinking copied._action = self._action copied._context = self._context @@ -715,6 +717,7 @@ def __deepcopy__(self, memo: dict) -> ModelOutputThunk: deepcopied._meta = deepcopy(self._meta) deepcopied.tool_calls = deepcopy(self.tool_calls) deepcopied._computed = self._computed + deepcopied._cancelled = self._cancelled deepcopied._thinking = self._thinking deepcopied._action = deepcopy(self._action) deepcopied._context = copy( diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 2ba815096..5f4b94dac 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -66,6 +66,13 @@ def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: # consumer never iterates astream(). Cleared once consumed by # whichever of the two reads it first. self._orchestration_exception: BaseException | None = None + # Tracks whether the exception has already been surfaced to the caller + # (by astream OR acomplete). A separate flag rather than reusing the + # stash slot avoids the race where acomplete() clears the stash, a + # subsequent astream() dequeues the exception item, sees the stash is + # None, and silently skips it — leaving the caller with zero chunks + # and no error. + self._exception_surfaced: bool = False self.completed: bool = True self.full_text: str = "" @@ -100,9 +107,10 @@ async def astream(self) -> AsyncIterator[str]: if item is None: return if isinstance(item, Exception): - if self._orchestration_exception is None: + if self._exception_surfaced: # Already surfaced by acomplete(); don't raise twice. continue + self._exception_surfaced = True self._orchestration_exception = None raise item yield item @@ -119,10 +127,10 @@ async def acomplete(self) -> None: Exception: Propagates any error from the orchestration task. """ await self._done.wait() - # Raise-once: if astream() already consumed the exception, the stash - # is already None and this is a no-op. + # Raise-once: if astream() already surfaced the exception, skip. exc = self._orchestration_exception - if exc is not None: + if exc is not None and not self._exception_surfaced: + self._exception_surfaced = True self._orchestration_exception = None raise exc if self._orchestration_task is not None and self._orchestration_task.done(): @@ -173,6 +181,7 @@ async def _orchestrate_streaming( val_backend: Backend, ) -> None: accumulated = "" + emitted_text = "" prev_chunk_count = 0 failed_indices: set[int] = set() early_exit = False @@ -230,9 +239,10 @@ async def _validate_and_emit(c: str) -> bool: result.completed = False await mot.cancel_generation() break + emitted_text += c if early_exit: - break + break # break the while loop; cancel_generation() already set _computed=True # Stream ended naturally: flush any withheld trailing fragment and # run stream_validate on it. Skipped on early exit — the generation @@ -244,8 +254,14 @@ async def _validate_and_emit(c: str) -> bool: early_exit = True result.completed = False break - - result.full_text = accumulated + emitted_text += c + + # On early exit, full_text reflects only validated-and-emitted chunks + # so it matches exactly what the consumer received via astream(). + # On natural completion emitted_text == accumulated (every character + # ends up in some chunk or flushed fragment), so either value is + # equivalent; accumulated is used to preserve the original raw text. + result.full_text = emitted_text if early_exit else accumulated non_failed = [ req for i, req in enumerate(cloned_reqs) if i not in failed_indices diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 2010c9b3e..76f1e8b47 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -938,3 +938,145 @@ async def test_external_task_cancellation_releases_consumers() -> None: # not hang. with pytest.raises(asyncio.CancelledError): await asyncio.wait_for(result.acomplete(), timeout=2.0) + + +@pytest.mark.asyncio +async def test_raise_once_acomplete_then_astream() -> None: + """Regression for the raise-once stash bug: acomplete() first, astream() second. + + Prior to the fix, acomplete() cleared _orchestration_exception, so a + subsequent astream() call dequeued the exception item, saw the stash was + None, silently skipped it, and returned zero chunks with no error. + """ + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + raise ValueError("raise-once-regression") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "word " * 10 + backend = StreamingMockBackend(response, token_size=3) + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[RaisingReq()], + chunking="word", + ) + + # acomplete() sees the exception first and raises it. + with pytest.raises(ValueError, match="raise-once-regression"): + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + # astream() must NOT re-raise (raise-once semantics). Because the + # exception fired before any chunk was emitted, the queue contains + # [exc, None]. With the separate _exception_surfaced flag, astream() + # correctly skips the exception item and terminates cleanly. Without + # the flag the behaviour is the same, but the guard conflates + # "already surfaced" with "stash was never set" — the flag makes the + # intent unambiguous. + chunks: list[str] = [] + async for chunk in result.astream(): + chunks.append(chunk) + assert chunks == [] # no partial chunks before the exception + + +@pytest.mark.asyncio +async def test_full_text_contains_only_validated_chunks_on_early_exit() -> None: + """full_text must equal exactly what was emitted to the consumer on early exit. + + When one astream() delta produces N chunks and chunk K fails, full_text + must contain chunks 0..K-1 only — not the failed chunk or any unvalidated + chunks after it in the same delta. + """ + + class FailOnNthChunkText(Requirement): + def __init__(self, n: int) -> None: + self._n = n + self._calls = 0 + + def __copy__(self) -> "FailOnNthChunkText": + return FailOnNthChunkText(self._n) + + def format_for_llm(self) -> str: + return f"fail on chunk {self._n}" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self._calls += 1 + if self._calls == self._n: + return PartialValidationResult("fail") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + # token_size > full response → single delta with 4 sentences; fail on chunk 2. + response = "One. Two. Three. Four. " + backend = StreamingMockBackend(response, token_size=100) + req = FailOnNthChunkText(n=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + + assert result.completed is False + # Consumer received only chunk 1. + assert yielded == ["One."] + # full_text must match what the consumer received — not the raw delta. + assert result.full_text == "One." + # as_thunk.value must agree with full_text. + assert result.as_thunk.value == result.full_text + + +@pytest.mark.asyncio +async def test_cancelled_flag_propagates_through_copy_methods() -> None: + """_cancelled must survive __copy__, __deepcopy__, and _copy_from.""" + from copy import deepcopy + + mot = ModelOutputThunk(value="result") + mot._cancelled = True + + # __copy__ + shallow = mot.__copy__() + assert shallow._cancelled is True, "__copy__ must propagate _cancelled" + + # __deepcopy__ + deep = deepcopy(mot) + assert deep._cancelled is True, "__deepcopy__ must propagate _cancelled" + + # _copy_from + target = ModelOutputThunk(value="original") + assert target._cancelled is False + target._copy_from(mot) + assert target._cancelled is True, "_copy_from must propagate _cancelled" + + # Sanity: default-constructed MOT has _cancelled=False. + fresh = ModelOutputThunk(value="x") + assert fresh._cancelled is False From 7fc40a4e968f88477dc89b597f5c01e8a6b42d2e Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 13 May 2026 07:52:07 +0100 Subject: [PATCH 19/26] fix(stdlib): clone requirements before backend start; cancel peer validators on gather failure - Reorder copy(req) before generate_from_context so a raising __copy__ cannot leave a backend feeder task wedged against a full _async_queue - Add try/except BaseException around the post-generate setup window; on failure cancel_generation() is called and the exception re-raised - Replace both asyncio.gather sites in _orchestrate_streaming with asyncio.TaskGroup so peer validators are cancelled on first failure - Unwrap ExceptionGroup from TaskGroup in the exception handler; log any suppressed siblings before forwarding the first exception - Extend Requirement.stream_validate docstring with __copy__ failure contract Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/core/requirement.py | 7 +- mellea/stdlib/streaming.py | 87 ++++++++++++----- test/stdlib/test_streaming.py | 174 ++++++++++++++++++++++++++++++++++ 3 files changed, 242 insertions(+), 26 deletions(-) diff --git a/mellea/core/requirement.py b/mellea/core/requirement.py index e8c55564b..141af22b0 100644 --- a/mellea/core/requirement.py +++ b/mellea/core/requirement.py @@ -300,7 +300,12 @@ async def stream_validate( are shared by reference under ``copy()``. Reassign rather than mutate in place (``self._buffer = self._buffer + [chunk]``, not ``self._buffer.append(chunk)``), or override ``__copy__`` for proper - isolation. + isolation. If an override raises, the enclosing + :func:`~mellea.stdlib.streaming.stream_with_chunking` call aborts before + any backend generation starts and the exception propagates unchanged. + Overrides with externally visible side effects (file writes, network + calls) should perform them only after any logic that could raise, since + the framework cannot roll them back. Implementations must not call ``mot.astream()`` or otherwise read the underlying stream; the orchestrator is the single consumer of the MOT diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 5f4b94dac..df9ad5a80 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -196,14 +196,12 @@ async def _validate_and_emit(c: str) -> bool: (i, req) for i, req in enumerate(cloned_reqs) if i not in failed_indices ] if active: - pvrs: list[PartialValidationResult] = list( - await asyncio.gather( - *[ - req.stream_validate(c, backend=val_backend, ctx=ctx) - for _, req in active - ] - ) - ) + async with asyncio.TaskGroup() as tg: + _tasks = [ + tg.create_task(req.stream_validate(c, backend=val_backend, ctx=ctx)) + for _, req in active + ] + pvrs: list[PartialValidationResult] = [t.result() for t in _tasks] for (idx, req), pvr in zip(active, pvrs): if pvr.success == "fail": failed_indices.add(idx) @@ -267,11 +265,11 @@ async def _validate_and_emit(c: str) -> bool: req for i, req in enumerate(cloned_reqs) if i not in failed_indices ] if non_failed and not early_exit: - result.final_validations = list( - await asyncio.gather( - *[req.validate(val_backend, ctx) for req in non_failed] - ) - ) + async with asyncio.TaskGroup() as tg: + _final_tasks = [ + tg.create_task(req.validate(val_backend, ctx)) for req in non_failed + ] + result.final_validations = [t.result() for t in _final_tasks] except Exception as exc: # Orchestrator is leaving — we must stop the backend producer too, @@ -280,8 +278,22 @@ async def _validate_and_emit(c: str) -> bool: # "fail" path; the same reasoning applies to any unplanned exit. # Pass `exc` so the backend telemetry span records the real cause # rather than a generic "Generation cancelled". + # TaskGroup wraps failures in ExceptionGroup; unwrap so telemetry and + # the chunk queue see the original exception, not the wrapper. + # ExceptionGroup (not BaseExceptionGroup) guarantees Exception elements. + if isinstance(exc, ExceptionGroup) and exc.exceptions: + reported_exc: Exception = exc.exceptions[0] + if len(exc.exceptions) > 1: + MelleaLogger.get_logger().warning( + "stream_with_chunking: %d validator(s) failed simultaneously; " + "reporting first, suppressing rest: %r", + len(exc.exceptions) - 1, + exc.exceptions[1:], + ) + else: + reported_exc = exc try: - await mot.cancel_generation(error=exc) + await mot.cancel_generation(error=reported_exc) except Exception as cleanup_exc: # Never let cleanup mask the original exception: log loudly and # continue to surface `exc` to the consumer. @@ -289,12 +301,12 @@ async def _validate_and_emit(c: str) -> bool: MelleaLogger.get_logger().warning( "stream_with_chunking: cancel_generation() raised during " "exception cleanup (original: %r, cleanup: %r)", - exc, + reported_exc, cleanup_exc, ) result.completed = False - result._orchestration_exception = exc - await result._chunk_queue.put(exc) + result._orchestration_exception = reported_exc + await result._chunk_queue.put(reported_exc) finally: # CancelledError (BaseException, not Exception) bypasses the except # block above, so cancel_generation() may not have been called. @@ -358,8 +370,9 @@ async def stream_with_chunking( requirement that did not return ``"fail"`` — both ``"pass"`` and ``"unknown"`` trigger final validation. On early exit, no ``validate()`` call is made; :attr:`StreamChunkingResult.final_validations` remains - empty. Requirements are cloned (``copy(req)``) before use so originals - are never mutated. + empty. Requirements are cloned (``copy(req)``) before backend generation + begins, so the originals are never mutated and a raising ``__copy__`` + cannot leak an in-flight backend task. Requirements that need context beyond the current chunk should accumulate it themselves across ``stream_validate`` calls (e.g. @@ -406,6 +419,12 @@ async def stream_with_chunking( Raises: ValueError: If *chunking* is a string that does not match any known alias (``"sentence"``, ``"word"``, ``"paragraph"``). + + Note: + Any exception raised by ``copy(req)`` on a ``quick_check_requirements`` + entry propagates to the caller; no backend generation is started in + that case. See :class:`~mellea.core.Requirement` for the ``__copy__`` + override contract. """ if isinstance(chunking, str): cls = _CHUNKING_ALIASES.get(chunking) @@ -416,15 +435,33 @@ async def stream_with_chunking( chunking = cls() opts: dict[str, Any] = {ModelOption.STREAM: True} - mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) - - result = StreamChunkingResult(mot, gen_ctx) + # Clone requirements before starting backend generation so that a raising + # __copy__ (an advertised extension point on Requirement) cannot leave the + # backend feeder task wedged against a full _async_queue with no consumer. cloned_reqs = [copy(req) for req in (quick_check_requirements or [])] val_backend = quick_check_backend if quick_check_backend is not None else backend - result._orchestration_task = asyncio.create_task( - _orchestrate_streaming(result, mot, gen_ctx, cloned_reqs, chunking, val_backend) - ) + mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) + try: + result = StreamChunkingResult(mot, gen_ctx) + coro = _orchestrate_streaming( + result, mot, gen_ctx, cloned_reqs, chunking, val_backend + ) + try: + result._orchestration_task = asyncio.create_task(coro) + except BaseException: + coro.close() # prevent "coroutine was never awaited" RuntimeWarning + raise + except BaseException: + try: + await mot.cancel_generation() + except Exception as cleanup_exc: + MelleaLogger.get_logger().warning( + "stream_with_chunking: cancel_generation() raised during " + "setup-path cleanup (cleanup: %r)", + cleanup_exc, + ) + raise return result diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 76f1e8b47..f5d817201 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -1080,3 +1080,177 @@ async def test_cancelled_flag_propagates_through_copy_methods() -> None: # Sanity: default-constructed MOT has _cancelled=False. fresh = ModelOutputThunk(value="x") assert fresh._cancelled is False + + +# --------------------------------------------------------------------------- +# Fix 1 — setup-path backend leak: copy(req) before generate_from_context +# --------------------------------------------------------------------------- + + +class _PlainReq(Requirement): + """Default shallow copy — cannot raise.""" + + def format_for_llm(self) -> str: + return "plain" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class _RaisingCopyReq(Requirement): + """__copy__ raises — simulates a user-defined Requirement with a faulty override.""" + + def __copy__(self) -> "_RaisingCopyReq": + raise ValueError("copy boom") + + def format_for_llm(self) -> str: + return "raising copy" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class _InstrumentedBackend(StreamingMockBackend): + """Counts generate_from_context calls and exposes the last MOT produced.""" + + def __init__(self, response: str, token_size: int = 1) -> None: + super().__init__(response, token_size) + self.generate_from_context_call_count = 0 + self.last_mot: ModelOutputThunk | None = None + + async def _generate_from_context( + self, + action: Any, + ctx: Any, + *, + format: Any = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk, Any]: + self.generate_from_context_call_count += 1 + mot, new_ctx = await super()._generate_from_context( + action, + ctx, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + self.last_mot = mot + return mot, new_ctx + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "req_cls,expect_raise", [(_PlainReq, False), (_RaisingCopyReq, True)] +) +async def test_stream_with_chunking_requirement_copy_contract( + req_cls: type, expect_raise: bool +) -> None: + """Fix 1: copy(req) runs before generate_from_context. + + On __copy__ failure the backend is never started (call_count == 0). + On success the backend is called exactly once. + """ + backend = _InstrumentedBackend("Hello world. ", token_size=2) + req = req_cls() + if expect_raise: + with pytest.raises(ValueError, match="copy boom"): + await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req] + ) + # Hard invariant: reorder ensures backend never starts on copy failure. + assert backend.generate_from_context_call_count == 0 + assert backend.last_mot is None + else: + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req] + ) + await result.acomplete() + assert backend.generate_from_context_call_count == 1 + assert backend.last_mot is not None + + +# --------------------------------------------------------------------------- +# Fix 3 — TaskGroup cancels peer validators on first failure +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stream_with_chunking_cancels_peer_validators() -> None: + """Fix 3: a failing stream_validate causes TaskGroup to cancel peer validators. + + One requirement raises immediately in stream_validate; the second sleeps + for 5 s and sets a flag on completion. Without TaskGroup the slow sibling + runs detached; with it the cancellation is observed. + """ + reached_final_stage = asyncio.Event() + + class _RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raiser" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + raise RuntimeError("validator failed") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=False) + + class _SlowReq(Requirement): + def format_for_llm(self) -> str: + return "slow" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + try: + await asyncio.sleep(5.0) + reached_final_stage.set() + return PartialValidationResult("pass") + except asyncio.CancelledError: + raise # propagate so TaskGroup knows we were cancelled + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + backend = StreamingMockBackend("Hello world. ", token_size=2) + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[_RaisingReq(), _SlowReq()] + ) + with pytest.raises(RuntimeError, match="validator failed"): + await result.acomplete() + + # Give the loop a tick; the slow sibling must NOT have run to completion. + await asyncio.sleep(0.05) + assert not reached_final_stage.is_set(), ( + "slow sibling was not cancelled by TaskGroup" + ) From d8018ddb0062e4896528d31f31ee6cb648109db5 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 13 May 2026 07:52:25 +0100 Subject: [PATCH 20/26] fix(core,hf): cooperative cancel via StoppingCriteria backed by threading.Event MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _cancel_hook: Callable[[], None] | None to ModelOutputThunk; called before asyncio task cancellation so backends running generation in a thread receive a cooperative stop signal; exceptions logged and suppressed - Reset _cancel_hook = None in all three copy sites (_copy_from, __copy__, __deepcopy__) — copied MOTs are distinct computations - Add _EventStoppingCriteria and _install_cancel_stopping_criteria helpers in huggingface.py; both streaming paths wire output._cancel_hook = event.set before creating the generation task - Move import logging to module level in base.py Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/backends/huggingface.py | 57 ++++++++++++++++++++++++ mellea/core/base.py | 28 ++++++++++++ test/core/test_base.py | 81 ++++++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 1ccb9abf7..58b6edcd7 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -22,6 +22,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.cache_utils import DynamicCache from transformers.generation.logits_process import LogitsProcessorList + from transformers.generation.stopping_criteria import ( + StoppingCriteria, + StoppingCriteriaList, + ) from transformers.generation.streamers import AsyncTextIteratorStreamer from transformers.generation.utils import GenerateDecoderOnlyOutput from transformers.modeling_utils import PreTrainedModel @@ -74,6 +78,45 @@ ) from .utils import to_chat, to_tool_calls + +class _EventStoppingCriteria(StoppingCriteria): + """StoppingCriteria that signals the model to stop when a threading.Event is set. + + Used by LocalHFBackend to implement cooperative cancellation: when + ``cancel_generation`` is called, it sets the backing event via + ``_cancel_hook`` before cancelling the asyncio task, giving the HF + ``model.generate`` thread a chance to exit cleanly rather than running + to completion. + """ + + def __init__(self, event: threading.Event) -> None: + self._event = event + + def __call__(self, input_ids: Any, scores: Any, **kwargs: Any) -> bool: # type: ignore[override] + return self._event.is_set() + + +def _install_cancel_stopping_criteria( + generate_options: dict[str, Any], streaming_kwargs: dict[str, Any] +) -> threading.Event: + """Wire a cooperative-cancel event into the generate call's stopping criteria. + + Pops any caller-supplied ``stopping_criteria`` from *generate_options* (to + avoid passing it twice via both ``**generate_options`` and + ``**streaming_kwargs``), prepends an :class:`_EventStoppingCriteria` backed + by a fresh ``threading.Event``, and stores the merged list in + *streaming_kwargs*. Returns the event so the caller can arm + ``output._cancel_hook = event.set``. + """ + cancel_event = threading.Event() + user_sc = generate_options.pop("stopping_criteria", None) + streaming_kwargs["stopping_criteria"] = StoppingCriteriaList( + [_EventStoppingCriteria(cancel_event)] + + (list(user_sc) if user_sc is not None else []) + ) + return cancel_event + + """A configuration type for the unhappy path: Tokenizer * Model * torch device string Huggingface backends can initialize themselves from a model string if the transformers `Auto*` classes can be used. Therefore, a TransformersTorchConfig usually isn't required. However, sometimes a model needs special care to instantiate properly, or a custom device type needs to bse used. Instead of trying to do a lot of partial magic, we basically have two modaliites: either the constructor can figure out everything from the model_id, or the user has to provide an entire config. @@ -839,6 +882,10 @@ async def _generate_from_context_with_kv_cache( # Filter out chat template-only options before passing to generate() generate_options = self._filter_chat_template_only_options(model_options) + _cancel_event = _install_cancel_stopping_criteria( + generate_options, streaming_kwargs + ) + linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None _input_text, input_ids, merged_cache, attention_mask = ( @@ -867,6 +914,9 @@ async def _generate_from_context_with_kv_cache( ) output = ModelOutputThunk(None) + # Arm the cancel hook before creating tasks so a cancel racing + # task creation still finds the hook set. + output._cancel_hook = _cancel_event.set output._start = datetime.datetime.now() output._context = ctx.view_for_generation() output._action = action @@ -1002,6 +1052,10 @@ async def _generate_from_context_standard( # Filter out chat template-only options before passing to generate() generate_options = self._filter_chat_template_only_options(model_options) + _cancel_event = _install_cancel_stopping_criteria( + generate_options, streaming_kwargs + ) + chat_response = asyncio.to_thread( self._generate_with_adapter_lock, "", # Empty for no adapters. @@ -1016,6 +1070,9 @@ async def _generate_from_context_standard( ) output = ModelOutputThunk(None) + # Arm the cancel hook before creating tasks so a cancel racing + # task creation still finds the hook set. + output._cancel_hook = _cancel_event.set output._start = datetime.datetime.now() output._context = ctx.view_for_generation() output._action = action diff --git a/mellea/core/base.py b/mellea/core/base.py index a8ff2f3aa..ec5111ee2 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -17,6 +17,7 @@ import binascii import datetime import enum +import logging from collections.abc import Callable, Coroutine, Iterable, Mapping from copy import copy, deepcopy from dataclasses import dataclass @@ -345,6 +346,14 @@ def __init__( self._generate_extra: asyncio.Task[Any] | None = ( None # Currently only used by hf. ) + # Optional cooperative-cancel hook called before asyncio task cancellation. + # Backends that run generation in a thread (e.g. HuggingFace via + # asyncio.to_thread) set this to a non-blocking callable (e.g. + # threading.Event.set) so the thread receives a stop signal before the + # task wrapper is cancelled. Must be non-blocking; exceptions are logged + # and suppressed. Copied MOTs reset this to None — each computation owns + # its own thread signal. + self._cancel_hook: Callable[[], None] | None = None self._process: Callable[[ModelOutputThunk, Any], Coroutine] | None = None self._post_process: Callable[[ModelOutputThunk], Coroutine] | None = None self._on_computed: Callable[[ModelOutputThunk], Coroutine] | None = None @@ -395,6 +404,16 @@ def _drain() -> None: except asyncio.QueueEmpty: break + # Signal any backend thread before cancelling the asyncio task wrapper + # so the thread can stop cooperatively instead of running to completion. + if self._cancel_hook is not None: + try: + self._cancel_hook() + except Exception as hook_exc: + logging.getLogger(__name__).warning( + "cancel_generation: _cancel_hook raised (suppressed): %r", hook_exc + ) + if self._generate is not None and not self._generate.done(): self._generate.cancel() @@ -460,6 +479,9 @@ def _copy_from(self, other: ModelOutputThunk) -> None: self.generation = other.generation self._generate_log = other._generate_log self._cancelled = other._cancelled + # _cancel_hook is deliberately not copied: _copy_from swaps output state, + # not backend-thread plumbing, which is tied to the original computation. + self._cancel_hook = None def is_computed(self) -> bool: """Returns true only if this Thunk has already been filled. @@ -689,6 +711,9 @@ def __copy__(self) -> ModelOutputThunk: copied._computed = self._computed copied._cancelled = self._cancelled + # _cancel_hook is not forwarded: a copied MOT is a distinct computation + # and must not share the original's backend thread signal. + copied._cancel_hook = None copied._thinking = self._thinking copied._action = self._action copied._context = self._context @@ -718,6 +743,9 @@ def __deepcopy__(self, memo: dict) -> ModelOutputThunk: deepcopied.tool_calls = deepcopy(self.tool_calls) deepcopied._computed = self._computed deepcopied._cancelled = self._cancelled + # _cancel_hook is not forwarded: a deepcopied MOT is a distinct computation + # and must not share the original's backend thread signal. + deepcopied._cancel_hook = None deepcopied._thinking = self._thinking deepcopied._action = deepcopy(self._action) deepcopied._context = copy( diff --git a/test/core/test_base.py b/test/core/test_base.py index 32ad9ab10..fb98b0f7c 100644 --- a/test/core/test_base.py +++ b/test/core/test_base.py @@ -191,3 +191,84 @@ def test_mot_deep_copy_clones_generation(): if __name__ == "__main__": pytest.main([__file__]) + + +# --------------------------------------------------------------------------- +# Fix 2 — cancel_generation invokes _cancel_hook before task cancellation +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cancel_generation_invokes_cancel_hook_before_task_cancel() -> None: + """Fix 2: _cancel_hook fires and cancel_generation() returns promptly. + + Simulates a backend thread that blocks for 5 s unless the hook sets the + event. Without the hook, cancel_generation() would only observe the asyncio + task as CancelledError but the thread would keep running — on a slow box + that can mean the task wrapper hangs past the 1 s timeout here. With the + hook, the event is set first, the thread unblocks, and the whole path + completes within the timeout. + """ + import asyncio + import threading + + hook_called = threading.Event() + thread_released = threading.Event() + + def hook() -> None: + hook_called.set() + thread_released.set() + + mot = ModelOutputThunk(value=None) + mot._cancel_hook = hook # type: ignore[attr-defined] + + # Task that blocks in a thread until thread_released is set. + async def spin() -> None: + await asyncio.to_thread(thread_released.wait, 5.0) + + mot._generate = asyncio.create_task(spin()) # type: ignore[attr-defined] + await asyncio.sleep(0) # let the task reach to_thread + + # Must return within 1 s; without the hook it would hang ~5 s. + await asyncio.wait_for(mot.cancel_generation(), timeout=1.0) # type: ignore[attr-defined] + + assert hook_called.is_set(), "_cancel_hook was never called" + assert mot._cancelled is True # type: ignore[attr-defined] + + +def test_cancel_hook_not_forwarded_by_copy_methods() -> None: + """Fix 2: copied MOTs must not inherit _cancel_hook (distinct computation).""" + import copy as copy_mod + + def _hook() -> None: + pass + + mot = ModelOutputThunk(value="x") + mot._cancel_hook = _hook # type: ignore[attr-defined] + + shallow = copy_mod.copy(mot) + assert shallow._cancel_hook is None, "__copy__ must reset _cancel_hook to None" # type: ignore[attr-defined] + + deep = copy_mod.deepcopy(mot) + assert deep._cancel_hook is None, "__deepcopy__ must reset _cancel_hook to None" # type: ignore[attr-defined] + + target = ModelOutputThunk(value="original") + target._copy_from(mot) # type: ignore[attr-defined] + assert target._cancel_hook is None, "_copy_from must reset _cancel_hook to None" # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_cancel_generation_hook_exception_is_suppressed() -> None: + """Fix 2: a faulty _cancel_hook must not mask cancel_generation itself.""" + import asyncio + + def _bad_hook() -> None: + raise RuntimeError("hook exploded") + + mot = ModelOutputThunk(value=None) + mot._cancel_hook = _bad_hook # type: ignore[attr-defined] + + # No _generate task — cancel_generation still runs the hook path. + # The hook raises, but cancel_generation must complete without propagating. + await mot.cancel_generation() # type: ignore[attr-defined] + assert mot._cancelled is True # type: ignore[attr-defined] From bf9a62bcea6b6e6a62c34faa65f39b960dc3e625 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 13 May 2026 12:10:16 +0100 Subject: [PATCH 21/26] fix(stdlib,core,hf): three pre-merge correctness fixes - huggingface.py: guard _install_cancel_stopping_criteria behind `if stream:` at both call sites (chat and non-chat paths). Previously the StoppingCriteria was installed unconditionally, silently wrapping any caller-supplied stopping_criteria and adding per-token overhead on all non-streaming HF calls. Non-streaming paths have no orchestrator to call cancel_generation(), so the hook was dead code on that path. - base.py: in cancel_generation(), use task.cancelling() > 0 to distinguish a child task's CancelledError (expected after .cancel()) from an external CancelledError injected into the outer task. Previously both were caught and swallowed, meaning an external cancellation of the orchestration task during cleanup would be silently absorbed rather than propagated. - streaming.py: assert the returned MOT is not already computed before starting the orchestrator. A backend that silently ignores ModelOption.STREAM would previously cause the orchestrator loop to skip, produce empty full_text, and pass final validators against "". The new RuntimeError fails fast with an actionable message. Assisted-by: Claude Code --- mellea/backends/huggingface.py | 26 ++++++++++++++++++-------- mellea/core/base.py | 16 ++++++++++++++-- mellea/stdlib/streaming.py | 5 +++++ test/stdlib/test_streaming.py | 30 ++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 10 deletions(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 58b6edcd7..8105b202c 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -882,9 +882,14 @@ async def _generate_from_context_with_kv_cache( # Filter out chat template-only options before passing to generate() generate_options = self._filter_chat_template_only_options(model_options) - _cancel_event = _install_cancel_stopping_criteria( - generate_options, streaming_kwargs - ) + # Only install cooperative-cancel plumbing on the streaming path. + # Non-streaming calls have no orchestrator calling cancel_generation(), + # so the hook would be dead code and the StoppingCriteria would silently + # wrap any user-supplied stopping_criteria on every decode step. + if stream: + _cancel_event = _install_cancel_stopping_criteria( + generate_options, streaming_kwargs + ) linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None @@ -916,7 +921,7 @@ async def _generate_from_context_with_kv_cache( output = ModelOutputThunk(None) # Arm the cancel hook before creating tasks so a cancel racing # task creation still finds the hook set. - output._cancel_hook = _cancel_event.set + output._cancel_hook = _cancel_event.set if stream else None output._start = datetime.datetime.now() output._context = ctx.view_for_generation() output._action = action @@ -1052,9 +1057,14 @@ async def _generate_from_context_standard( # Filter out chat template-only options before passing to generate() generate_options = self._filter_chat_template_only_options(model_options) - _cancel_event = _install_cancel_stopping_criteria( - generate_options, streaming_kwargs - ) + # Only install cooperative-cancel plumbing on the streaming path. + # Non-streaming calls have no orchestrator calling cancel_generation(), + # so the hook would be dead code and the StoppingCriteria would silently + # wrap any user-supplied stopping_criteria on every decode step. + if stream: + _cancel_event = _install_cancel_stopping_criteria( + generate_options, streaming_kwargs + ) chat_response = asyncio.to_thread( self._generate_with_adapter_lock, @@ -1072,7 +1082,7 @@ async def _generate_from_context_standard( output = ModelOutputThunk(None) # Arm the cancel hook before creating tasks so a cancel racing # task creation still finds the hook set. - output._cancel_hook = _cancel_event.set + output._cancel_hook = _cancel_event.set if stream else None output._start = datetime.datetime.now() output._context = ctx.view_for_generation() output._action = action diff --git a/mellea/core/base.py b/mellea/core/base.py index ec5111ee2..71e8e9b7d 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -426,13 +426,25 @@ def _drain() -> None: if self._generate is not None: try: await self._generate - except (asyncio.CancelledError, Exception): + except asyncio.CancelledError: + # Re-raise if the *outer* task is being cancelled (Python 3.12+ + # task.cancelling() > 0) so we don't silently absorb external + # cancellation. For the inner task's own CancelledError (the + # expected result of .cancel() above), cancelling() is 0. + cur = asyncio.current_task() + if cur is not None and cur.cancelling() > 0: + raise + except Exception: pass if self._generate_extra is not None: try: await self._generate_extra - except (asyncio.CancelledError, Exception): + except asyncio.CancelledError: + cur = asyncio.current_task() + if cur is not None and cur.cancelling() > 0: + raise + except Exception: pass # Drain again for any final item the task put before terminating. diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index df9ad5a80..f9889cad3 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -443,6 +443,11 @@ async def stream_with_chunking( val_backend = quick_check_backend if quick_check_backend is not None else backend mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) + if mot.is_computed(): + raise RuntimeError( + "stream_with_chunking() requires a streaming backend; the backend returned " + "an already-computed MOT. Ensure the backend honours ModelOption.STREAM." + ) try: result = StreamChunkingResult(mot, gen_ctx) coro = _orchestrate_streaming( diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index f5d817201..0818d6a49 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -1254,3 +1254,33 @@ async def validate( assert not reached_final_stage.is_set(), ( "slow sibling was not cancelled by TaskGroup" ) + + +@pytest.mark.asyncio +async def test_stream_with_chunking_rejects_precomputed_mot() -> None: + """Backend returning an already-computed MOT raises RuntimeError immediately. + + stream_with_chunking() requires streaming; a pre-computed MOT would cause + the orchestrator loop to skip entirely, producing empty output and silently + passing all final validators against an empty string. + """ + + class PrecomputedBackend(Backend): + async def _generate_from_context( + self, + action: Any, + ctx: Any, + *, + format: Any = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk, Any]: + return ModelOutputThunk(value="already done"), ctx + + async def generate_from_raw( + self, actions: Any, ctx: Any, **kwargs: Any + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + with pytest.raises(RuntimeError, match="already-computed MOT"): + await stream_with_chunking(_action(), PrecomputedBackend(), _ctx()) From 9a715d62519733ddcededdaad019963a47517373 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 13 May 2026 12:31:15 +0100 Subject: [PATCH 22/26] fix: address second-review feedback on bf9a62bc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - huggingface.py: replace `_cancel_event.set if stream else None` ternary with `if stream: output._cancel_hook = _cancel_event.set` to silence Pyright possibly-unbound warning and improve readability (no-op change since _cancel_hook defaults to None) - base.py: correct misleading comment "Python 3.12+" → "Python 3.11+"; Task.cancelling() was added in 3.11, the comment was wrong - streaming.py: add RuntimeError to Raises: section of stream_with_chunking() docstring (already raised since bf9a62bc) - test/core/test_base.py: add regression test for the outer-cancellation re-raise path in cancel_generation(); exercises cur.cancelling() > 0 branch that was added in bf9a62bc but had no direct test coverage Assisted-by: Claude Code --- mellea/backends/huggingface.py | 6 +++-- mellea/core/base.py | 2 +- mellea/stdlib/streaming.py | 4 +++ test/core/test_base.py | 45 ++++++++++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 3 deletions(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 8105b202c..337c27266 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -921,7 +921,8 @@ async def _generate_from_context_with_kv_cache( output = ModelOutputThunk(None) # Arm the cancel hook before creating tasks so a cancel racing # task creation still finds the hook set. - output._cancel_hook = _cancel_event.set if stream else None + if stream: + output._cancel_hook = _cancel_event.set output._start = datetime.datetime.now() output._context = ctx.view_for_generation() output._action = action @@ -1082,7 +1083,8 @@ async def _generate_from_context_standard( output = ModelOutputThunk(None) # Arm the cancel hook before creating tasks so a cancel racing # task creation still finds the hook set. - output._cancel_hook = _cancel_event.set if stream else None + if stream: + output._cancel_hook = _cancel_event.set output._start = datetime.datetime.now() output._context = ctx.view_for_generation() output._action = action diff --git a/mellea/core/base.py b/mellea/core/base.py index 71e8e9b7d..8b03250cc 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -427,7 +427,7 @@ def _drain() -> None: try: await self._generate except asyncio.CancelledError: - # Re-raise if the *outer* task is being cancelled (Python 3.12+ + # Re-raise if the *outer* task is being cancelled (Python 3.11+ # task.cancelling() > 0) so we don't silently absorb external # cancellation. For the inner task's own CancelledError (the # expected result of .cancel() above), cancelling() is 0. diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index f9889cad3..456ad1164 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -419,6 +419,10 @@ async def stream_with_chunking( Raises: ValueError: If *chunking* is a string that does not match any known alias (``"sentence"``, ``"word"``, ``"paragraph"``). + RuntimeError: If the backend returns an already-computed + :class:`~mellea.core.base.ModelOutputThunk` instead of a streaming + one. This indicates the backend is not honouring + ``ModelOption.STREAM``. Note: Any exception raised by ``copy(req)`` on a ``quick_check_requirements`` diff --git a/test/core/test_base.py b/test/core/test_base.py index fb98b0f7c..213a16e6e 100644 --- a/test/core/test_base.py +++ b/test/core/test_base.py @@ -272,3 +272,48 @@ def _bad_hook() -> None: # The hook raises, but cancel_generation must complete without propagating. await mot.cancel_generation() # type: ignore[attr-defined] assert mot._cancelled is True # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_cancel_generation_propagates_outer_cancellation() -> None: + """Outer cancellation of the cancel_generation() task must re-raise CancelledError. + + When cancel_generation() is awaiting self._generate and the *cancel_generation* + task is itself cancelled from outside, cur.cancelling() > 0 and the + CancelledError must propagate — not be swallowed by the bare ``pass`` path. + """ + import asyncio + + inner_cancelled = asyncio.Event() + + async def _absorbs_first_cancel() -> None: + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + # Signal that cancel_generation() has called .cancel() and is + # now blocked at ``await self._generate``. + inner_cancelled.set() + # Absorb this cancel so cancel_generation() stays at the await. + await asyncio.sleep(60) + + mot = ModelOutputThunk(value=None) + mot._generate = asyncio.create_task(_absorbs_first_cancel()) # type: ignore[attr-defined] + await asyncio.sleep(0) + + cg_task = asyncio.create_task(mot.cancel_generation()) # type: ignore[attr-defined] + # Wait until _generate has absorbed cancel_generation()'s .cancel() call — + # at that point cg_task is blocked at ``await self._generate``. + await asyncio.wait_for(inner_cancelled.wait(), timeout=2.0) + + # Cancel cancel_generation() from outside (simulates asyncio.wait_for timeout + # or an outer TaskGroup cancelling this coroutine). + cg_task.cancel() + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(cg_task, timeout=2.0) + + # Cleanup: stop the still-running _generate task. + mot._generate.cancel() # type: ignore[attr-defined] + try: + await asyncio.wait_for(mot._generate, timeout=1.0) # type: ignore[attr-defined] + except (TimeoutError, asyncio.CancelledError): + pass From f3e3501bb428a232b364d6be318e048ed6a07b0f Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 13 May 2026 13:05:21 +0100 Subject: [PATCH 23/26] docs(core): add Raises section to cancel_generation() docstring The CancelledError re-raise path added in bf9a62bc means the method now raises asyncio.CancelledError when the calling task is being cancelled (cur.cancelling() > 0). The docstring quality gate (build-and-validate CI) flags any function that raises without a Raises: section. Assisted-by: Claude Code --- mellea/core/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mellea/core/base.py b/mellea/core/base.py index 8b03250cc..3ea17e088 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -393,6 +393,13 @@ async def cancel_generation(self, error: Exception | None = None) -> None: requirement failure or an unhandled exception from a streaming validator). When ``None``, a generic ``RuntimeError("Generation cancelled")`` is recorded. + + Raises: + asyncio.CancelledError: Re-raised when the *calling* task itself is + being cancelled (``asyncio.current_task().cancelling() > 0``). + This prevents external cancellation (e.g. ``asyncio.wait_for`` + timeout) from being silently absorbed while awaiting the inner + generation task. """ if self._computed: return From 2f2e352b85628136964ae9583cb88b905e887d78 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 13 May 2026 13:06:43 +0100 Subject: [PATCH 24/26] docs(agents): add docstring quality gate to self-review checklist After triggering the build-and-validate quality gate twice in the same PR (missing Raises: on stream_with_chunking and cancel_generation), add an explicit pre-push check to the self-review checklist. Any diff that adds raise statements to library code should run the audit tool locally before pushing. Assisted-by: Claude Code --- AGENTS.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index cc068f107..170dbca22 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -126,6 +126,11 @@ Use the tool's common name (e.g., GitHub Copilot, Cursor, etc.). 3. New functions typed with concise docstrings? 4. Unit tests added for new functionality? 5. Avoided over-engineering? +6. If the diff adds `raise` statements to library code (`mellea/` but not `test/`), run the docstring quality gate before pushing: + ```bash + uv run python tooling/docs-autogen/audit_coverage.py --docs-dir docs/docs/api --quality --fail-on-quality --threshold 100 --orphans + ``` + Every new `raise` in a public function requires a matching `Raises:` entry — the `build-and-validate` CI job enforces this with `--fail-on-quality`. ## 11. Writing Tests From 66260fedba7383748e9d761f9bfbe5ad000bf292 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 18 May 2026 13:51:10 +0100 Subject: [PATCH 25/26] fix: address review feedback from psschwei + jakelorocco MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit streaming.py — raise-once contract on the task fallback path Per psschwei's review: acomplete()'s fallback that reads self._orchestration_task.exception() did not honour the raise-once contract that the _orchestration_exception path already enforced. In addition to the missing flag, ``task.exception()`` itself raises CancelledError on a cancelled task (rather than returning it), so the original ``raise task_exc`` was unreachable for the cancellation path — the cancel was surfaced as a side-effect of ``.exception()`` and every subsequent acomplete() call hit it again. Restructure to: short-circuit on _exception_surfaced, branch on task.cancelled() before calling .exception(), and set the flag in both branches before raising. Adds test_external_cancellation_acomplete_raise_once to lock in the raise-once contract for the fallback path. streaming_chunking.py — chunker-agnostic example requirement Per jakelorocco's review: MaxSentencesReq counted stream_validate calls, which only equates to sentences when paired with the sentence chunker. The example now counts sentence-end punctuation in the chunk text via a small regex, so the same instance behaves correctly under any chunker. Updated the docstring to teach the content-driven pattern over chunker-coupled logic. Assisted-by: Claude Code --- docs/examples/streaming/streaming_chunking.py | 22 ++++++++++--- mellea/stdlib/streaming.py | 11 +++++++ test/stdlib/test_streaming.py | 33 +++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py index 737ea7486..8eb0b18f1 100644 --- a/docs/examples/streaming/streaming_chunking.py +++ b/docs/examples/streaming/streaming_chunking.py @@ -10,6 +10,7 @@ """ import asyncio +import re from mellea.core.backend import Backend from mellea.core.base import Context @@ -21,14 +22,25 @@ from mellea.stdlib.components import Instruction from mellea.stdlib.streaming import stream_with_chunking +# Crude sentence-terminator detector. A run of ``.``/``!``/``?`` counts once +# (so "..." and "!!!" are a single terminator). Good enough for an example; +# production code might use spaCy/NLTK for proper sentence segmentation. +_SENTENCE_END = re.compile(r"[.!?]+") + class MaxSentencesReq(Requirement): """Fails if the model generates more than *limit* sentences mid-stream. - Each ``stream_validate`` call receives one complete sentence from the - :class:`~mellea.stdlib.chunking.SentenceChunker`. The running count is - maintained on ``self`` — this is the standard pattern for requirements - that need context beyond a single chunk. + Counts sentence terminators in the chunk *text* rather than counting + ``stream_validate`` calls. This makes the requirement **chunker-agnostic**: + the same instance behaves correctly with sentence, word, or paragraph + chunking, because the semantics depend on content, not on the chunker's + structural decisions. + + When writing your own streaming requirements, prefer this content-driven + pattern over coupling the requirement to a specific chunker. Reach for + chunker-coupled logic only when the requirement is genuinely a property + of chunk boundaries (e.g. "no chunk longer than N tokens"). """ def __init__(self, limit: int) -> None: @@ -42,7 +54,7 @@ def format_for_llm(self) -> str: async def stream_validate( self, chunk: str, *, backend: Backend, ctx: Context ) -> PartialValidationResult: - self._count += 1 + self._count += len(_SENTENCE_END.findall(chunk)) if self._count > self._limit: return PartialValidationResult( "fail", diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 456ad1164..8404d8cbb 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -134,8 +134,19 @@ async def acomplete(self) -> None: self._orchestration_exception = None raise exc if self._orchestration_task is not None and self._orchestration_task.done(): + # Raise-once: a prior call already surfaced the exception. + if self._exception_surfaced: + return + # ``task.exception()`` raises CancelledError on a cancelled task + # (rather than returning it), so check cancelled status first. + # This branch covers BaseException paths that bypass the + # ``except Exception`` handler in ``_orchestrate_streaming``. + if self._orchestration_task.cancelled(): + self._exception_surfaced = True + raise asyncio.CancelledError() task_exc = self._orchestration_task.exception() if task_exc is not None: + self._exception_surfaced = True raise task_exc @property diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 0818d6a49..a6642d1f4 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -940,6 +940,39 @@ async def test_external_task_cancellation_releases_consumers() -> None: await asyncio.wait_for(result.acomplete(), timeout=2.0) +@pytest.mark.asyncio +async def test_external_cancellation_acomplete_raise_once() -> None: + """Raise-once contract holds for the task-fallback path on external cancel. + + CancelledError bypasses the orchestrator's ``except Exception`` handler, + so ``_orchestration_exception`` is never set. ``acomplete()`` surfaces the + cancel via ``self._orchestration_task.exception()`` instead — and that + branch must also flip ``_exception_surfaced`` so a second ``acomplete()`` + call does not raise the same exception twice. + """ + response = "word " * 200 + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[AlwaysUnknownReq()], + chunking="word", + ) + + assert result._orchestration_task is not None + await asyncio.sleep(0.01) + result._orchestration_task.cancel() + await asyncio.wait_for(result._done.wait(), timeout=2.0) + + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(result.acomplete(), timeout=2.0) + + # Second call must NOT re-raise — raise-once contract. + await asyncio.wait_for(result.acomplete(), timeout=2.0) + + @pytest.mark.asyncio async def test_raise_once_acomplete_then_astream() -> None: """Regression for the raise-once stash bug: acomplete() first, astream() second. From 2cac22c47bc68155d1f7fb640a75e9f0fa9470a3 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 19 May 2026 11:03:35 +0100 Subject: [PATCH 26/26] fix(stdlib): address jakelorocco review feedback on stream_with_chunking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename quick_check_requirements → requirements and quick_check_backend → validation_backend throughout (signature, docstring, tests, example); the quick_check_* names implied an early-success semantics that was never implemented — pass mid-stream is informational only (#924, #900) - Fix full_text on early exit to preserve original inter-chunk spacing: replace emitted_text (chunk concatenation, loses whitespace) with an emitted_end cursor into accumulated so full_text = accumulated[:emitted_end] - Set parsed_repr = value in as_thunk on cancelled results for consistency with plain-text normal completion; document typed-output limitation - Tighten as_thunk return type to ModelOutputThunk[str] - Fix MaxSentencesReq example: validate() now mirrors stream_validate() result instead of always returning True - Add 3-chunk early-exit test to cover the whitespace preservation fix Assisted-by: Claude Code Signed-off-by: Nigel Jones --- docs/examples/streaming/streaming_chunking.py | 4 +- mellea/stdlib/streaming.py | 61 +++++----- test/stdlib/test_streaming.py | 110 +++++++----------- 3 files changed, 81 insertions(+), 94 deletions(-) diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py index 8eb0b18f1..81056abaf 100644 --- a/docs/examples/streaming/streaming_chunking.py +++ b/docs/examples/streaming/streaming_chunking.py @@ -70,7 +70,7 @@ async def validate( format: type | None = None, model_options: dict | None = None, ) -> ValidationResult: - return ValidationResult(result=True) + return ValidationResult(result=self._count <= self._limit) async def main() -> None: @@ -86,7 +86,7 @@ async def main() -> None: req = MaxSentencesReq(limit=3) result = await stream_with_chunking( - action, backend, ctx, quick_check_requirements=[req], chunking="sentence" + action, backend, ctx, requirements=[req], chunking="sentence" ) print("Streaming chunks as they arrive:") diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 8404d8cbb..dfdbcb232 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -150,7 +150,7 @@ async def acomplete(self) -> None: raise task_exc @property - def as_thunk(self) -> ModelOutputThunk: + def as_thunk(self) -> ModelOutputThunk[str]: """Wrap the output as a computed :class:`~mellea.core.base.ModelOutputThunk`. Returns a new thunk with ``value`` set to :attr:`full_text` and @@ -161,14 +161,17 @@ def as_thunk(self) -> ModelOutputThunk: Note: On early exit, ``cancel_generation()`` forces the MOT into a computed state without running the backend's - ``post_processing()``. Telemetry fields on the returned thunk - (``generation.usage``, ``generation.ttfb_ms``, etc.) may - therefore be ``None`` or reflect the partial state at - cancellation time. ``value`` and ``streaming`` are reliable; - usage totals are not. + ``post_processing()``. ``value`` and ``streaming`` are + reliable. ``parsed_repr`` is set to the raw text (same as + ``value``) — consistent with normal completion for plain-text + outputs, but for typed outputs the backend-parsed representation + will not be available. Telemetry fields (``generation.usage``, + ``generation.ttfb_ms``, etc.) may be ``None`` or reflect the + partial state at cancellation time; usage totals are not + recoverable. Returns: - ModelOutputThunk: A computed thunk containing the streamed output. + ModelOutputThunk[str]: A computed thunk containing the streamed output. Raises: RuntimeError: If called before :meth:`acomplete` has returned. @@ -180,6 +183,7 @@ def as_thunk(self) -> ModelOutputThunk: thunk = ModelOutputThunk(value=self.full_text) thunk._cancelled = self._mot._cancelled thunk.generation = copy(self._mot.generation) + thunk.parsed_repr = thunk.value # type: ignore[assignment] return thunk @@ -192,7 +196,7 @@ async def _orchestrate_streaming( val_backend: Backend, ) -> None: accumulated = "" - emitted_text = "" + emitted_end = 0 # byte offset into accumulated after the last emitted chunk prev_chunk_count = 0 failed_indices: set[int] = set() early_exit = False @@ -248,7 +252,9 @@ async def _validate_and_emit(c: str) -> bool: result.completed = False await mot.cancel_generation() break - emitted_text += c + pos = accumulated.find(c, emitted_end) + if pos >= 0: + emitted_end = pos + len(c) if early_exit: break # break the while loop; cancel_generation() already set _computed=True @@ -263,14 +269,15 @@ async def _validate_and_emit(c: str) -> bool: early_exit = True result.completed = False break - emitted_text += c + pos = accumulated.find(c, emitted_end) + if pos >= 0: + emitted_end = pos + len(c) - # On early exit, full_text reflects only validated-and-emitted chunks - # so it matches exactly what the consumer received via astream(). - # On natural completion emitted_text == accumulated (every character - # ends up in some chunk or flushed fragment), so either value is - # equivalent; accumulated is used to preserve the original raw text. - result.full_text = emitted_text if early_exit else accumulated + # On early exit, full_text is the prefix of accumulated up to and + # including the last emitted chunk — preserving original inter-chunk + # spacing from the token stream (chunk concatenation would strip it). + # On natural completion, accumulated is used directly. + result.full_text = accumulated[:emitted_end] if early_exit else accumulated non_failed = [ req for i, req in enumerate(cloned_reqs) if i not in failed_indices @@ -344,9 +351,9 @@ async def stream_with_chunking( backend: Backend, ctx: Context, *, - quick_check_requirements: Sequence[Requirement] | None = None, + requirements: Sequence[Requirement] | None = None, chunking: str | ChunkingStrategy = "sentence", - quick_check_backend: Backend | None = None, + validation_backend: Backend | None = None, ) -> StreamChunkingResult: """Generate a streaming response with per-chunk validation. @@ -412,13 +419,13 @@ async def stream_with_chunking( action: The component or content block to generate from. backend: The backend used for generation and final validation. ctx: The generation context. - quick_check_requirements: Sequence of requirements to validate against - each chunk during streaming. ``None`` disables streaming validation - (chunks are still produced; ``validate()`` is not called at stream end). + requirements: Sequence of requirements to validate against each chunk + during streaming. ``None`` disables streaming validation (chunks + are still produced; ``validate()`` is not called at stream end). chunking: Chunking strategy — either a :class:`~mellea.stdlib.chunking.ChunkingStrategy` instance or one of the string aliases ``"sentence"`` (default), ``"word"``, or ``"paragraph"``. - quick_check_backend: Optional alternate backend for both + validation_backend: Optional alternate backend for both ``stream_validate`` and final ``validate`` calls. When ``None``, *backend* is used for validation. @@ -436,9 +443,9 @@ async def stream_with_chunking( ``ModelOption.STREAM``. Note: - Any exception raised by ``copy(req)`` on a ``quick_check_requirements`` - entry propagates to the caller; no backend generation is started in - that case. See :class:`~mellea.core.Requirement` for the ``__copy__`` + Any exception raised by ``copy(req)`` on a ``requirements`` entry + propagates to the caller; no backend generation is started in that + case. See :class:`~mellea.core.Requirement` for the ``__copy__`` override contract. """ if isinstance(chunking, str): @@ -454,8 +461,8 @@ async def stream_with_chunking( # Clone requirements before starting backend generation so that a raising # __copy__ (an advertised extension point on Requirement) cannot leave the # backend feeder task wedged against a full _async_queue with no consumer. - cloned_reqs = [copy(req) for req in (quick_check_requirements or [])] - val_backend = quick_check_backend if quick_check_backend is not None else backend + cloned_reqs = [copy(req) for req in (requirements or [])] + val_backend = validation_backend if validation_backend is not None else backend mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) if mot.is_computed(): diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index a6642d1f4..ad042df56 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -221,7 +221,7 @@ async def test_normal_completion_calls_validate_at_stream_end() -> None: req = AlwaysUnknownReq() result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + _action(), backend, _ctx(), requirements=[req], chunking="sentence" ) await result.acomplete() @@ -241,7 +241,7 @@ async def test_early_exit_on_fail() -> None: req = FailAfterWordsReq(threshold=4) result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" + _action(), backend, _ctx(), requirements=[req], chunking="word" ) await result.acomplete() @@ -264,20 +264,12 @@ async def test_clone_isolation_across_retries() -> None: backend = StreamingMockBackend(response, token_size=4) r1 = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=original_reqs, - chunking="sentence", + _action(), backend, _ctx(), requirements=original_reqs, chunking="sentence" ) await r1.acomplete() r2 = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=original_reqs, - chunking="sentence", + _action(), backend, _ctx(), requirements=original_reqs, chunking="sentence" ) await r2.acomplete() @@ -286,8 +278,8 @@ async def test_clone_isolation_across_retries() -> None: @pytest.mark.asyncio -async def test_quick_check_backend_routing() -> None: - """stream_validate and validate receive quick_check_backend, not the main backend.""" +async def test_validation_backend_routing() -> None: + """stream_validate and validate receive validation_backend, not the main backend.""" response = "One sentence. Two sentences. " main_backend = StreamingMockBackend(response, token_size=3) val_backend = StreamingMockBackend("unused", token_size=1) @@ -309,9 +301,9 @@ def _capturing_copy(self: BackendRecordingReq) -> BackendRecordingReq: _action(), main_backend, _ctx(), - quick_check_requirements=[req], + requirements=[req], chunking="sentence", - quick_check_backend=val_backend, + validation_backend=val_backend, ) await result.acomplete() finally: @@ -335,7 +327,7 @@ async def test_early_exit_does_not_deadlock() -> None: req = FailAfterWordsReq(threshold=3) result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" + _action(), backend, _ctx(), requirements=[req], chunking="word" ) # 5-second timeout — should complete in milliseconds on success await asyncio.wait_for(result.acomplete(), timeout=5.0) @@ -439,11 +431,7 @@ def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] try: result = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=[req], - chunking="sentence", + _action(), backend, _ctx(), requirements=[req], chunking="sentence" ) await result.acomplete() finally: @@ -507,11 +495,7 @@ def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] try: result = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=[req], - chunking="sentence", + _action(), backend, _ctx(), requirements=[req], chunking="sentence" ) yielded: list[str] = [] async for chunk in result.astream(): @@ -562,7 +546,7 @@ async def validate( req = FailOnSecondSentence() result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + _action(), backend, _ctx(), requirements=[req], chunking="sentence" ) yielded: list[str] = [] async for chunk in result.astream(): @@ -579,12 +563,12 @@ async def validate( @pytest.mark.asyncio async def test_no_requirements_streams_without_validation() -> None: - """quick_check_requirements=None → chunks produced, no validate() called.""" + """requirements=None → chunks produced, no validate() called.""" response = "Chunk one. Chunk two. Chunk three. " backend = StreamingMockBackend(response, token_size=3) result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=None, chunking="sentence" + _action(), backend, _ctx(), requirements=None, chunking="sentence" ) await result.acomplete() @@ -644,7 +628,7 @@ async def validate( req = FailOnNthChunk(n=2) result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + _action(), backend, _ctx(), requirements=[req], chunking="sentence" ) yielded: list[str] = [] async for c in result.astream(): @@ -709,7 +693,7 @@ async def spy_cancel( _action(), backend, _ctx(), - quick_check_requirements=[FailOnFirstChunk()], + requirements=[FailOnFirstChunk()], chunking="word", ) await asyncio.wait_for(result.acomplete(), timeout=5.0) @@ -754,7 +738,7 @@ async def validate( _action(), fail_backend, _ctx(), - quick_check_requirements=[FailImmediately()], + requirements=[FailImmediately()], chunking="word", ) await asyncio.wait_for(fail_result.acomplete(), timeout=5.0) @@ -771,7 +755,7 @@ async def validate( _action(), ok_backend, _ctx(), - quick_check_requirements=[AlwaysUnknownReq()], + requirements=[AlwaysUnknownReq()], chunking="sentence", ) await ok_result.acomplete() @@ -833,11 +817,7 @@ async def spy_cancel( ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] try: result = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=[RaisingReq()], - chunking="word", + _action(), backend, _ctx(), requirements=[RaisingReq()], chunking="word" ) with pytest.raises(ValueError, match="boom"): async for _chunk in result.astream(): @@ -886,11 +866,7 @@ async def validate( backend = StreamingMockBackend(response, token_size=3) result = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=[RaisingReq()], - chunking="word", + _action(), backend, _ctx(), requirements=[RaisingReq()], chunking="word" ) # Deliberately skip astream(). wait_for bounds any hang. with pytest.raises(ValueError, match="surfaced-without-astream"): @@ -915,11 +891,7 @@ async def test_external_task_cancellation_releases_consumers() -> None: backend = StreamingMockBackend(response, token_size=2) result = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=[AlwaysUnknownReq()], - chunking="word", + _action(), backend, _ctx(), requirements=[AlwaysUnknownReq()], chunking="word" ) assert result._orchestration_task is not None @@ -954,11 +926,7 @@ async def test_external_cancellation_acomplete_raise_once() -> None: backend = StreamingMockBackend(response, token_size=2) result = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=[AlwaysUnknownReq()], - chunking="word", + _action(), backend, _ctx(), requirements=[AlwaysUnknownReq()], chunking="word" ) assert result._orchestration_task is not None @@ -1004,11 +972,7 @@ async def validate( response = "word " * 10 backend = StreamingMockBackend(response, token_size=3) result = await stream_with_chunking( - _action(), - backend, - _ctx(), - quick_check_requirements=[RaisingReq()], - chunking="word", + _action(), backend, _ctx(), requirements=[RaisingReq()], chunking="word" ) # acomplete() sees the exception first and raises it. @@ -1072,7 +1036,7 @@ async def validate( req = FailOnNthChunkText(n=2) result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + _action(), backend, _ctx(), requirements=[req], chunking="sentence" ) yielded: list[str] = [] async for chunk in result.astream(): @@ -1087,6 +1051,24 @@ async def validate( # as_thunk.value must agree with full_text. assert result.as_thunk.value == result.full_text + # Fail on chunk 3: two chunks emitted before early exit. full_text must + # preserve the original inter-sentence spacing from the token stream, not + # the stripped chunk concatenation ("One.Two." would be wrong). + backend2 = StreamingMockBackend(response, token_size=100) + req2 = FailOnNthChunkText(n=3) + result2 = await stream_with_chunking( + _action(), backend2, _ctx(), requirements=[req2], chunking="sentence" + ) + yielded2: list[str] = [] + async for chunk in result2.astream(): + yielded2.append(chunk) + await result2.acomplete() + + assert result2.completed is False + assert yielded2 == ["One.", "Two."] + assert result2.full_text == "One. Two." + assert result2.as_thunk.value == result2.full_text + @pytest.mark.asyncio async def test_cancelled_flag_propagates_through_copy_methods() -> None: @@ -1202,15 +1184,13 @@ async def test_stream_with_chunking_requirement_copy_contract( req = req_cls() if expect_raise: with pytest.raises(ValueError, match="copy boom"): - await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[req] - ) + await stream_with_chunking(_action(), backend, _ctx(), requirements=[req]) # Hard invariant: reorder ensures backend never starts on copy failure. assert backend.generate_from_context_call_count == 0 assert backend.last_mot is None else: result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[req] + _action(), backend, _ctx(), requirements=[req] ) await result.acomplete() assert backend.generate_from_context_call_count == 1 @@ -1277,7 +1257,7 @@ async def validate( backend = StreamingMockBackend("Hello world. ", token_size=2) result = await stream_with_chunking( - _action(), backend, _ctx(), quick_check_requirements=[_RaisingReq(), _SlowReq()] + _action(), backend, _ctx(), requirements=[_RaisingReq(), _SlowReq()] ) with pytest.raises(RuntimeError, match="validator failed"): await result.acomplete()