diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py new file mode 100644 index 000000000..737ea7486 --- /dev/null +++ b/docs/examples/streaming/streaming_chunking.py @@ -0,0 +1,98 @@ +# pytest: ollama, e2e + +"""Streaming generation with per-chunk validation using stream_with_chunking(). + +Demonstrates: +- Subclassing Requirement to override stream_validate() for early-exit checks +- Calling stream_with_chunking() with sentence-level chunking +- Consuming validated chunks via astream() as they arrive +- Awaiting full completion with acomplete() to access final_validations and full_text +""" + +import asyncio + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import stream_with_chunking + + +class MaxSentencesReq(Requirement): + """Fails if the model generates more than *limit* sentences mid-stream. + + Each ``stream_validate`` call receives one complete sentence from the + :class:`~mellea.stdlib.chunking.SentenceChunker`. The running count is + maintained on ``self`` — this is the standard pattern for requirements + that need context beyond a single chunk. + """ + + def __init__(self, limit: int) -> None: + super().__init__() + self._limit = limit + self._count = 0 + + def format_for_llm(self) -> str: + return f"The response must be at most {self._limit} sentences long." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + self._count += 1 + if self._count > self._limit: + return PartialValidationResult( + "fail", + reason=f"Response exceeded {self._limit} sentence limit mid-stream", + ) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + backend = m.backend + ctx = m.ctx + + action = Instruction( + "Write a short paragraph about the water cycle in exactly two sentences." + ) + req = MaxSentencesReq(limit=3) + + result = await stream_with_chunking( + action, backend, ctx, quick_check_requirements=[req], chunking="sentence" + ) + + print("Streaming chunks as they arrive:") + async for chunk in result.astream(): + print(f" CHUNK: {chunk!r}") + + await result.acomplete() + + print(f"\nCompleted normally: {result.completed}") + print(f"Full text: {result.full_text!r}") + + if result.streaming_failures: + for _req, pvr in result.streaming_failures: + print(f"Streaming failure: {pvr.reason}") + + if result.final_validations: + for vr in result.final_validations: + print(f"Final validation: {'PASS' if vr.as_bool() else 'FAIL'}") + + +asyncio.run(main()) diff --git a/mellea/core/base.py b/mellea/core/base.py index 2028008d9..a8f35e79d 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -364,6 +364,74 @@ def _record_ttfb(self) -> None: ).total_seconds() * 1000 self._first_chunk_received = True + async def cancel_generation(self, error: Exception | None = None) -> None: + """Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span. + + Safe to call at any point during streaming. After this method returns, + ``is_computed()`` is ``True`` and ``value`` contains whatever text was + accumulated before cancellation. Calling on an already-computed MOT + is a no-op. + + Draining the internal queue after cancellation is necessary to release + any ``asyncio.Queue.put()`` call that the generation task was blocked on + (queue maxsize=20). + + Args: + error: Optional cause attributed to the open telemetry span. When + provided, this exception is recorded via ``set_span_error`` so + the span reflects the actual reason for cancellation (e.g. the + requirement failure or an unhandled exception from a streaming + validator). When ``None``, a generic + ``RuntimeError("Generation cancelled")`` is recorded. + """ + if self._computed: + return + + def _drain() -> None: + while not self._async_queue.empty(): + try: + self._async_queue.get_nowait() + except asyncio.QueueEmpty: + break + + if self._generate is not None and not self._generate.done(): + self._generate.cancel() + + if self._generate_extra is not None and not self._generate_extra.done(): + self._generate_extra.cancel() + + # Drain before awaiting — unblocks any put() the task is stuck on. + _drain() + + if self._generate is not None: + try: + await self._generate + except (asyncio.CancelledError, Exception): + pass + + if self._generate_extra is not None: + try: + await self._generate_extra + except (asyncio.CancelledError, Exception): + pass + + # Drain again for any final item the task put before terminating. + _drain() + + span = self._meta.pop("_telemetry_span", None) + if span is not None: + from ..telemetry import end_backend_span, set_span_error + + recorded: Exception = ( + error if error is not None else RuntimeError("Generation cancelled") + ) + set_span_error(span, recorded) + end_backend_span(span) + + if self._underlying_value is None: + self._underlying_value = "" + self._computed = True + def _copy_from(self, other: ModelOutputThunk) -> None: """Copy computed-output fields from *other* into *self*. diff --git a/mellea/stdlib/__init__.py b/mellea/stdlib/__init__.py index e4f32941b..7a30fdd53 100644 --- a/mellea/stdlib/__init__.py +++ b/mellea/stdlib/__init__.py @@ -10,9 +10,20 @@ ``mellea.stdlib.session`` — for day-to-day use. Streaming chunking strategies (for use with streaming validation) are available at -``mellea.stdlib.chunking`` and re-exported here for convenience. +``mellea.stdlib.chunking`` and re-exported here for convenience. The core streaming +orchestration primitive :func:`~mellea.stdlib.streaming.stream_with_chunking` and +its result type :class:`~mellea.stdlib.streaming.StreamChunkingResult` are also +re-exported here. """ from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker +from .streaming import StreamChunkingResult, stream_with_chunking -__all__ = ["ChunkingStrategy", "ParagraphChunker", "SentenceChunker", "WordChunker"] +__all__ = [ + "ChunkingStrategy", + "ParagraphChunker", + "SentenceChunker", + "StreamChunkingResult", + "WordChunker", + "stream_with_chunking", +] diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index 6b9091780..6c81105c5 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -35,6 +35,27 @@ def split(self, accumulated_text: str) -> list[str]: """ ... + def flush(self, accumulated_text: str) -> list[str]: + """Return any trailing fragment that ``split`` withheld. + + Called once by the orchestrator after the stream has ended naturally + (not on early-exit cancellation). Gives the chunker a chance to + release the final fragment that did not reach a terminator. + + The default implementation returns an empty list — the trailing + fragment is discarded. Built-in chunkers override this to return + the withheld fragment as a single-element list when non-empty. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + The trailing fragment as ``[fragment]`` if it should be treated + as a final chunk, or an empty list to discard it. + """ + _ = accumulated_text + return [] + # Sentence boundary: sentence-ending punctuation, optionally followed by a closing # quote or paren, then whitespace. @@ -94,6 +115,36 @@ def split(self, accumulated_text: str) -> list[str]: return chunks + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing sentence fragment (if any) as a final chunk. + + Trailing whitespace on the fragment is non-semantic for sentence + boundaries and is dropped via ``rstrip``. Leading whitespace is + already removed by the loop's ``lstrip`` on each advance, so no + ``lstrip`` is needed here. The result is the fragment's content + only, consistent with how :meth:`split` returns sentences without + trailing whitespace. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing sentence fragment + with leading and trailing whitespace stripped, or an empty list + when there is no fragment (all content ended in a sentence + boundary or the input is empty/whitespace-only). + """ + if not accumulated_text: + return [] + remaining = accumulated_text + while True: + match = _SENTENCE_BOUNDARY.search(remaining) + if match is None: + break + remaining = remaining[match.end() :].lstrip() + trailing = remaining.rstrip() + return [trailing] if trailing else [] + class WordChunker(ChunkingStrategy): """Splits accumulated text on whitespace boundaries. @@ -134,6 +185,32 @@ def split(self, accumulated_text: str) -> list[str]: return parts + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing word fragment (if any) as a final chunk. + + The trailing fragment is the text after the last whitespace run when + the accumulated text does not end with whitespace. When it does end + with whitespace, every word is already complete and no fragment is + released. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing word fragment, or + an empty list when the input ends with whitespace (every word + already complete) or is empty. + """ + if not accumulated_text: + return [] + if accumulated_text[-1].isspace(): + return [] + parts = _WHITESPACE.split(accumulated_text) + for part in reversed(parts): + if part: + return [part] + return [] + class ParagraphChunker(ChunkingStrategy): r"""Splits accumulated text on double-newline paragraph boundaries. @@ -168,3 +245,29 @@ def split(self, accumulated_text: str) -> list[str]: # _PARA_BOUNDARY.split on leading \n\n produces an empty first element. return [p for p in parts if p] + + def flush(self, accumulated_text: str) -> list[str]: + r"""Return the trailing paragraph fragment (if any) as a final chunk. + + Unlike :class:`SentenceChunker.flush`, the fragment is returned + byte-for-byte without stripping. Internal whitespace — including + a trailing single ``\n`` — can be semantically meaningful inside + a paragraph (e.g. a list item or a deliberate line break), and a + consumer validating paragraph content should see the fragment as + it was withheld. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing paragraph fragment + byte-for-byte, or an empty list when the input ends with a + paragraph boundary (``\n\n`` or more) or is empty. + """ + if not accumulated_text: + return [] + if _PARA_BOUNDARY_END.search(accumulated_text): + return [] + parts = _PARA_BOUNDARY.split(accumulated_text) + trailing = parts[-1] if parts else "" + return [trailing] if trailing else [] diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py new file mode 100644 index 000000000..dcdd6e894 --- /dev/null +++ b/mellea/stdlib/streaming.py @@ -0,0 +1,413 @@ +"""Streaming generation with per-chunk validation. + +Provides :func:`stream_with_chunking`, the core orchestration primitive that +consumes a streaming :class:`~mellea.core.base.ModelOutputThunk`, applies a +:class:`~mellea.stdlib.chunking.ChunkingStrategy` to produce semantic chunks, +and runs :meth:`~mellea.core.requirement.Requirement.stream_validate` on each +chunk in parallel. Higher-level streaming APIs build on this function. +""" + +import asyncio +from collections.abc import AsyncIterator, Sequence +from copy import copy +from typing import Any + +from ..backends.model_options import ModelOption +from ..core.backend import Backend +from ..core.base import CBlock, Component, Context, ModelOutputThunk +from ..core.requirement import PartialValidationResult, Requirement, ValidationResult +from ..core.utils import MelleaLogger +from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker + +_CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = { + "sentence": SentenceChunker, + "word": WordChunker, + "paragraph": ParagraphChunker, +} + + +class StreamChunkingResult: + """Result of a :func:`stream_with_chunking` operation. + + Provides async iteration over validated text chunks as they complete + (:meth:`astream`), a blocking :meth:`acomplete` for awaiting the full + result including final validation, and :attr:`as_thunk` for wrapping the + output as a :class:`~mellea.core.base.ModelOutputThunk`. + + Instances are created by :func:`stream_with_chunking`; do not instantiate + directly. + + Args: + mot: The :class:`~mellea.core.base.ModelOutputThunk` from the backend + generation call. + ctx: The generation context returned alongside the MOT. + + Attributes: + completed: ``False`` if the stream exited early because a requirement + returned ``"fail"`` during streaming; ``True`` otherwise. + full_text: The complete generated text accumulated during streaming. + Available after :meth:`acomplete` returns. + final_validations: :class:`~mellea.core.requirement.ValidationResult` + objects from the final :meth:`~mellea.core.requirement.Requirement.validate` + calls on all non-failed requirements. Available after + :meth:`acomplete` returns. + streaming_failures: ``(Requirement, PartialValidationResult)`` pairs + for every requirement that returned ``"fail"`` during streaming. + """ + + def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: + """Initialise with the MOT and context from the backend call.""" + self._mot = mot + self._ctx = ctx + self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue() + self._orchestration_task: asyncio.Task[None] | None = None + self._done = asyncio.Event() + # Stashed so acomplete() surfaces orchestrator failures even when the + # consumer never iterates astream(). Cleared once consumed by + # whichever of the two reads it first. + self._orchestration_exception: BaseException | None = None + + self.completed: bool = True + self.full_text: str = "" + self.final_validations: list[ValidationResult] = [] + self.streaming_failures: list[tuple[Requirement, PartialValidationResult]] = [] + + async def astream(self) -> AsyncIterator[str]: + """Yield validated text chunks as they complete. + + Each yielded string is a chunk that has passed per-chunk streaming + validation (or the stream had no requirements). Iteration ends when + all chunks have been yielded, whether the stream completed normally or + was cancelled early on a ``"fail"`` result. + + **Single-consumer.** Chunks are delivered via an + :class:`asyncio.Queue` that this method drains; calling + ``astream()`` a second time on the same result blocks indefinitely + because the queue is empty and the terminating ``None`` sentinel + has already been consumed. If you need the chunks after + iteration, capture them into a list during the first pass or use + :attr:`full_text` after :meth:`acomplete`. + + Yields: + str: A validated text chunk from the chunking strategy. + + Raises: + Exception: Propagates any error from the background orchestration + task. + """ + while True: + item = await self._chunk_queue.get() + if item is None: + return + if isinstance(item, Exception): + if self._orchestration_exception is None: + # Already surfaced by acomplete(); don't raise twice. + continue + self._orchestration_exception = None + raise item + yield item + + async def acomplete(self) -> None: + """Await full completion, including final validation. + + After this method returns, :attr:`full_text`, :attr:`completed`, + :attr:`final_validations`, and :attr:`streaming_failures` are all + populated. If :meth:`astream` has already been consumed to + exhaustion, this call is effectively a no-op. + + Raises: + Exception: Propagates any error from the orchestration task. + """ + await self._done.wait() + # Raise-once: if astream() already consumed the exception, the stash + # is already None and this is a no-op. + exc = self._orchestration_exception + if exc is not None: + self._orchestration_exception = None + raise exc + if self._orchestration_task is not None and self._orchestration_task.done(): + task_exc = self._orchestration_task.exception() + if task_exc is not None: + raise task_exc + + @property + def as_thunk(self) -> ModelOutputThunk: + """Wrap the output as a computed :class:`~mellea.core.base.ModelOutputThunk`. + + Returns a new thunk with ``value`` set to :attr:`full_text` and + generation metadata copied from the original MOT. Safe to call on + early-exit results; ``value`` will reflect whatever was accumulated + before cancellation. + + Note: + On early exit, ``cancel_generation()`` forces the MOT into a + computed state without running the backend's + ``post_processing()``. Telemetry fields on the returned thunk + (``generation.usage``, ``generation.ttfb_ms``, etc.) may + therefore be ``None`` or reflect the partial state at + cancellation time. ``value`` and ``streaming`` are reliable; + usage totals are not. + + Returns: + ModelOutputThunk: A computed thunk containing the streamed output. + + Raises: + RuntimeError: If called before :meth:`acomplete` has returned. + """ + if not self._done.is_set(): + raise RuntimeError( + "as_thunk accessed before acomplete() — await acomplete() first" + ) + thunk = ModelOutputThunk(value=self.full_text) + thunk.generation = copy(self._mot.generation) + return thunk + + +async def _orchestrate_streaming( + result: StreamChunkingResult, + mot: ModelOutputThunk, + ctx: Context, + cloned_reqs: list[Requirement], + chunking: ChunkingStrategy, + val_backend: Backend, +) -> None: + accumulated = "" + prev_chunk_count = 0 + failed_indices: set[int] = set() + early_exit = False + + async def _validate_and_emit(c: str) -> bool: + """Run stream_validate on chunk c across active requirements. + + Returns True if a failure was recorded (caller should early-exit), + False otherwise (chunk was emitted to the consumer queue). + """ + active = [ + (i, req) for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + if active: + pvrs: list[PartialValidationResult] = list( + await asyncio.gather( + *[ + req.stream_validate(c, backend=val_backend, ctx=ctx) + for _, req in active + ] + ) + ) + for (idx, req), pvr in zip(active, pvrs): + if pvr.success == "fail": + failed_indices.add(idx) + result.streaming_failures.append((req, pvr)) + + if failed_indices: + return True + + await result._chunk_queue.put(c) + return False + + try: + while not mot.is_computed(): + try: + delta = await mot.astream() + except RuntimeError: + # Expected race: mot.is_computed() was False at the top of the + # loop but the stream finished before we re-entered astream(). + # Any other RuntimeError is a real bug and must propagate. + if mot.is_computed(): + break + raise + + accumulated += delta + chunks = chunking.split(accumulated) + new_chunks = chunks[prev_chunk_count:] + prev_chunk_count = len(chunks) + + for c in new_chunks: + failed = await _validate_and_emit(c) + if failed: + early_exit = True + result.completed = False + await mot.cancel_generation() + break + + if early_exit: + break + + # Stream ended naturally: flush any withheld trailing fragment and + # run stream_validate on it. Skipped on early exit — the generation + # was cancelled, the trailing fragment is incomplete. + if not early_exit: + for c in chunking.flush(accumulated): + failed = await _validate_and_emit(c) + if failed: + early_exit = True + result.completed = False + break + + result.full_text = accumulated + + non_failed = [ + req for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + if non_failed and not early_exit: + result.final_validations = list( + await asyncio.gather( + *[req.validate(val_backend, ctx) for req in non_failed] + ) + ) + + except Exception as exc: + # Orchestrator is leaving — we must stop the backend producer too, + # otherwise mot._async_queue (maxsize=20) fills and the feeder task + # blocks indefinitely. The spec (#891, #901) calls this out for the + # "fail" path; the same reasoning applies to any unplanned exit. + # Pass `exc` so the backend telemetry span records the real cause + # rather than a generic "Generation cancelled". + try: + await mot.cancel_generation(error=exc) + except Exception as cleanup_exc: + # Never let cleanup mask the original exception: log loudly and + # continue to surface `exc` to the consumer. + # TODO(#902): replace this log with an ErrorEvent emission. + MelleaLogger.get_logger().warning( + "stream_with_chunking: cancel_generation() raised during " + "exception cleanup (original: %r, cleanup: %r)", + exc, + cleanup_exc, + ) + result.completed = False + result._orchestration_exception = exc + await result._chunk_queue.put(exc) + finally: + # CancelledError (BaseException, not Exception) bypasses the except + # block above, so cancel_generation() may not have been called. + # Catch only Exception here so CancelledError / KeyboardInterrupt / + # SystemExit still propagate to the caller. + if not mot.is_computed(): + try: + await mot.cancel_generation() + except Exception: + pass + # put_nowait + set() are synchronous — no await point, so they cannot + # be interrupted by task cancellation. Consumers waiting on + # _done.wait() are always released, even if the task was cancelled + # mid-cleanup. The queue is unbounded, so QueueFull cannot occur. + try: + result._chunk_queue.put_nowait(None) + except asyncio.QueueFull: + pass + result._done.set() + + +async def stream_with_chunking( + action: Component[Any] | CBlock, + backend: Backend, + ctx: Context, + *, + quick_check_requirements: Sequence[Requirement] | None = None, + chunking: str | ChunkingStrategy = "sentence", + quick_check_backend: Backend | None = None, +) -> StreamChunkingResult: + """Generate a streaming response with per-chunk validation. + + Starts a backend generation with streaming enabled, consumes the + :class:`~mellea.core.base.ModelOutputThunk`'s async stream in a single + background task, splits the accumulated text using *chunking*, and runs + :meth:`~mellea.core.requirement.Requirement.stream_validate` on each new + chunk in parallel across all requirements. + + For each new complete chunk produced by the chunking strategy, + ``stream_validate`` is called once per active requirement (in parallel + via :func:`asyncio.gather`), receiving that single chunk. Multiple + chunks produced from one ``astream()`` iteration are validated + sequentially in order, so early exit on a ``"fail"`` result prevents + later chunks in the same batch from being validated or emitted to the + consumer. + + If any requirement returns ``"fail"``, the generation is cancelled + immediately (via + :meth:`~mellea.core.base.ModelOutputThunk.cancel_generation`) and + :attr:`StreamChunkingResult.completed` is set to ``False``. The + failing chunk is not emitted to the consumer; use + :attr:`StreamChunkingResult.streaming_failures` to inspect what failed. + + When the stream ends naturally, any trailing fragment withheld by the + chunking strategy (see :meth:`~mellea.stdlib.chunking.ChunkingStrategy.flush`) + is released as a final chunk and run through ``stream_validate`` on the + same terms as the regular chunks. On early exit, the trailing fragment + is discarded because the generation was cancelled mid-token. + + After the stream ends naturally, ``validate()`` is called on every + requirement that did not return ``"fail"`` — both ``"pass"`` and + ``"unknown"`` trigger final validation. On early exit, no ``validate()`` + call is made; :attr:`StreamChunkingResult.final_validations` remains + empty. Requirements are cloned (``copy(req)``) before use so originals + are never mutated. + + Requirements that need context beyond the current chunk should + accumulate it themselves across ``stream_validate`` calls (e.g. + ``self._seen = self._seen + chunk``). They must not read ``mot.astream()`` + directly — this orchestrator is the single consumer of the MOT stream. + + Note: + Chunks are emitted to the consumer (via + :meth:`StreamChunkingResult.astream`) only after every requirement's + ``stream_validate`` has returned for that chunk. A slow validator + (for example, one that invokes an LLM) therefore adds latency to + every chunk — the consumer sees a chunk at most as quickly as the + slowest active validator. This trade is deliberate in v1: it + preserves the invariant that the consumer never sees content that + has not been validated, which matters for UIs displaying generated + text live. A future fast-path mode that emits chunks to the + consumer concurrently with validation (at the cost of that + invariant) may be added if a concrete use case calls for it. + + Note: + v1 retry is simple re-invocation of this function. Plugin hooks + (``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire + on retries — use the ``#902`` event types for observability instead. + + Args: + action: The component or content block to generate from. + backend: The backend used for generation and final validation. + ctx: The generation context. + quick_check_requirements: Sequence of requirements to validate against + each chunk during streaming. ``None`` disables streaming validation + (chunks are still produced; ``validate()`` is not called at stream end). + chunking: Chunking strategy — either a :class:`~mellea.stdlib.chunking.ChunkingStrategy` + instance or one of the string aliases ``"sentence"`` (default), + ``"word"``, or ``"paragraph"``. + quick_check_backend: Optional alternate backend for both + ``stream_validate`` and final ``validate`` calls. When ``None``, + *backend* is used for validation. + + Returns: + StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream` + for incremental chunk consumption and + :meth:`~StreamChunkingResult.acomplete` for blocking until done. + + Raises: + ValueError: If *chunking* is a string that does not match any known + alias (``"sentence"``, ``"word"``, ``"paragraph"``). + """ + if isinstance(chunking, str): + cls = _CHUNKING_ALIASES.get(chunking) + if cls is None: + raise ValueError( + f"Unknown chunking alias {chunking!r}. Choose from: {list(_CHUNKING_ALIASES)}" + ) + chunking = cls() + + opts: dict[str, Any] = {ModelOption.STREAM: True} + mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) + + result = StreamChunkingResult(mot, gen_ctx) + + cloned_reqs = [copy(req) for req in (quick_check_requirements or [])] + val_backend = quick_check_backend if quick_check_backend is not None else backend + + result._orchestration_task = asyncio.create_task( + _orchestrate_streaming(result, mot, gen_ctx, cloned_reqs, chunking, val_backend) + ) + + return result diff --git a/test/stdlib/test_chunking.py b/test/stdlib/test_chunking.py index fbaf727a2..7b965350f 100644 --- a/test/stdlib/test_chunking.py +++ b/test/stdlib/test_chunking.py @@ -242,3 +242,79 @@ def test_paragraph_chunker_incremental_simulation(): "First paragraph.", "Second paragraph.", ] + + +# --------------------------------------------------------------------------- +# flush() — trailing-fragment release at end of stream +# --------------------------------------------------------------------------- + + +def test_default_flush_returns_empty_list(): + """The ABC default discards the trailing fragment.""" + + class Minimal(ChunkingStrategy): + def split(self, accumulated_text: str) -> list[str]: + _ = accumulated_text + return [] + + assert Minimal().flush("anything at all") == [] + assert Minimal().flush("") == [] + + +def test_sentence_chunker_flush_empty(): + assert SentenceChunker().flush("") == [] + + +def test_sentence_chunker_flush_only_complete(): + """All text ends in a complete sentence with trailing whitespace → no fragment.""" + assert SentenceChunker().flush("One. Two. ") == [] + + +def test_sentence_chunker_flush_trailing_fragment(): + """Final sentence without trailing whitespace is released by flush.""" + assert SentenceChunker().flush("One. Two without period") == ["Two without period"] + + +def test_sentence_chunker_flush_terminated_no_trailing_space(): + """Final sentence with terminator but no trailing whitespace is a fragment + under split() semantics and gets released by flush().""" + assert SentenceChunker().flush("One. Two.") == ["Two."] + + +def test_sentence_chunker_flush_single_sentence_no_terminator(): + assert SentenceChunker().flush("Incomplete sentence") == ["Incomplete sentence"] + + +def test_word_chunker_flush_empty(): + assert WordChunker().flush("") == [] + + +def test_word_chunker_flush_trailing_whitespace(): + """Trailing whitespace means all words are complete → no fragment.""" + assert WordChunker().flush("one two three ") == [] + + +def test_word_chunker_flush_trailing_fragment(): + assert WordChunker().flush("one two three") == ["three"] + + +def test_word_chunker_flush_single_word(): + assert WordChunker().flush("solo") == ["solo"] + + +def test_paragraph_chunker_flush_empty(): + assert ParagraphChunker().flush("") == [] + + +def test_paragraph_chunker_flush_only_complete(): + assert ParagraphChunker().flush("Para one.\n\nPara two.\n\n") == [] + + +def test_paragraph_chunker_flush_trailing_fragment(): + assert ParagraphChunker().flush("Para one.\n\nPara two (no sep)") == [ + "Para two (no sep)" + ] + + +def test_paragraph_chunker_flush_single_paragraph_no_separator(): + assert ParagraphChunker().flush("Only paragraph") == ["Only paragraph"] diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py new file mode 100644 index 000000000..759d8d272 --- /dev/null +++ b/test/stdlib/test_streaming.py @@ -0,0 +1,879 @@ +"""Tests for stream_with_chunking() and StreamChunkingResult. + +Uses StreamingMockBackend — a deterministic test double that feeds tokens from a +fixed response string into a MOT queue without network or LLM calls. + +All tests are unit tests (no @pytest.mark.ollama needed). +""" + +import asyncio +from typing import Any + +import pytest + +from mellea.core.backend import Backend +from mellea.core.base import CBlock, Context, GenerateType, ModelOutputThunk +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.streaming import stream_with_chunking + +# --------------------------------------------------------------------------- +# StreamingMockBackend +# --------------------------------------------------------------------------- + + +async def _mock_process(mot: ModelOutputThunk, chunk: Any) -> None: + if mot._underlying_value is None: + mot._underlying_value = "" + if chunk is not None: + mot._underlying_value += chunk + + +async def _mock_post_process(_mot: ModelOutputThunk) -> None: + pass + + +def _make_mot() -> ModelOutputThunk: + mot = ModelOutputThunk(value=None) + mot._action = CBlock("mock_action") + mot._generate_type = GenerateType.ASYNC + mot._process = _mock_process + mot._post_process = _mock_post_process + mot._chunk_size = 0 + return mot + + +async def _feed_tokens(mot: ModelOutputThunk, response: str, token_size: int) -> None: + i = 0 + while i < len(response): + token = response[i : i + token_size] + await mot._async_queue.put(token) + await asyncio.sleep(0) + i += token_size + await mot._async_queue.put(None) + + +class StreamingMockBackend(Backend): + """Test double that streams a fixed response one token at a time. + + ``token_size`` controls how many characters constitute one token. + Validation calls (via ``stream_validate`` / ``validate``) are delegated + to the requirements themselves — this backend does not perform any real + inference. + """ + + def __init__(self, response: str, token_size: int = 1) -> None: + self._response = response + self._token_size = token_size + + async def _generate_from_context( + self, + action: Any, + ctx: Context, + *, + format: Any = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk, Context]: + _ = format, model_options, tool_calls + mot = _make_mot() + task = asyncio.create_task(_feed_tokens(mot, self._response, self._token_size)) + _ = task + new_ctx = ctx.add(action).add(mot) + return mot, new_ctx + + async def generate_from_raw( + self, actions: Any, ctx: Any, **kwargs: Any + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Requirement test doubles +# --------------------------------------------------------------------------- + + +class AlwaysUnknownReq(Requirement): + """stream_validate always returns 'unknown'; validate returns True.""" + + def format_for_llm(self) -> str: + return "always unknown" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class FailAfterWordsReq(Requirement): + """Returns 'fail' once the cumulative word count reaches *threshold*. + + Each call to ``stream_validate`` receives a single chunk (delta) from the + chunking strategy; the running total is maintained on the instance. + """ + + def __init__(self, threshold: int) -> None: + super().__init__() + self._threshold = threshold + self._word_count = 0 + + def format_for_llm(self) -> str: + return f"fail after {self._threshold} words" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self._word_count += len(chunk.split()) + if self._word_count >= self._threshold: + return PartialValidationResult("fail", reason="too many words") + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class BackendRecordingReq(Requirement): + """Records which backend was passed to stream_validate and validate.""" + + def __init__(self) -> None: + super().__init__() + self.seen_backends: list[Any] = [] + + def __copy__(self) -> "BackendRecordingReq": + clone = BackendRecordingReq() + clone.seen_backends = [] # fresh list — do not share with original + return clone + + def format_for_llm(self) -> str: + return "backend recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk + self.seen_backends.append(backend) + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + self.seen_backends.append(backend) + return ValidationResult(result=True) + + +class MutationDetectorReq(Requirement): + """Tracks how many times stream_validate was called on this instance.""" + + def __init__(self) -> None: + super().__init__() + self._call_count = 0 + + def format_for_llm(self) -> str: + return "mutation detector" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + self._call_count += 1 + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _ctx() -> SimpleContext: + return SimpleContext() + + +def _action() -> CBlock: + return CBlock("prompt") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normal_completion_calls_validate_at_stream_end() -> None: + """All 'unknown' requirements → validate() called at stream end; completed=True.""" + response = "Hello world. How are you. " + backend = StreamingMockBackend(response, token_size=3) + req = AlwaysUnknownReq() + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + await result.acomplete() + + assert result.completed is True + assert result.full_text == response + assert len(result.final_validations) == 1 + assert result.final_validations[0].as_bool() is True + assert result.streaming_failures == [] + + +@pytest.mark.asyncio +async def test_early_exit_on_fail() -> None: + """Requirement fails mid-stream → completed=False, streaming_failures populated.""" + # 5 words to trigger failure + response = "one two three four five six seven eight. " + backend = StreamingMockBackend(response, token_size=2) + req = FailAfterWordsReq(threshold=4) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" + ) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + _req, pvr = result.streaming_failures[0] + assert pvr.success == "fail" + assert pvr.reason == "too many words" + # final_validations should be empty — final validate() skipped on early exit + assert result.final_validations == [] + + +@pytest.mark.asyncio +async def test_clone_isolation_across_retries() -> None: + """Originals must not be mutated; two invocations are independent.""" + response = "Sentence one. Sentence two. " + req = MutationDetectorReq() + original_reqs = [req] + + backend = StreamingMockBackend(response, token_size=4) + + r1 = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=original_reqs, + chunking="sentence", + ) + await r1.acomplete() + + r2 = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=original_reqs, + chunking="sentence", + ) + await r2.acomplete() + + # Original requirement must never have been called — only clones are used + assert req._call_count == 0 + + +@pytest.mark.asyncio +async def test_quick_check_backend_routing() -> None: + """stream_validate and validate receive quick_check_backend, not the main backend.""" + response = "One sentence. Two sentences. " + main_backend = StreamingMockBackend(response, token_size=3) + val_backend = StreamingMockBackend("unused", token_size=1) + + req = BackendRecordingReq() + + # Capture the cloned requirement so we can inspect which backends it saw. + captured: list[BackendRecordingReq] = [] + original_copy = BackendRecordingReq.__copy__ + + def _capturing_copy(self: BackendRecordingReq) -> BackendRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + BackendRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + main_backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + quick_check_backend=val_backend, + ) + await result.acomplete() + finally: + BackendRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + assert result.completed is True + # The original was never called — only clones are used. + assert req.seen_backends == [] + # The clone must have seen val_backend for every call (stream_validate + validate), + # never main_backend. This is the actual routing assertion. + assert len(captured) == 1 + assert len(captured[0].seen_backends) > 0 + assert all(b is val_backend for b in captured[0].seen_backends) + + +@pytest.mark.asyncio +async def test_early_exit_does_not_deadlock() -> None: + """Early failure with a high-throughput stream must not hang.""" + long_response = "word " * 200 + backend = StreamingMockBackend(long_response, token_size=5) + req = FailAfterWordsReq(threshold=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" + ) + # 5-second timeout — should complete in milliseconds on success + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + assert result.completed is False + + +@pytest.mark.asyncio +async def test_as_thunk_correctness() -> None: + """as_thunk is computed, value matches full_text, generation metadata preserved.""" + response = "This is a test sentence. " + backend = StreamingMockBackend(response, token_size=4) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + await result.acomplete() + + thunk = result.as_thunk + assert thunk.is_computed() + assert thunk.value == result.full_text == response + + +@pytest.mark.asyncio +async def test_as_thunk_raises_before_acomplete() -> None: + """as_thunk raises RuntimeError if accessed before acomplete().""" + response = "Some text. " + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + + with pytest.raises(RuntimeError, match="acomplete"): + _ = result.as_thunk + + +@pytest.mark.asyncio +async def test_astream_yields_individual_chunks() -> None: + """Consumer via astream() receives individual chunks, not accumulated text.""" + response = "First sentence. Second sentence. Third sentence. " + backend = StreamingMockBackend(response, token_size=5) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + + chunks: list[str] = [] + async for chunk in result.astream(): + chunks.append(chunk) + + await result.acomplete() + + # Each chunk must be a complete sentence (not the accumulated text) + assert len(chunks) == 3 + for chunk in chunks: + assert chunk.endswith(".") + # Chunks don't include inter-sentence spaces; joined with a space they appear in full_text + assert " ".join(chunks) in result.full_text + + +@pytest.mark.asyncio +async def test_stream_validate_receives_individual_chunks() -> None: + """stream_validate is called once per chunk with the chunk itself, not accumulated text.""" + + class ChunkRecordingReq(Requirement): + def __init__(self) -> None: + self.seen_chunks: list[str] = [] + + def __copy__(self) -> "ChunkRecordingReq": + clone = ChunkRecordingReq() + clone.seen_chunks = [] + return clone + + def format_for_llm(self) -> str: + return "chunk recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self.seen_chunks.append(chunk) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "First sentence. Second sentence. Third sentence. " + backend = StreamingMockBackend(response, token_size=4) + req = ChunkRecordingReq() + + # Capture the cloned requirement used by the orchestrator via a side channel. + captured: list[ChunkRecordingReq] = [] + original_copy = ChunkRecordingReq.__copy__ + + def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + ) + await result.acomplete() + finally: + ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + assert len(captured) == 1 + seen = captured[0].seen_chunks + # Exact match: three separate calls, one per complete sentence, + # each call receiving that sentence and nothing more. Under the old + # accumulated-text semantics, seen would have been + # ["First sentence.", "First sentence. Second sentence.", ...] — + # exact match against the per-chunk list is the direct regression guard. + assert seen == ["First sentence.", "Second sentence.", "Third sentence."] + + +@pytest.mark.asyncio +async def test_trailing_fragment_is_flushed_to_consumer() -> None: + """Response without trailing whitespace: final sentence reaches astream() and stream_validate.""" + + class ChunkRecordingReq(Requirement): + def __init__(self) -> None: + self.seen_chunks: list[str] = [] + + def __copy__(self) -> "ChunkRecordingReq": + clone = ChunkRecordingReq() + clone.seen_chunks = [] + return clone + + def format_for_llm(self) -> str: + return "chunk recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self.seen_chunks.append(chunk) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + # No trailing whitespace after the final sentence — SentenceChunker withholds it. + response = "First sentence. Second sentence." + backend = StreamingMockBackend(response, token_size=4) + req = ChunkRecordingReq() + + captured: list[ChunkRecordingReq] = [] + original_copy = ChunkRecordingReq.__copy__ + + def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + finally: + ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + # Both sentences reach the consumer, including the terminating one without trailing whitespace. + assert yielded == ["First sentence.", "Second sentence."] + # stream_validate was called on both — the flush path is not a shortcut. + assert captured[0].seen_chunks == ["First sentence.", "Second sentence."] + assert result.completed is True + + +@pytest.mark.asyncio +async def test_early_exit_on_trailing_fragment() -> None: + """A fail on the flushed fragment records a streaming failure and skips final validate().""" + + class FailOnSecondSentence(Requirement): + def __init__(self) -> None: + self._count = 0 + + def format_for_llm(self) -> str: + return "fail on second sentence" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + self._count += 1 + if self._count >= 2: + return PartialValidationResult("fail", reason="second sentence hit") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "First sentence. Second sentence." + backend = StreamingMockBackend(response, token_size=4) + req = FailOnSecondSentence() + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + # First sentence was emitted; second (the flushed fragment) failed and wasn't emitted. + assert yielded == ["First sentence."] + # Early exit on fail skips final validate(). + assert result.final_validations == [] + + +@pytest.mark.asyncio +async def test_no_requirements_streams_without_validation() -> None: + """quick_check_requirements=None → chunks produced, no validate() called.""" + response = "Chunk one. Chunk two. Chunk three. " + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=None, chunking="sentence" + ) + await result.acomplete() + + assert result.completed is True + assert result.full_text == response + assert result.final_validations == [] + assert result.streaming_failures == [] + + +@pytest.mark.asyncio +async def test_multiple_chunks_in_one_batch_with_mid_batch_fail() -> None: + """When one astream() delta produces several complete chunks and one in + the middle fails, earlier chunks emit, failing chunk is recorded, later + chunks are neither validated nor emitted.""" + + captured: list[Any] = [] + + class FailOnNthChunk(Requirement): + def __init__(self, n: int) -> None: + self._n = n + self._calls = 0 + self.seen: list[str] = [] + + def __copy__(self) -> "FailOnNthChunk": + clone = FailOnNthChunk(self._n) + captured.append(clone) + return clone + + def format_for_llm(self) -> str: + return f"fail on chunk {self._n}" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = backend, ctx + self._calls += 1 + self.seen.append(chunk) + if self._calls == self._n: + return PartialValidationResult("fail", reason=f"n={self._n}") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + # token_size larger than the whole response → one astream() delta delivers + # the full text, so chunking.split produces 4 sentences in a single batch. + response = "One. Two. Three. Four. " + backend = StreamingMockBackend(response, token_size=100) + req = FailOnNthChunk(n=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for c in result.astream(): + yielded.append(c) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + # Chunk 1 was validated and emitted; chunk 2 was validated and failed + # (NOT emitted); chunks 3 and 4 were NEITHER validated NOR emitted. + assert yielded == ["One."] + assert len(captured) == 1 + assert captured[0].seen == ["One.", "Two."] + assert captured[0]._calls == 2 + + +@pytest.mark.asyncio +async def test_cancel_generation_invoked_on_fail() -> None: + """Early exit on 'fail' must call mot.cancel_generation() — the spec reason + is that asyncio.Queue(maxsize=20) will block the producer if the consumer + stops without cancelling.""" + + from mellea.core.base import ModelOutputThunk + + response = "word " * 50 + backend = StreamingMockBackend(response, token_size=3) + + class FailOnFirstChunk(Requirement): + def format_for_llm(self) -> str: + return "fail immediately" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + return PartialValidationResult("fail", reason="nope") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + call_count = 0 + real_cancel = ModelOutputThunk.cancel_generation + + async def spy_cancel( + self: ModelOutputThunk, error: Exception | None = None + ) -> None: + nonlocal call_count + call_count += 1 + await real_cancel(self, error) + + ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[FailOnFirstChunk()], + chunking="word", + ) + await asyncio.wait_for(result.acomplete(), timeout=5.0) + finally: + ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign] + + assert result.completed is False + assert call_count >= 1 + + +@pytest.mark.asyncio +async def test_exception_in_stream_validate_cancels_generation() -> None: + """Verifies the orchestrator's exception-path cleanup: if stream_validate + raises, cancel_generation() is called and the exception surfaces to the + consumer via astream()/acomplete() without hanging. + + This covers the cancel-on-exception path and the no-hang guarantee. + It does not directly exercise the worst-case "producer already blocked on + full queue" scenario (here the fail happens on chunk 1 so the queue never + fills); the cancel_generation drain logic is covered by its own tests in + test/core/. + """ + + from mellea.core.base import ModelOutputThunk + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + raise ValueError("boom") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + response = "word " * 50 # enough to fill maxsize=20 queue without cleanup + backend = StreamingMockBackend(response, token_size=3) + + call_count = 0 + real_cancel = ModelOutputThunk.cancel_generation + + async def spy_cancel( + self: ModelOutputThunk, error: Exception | None = None + ) -> None: + nonlocal call_count + call_count += 1 + await real_cancel(self, error) + + ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[RaisingReq()], + chunking="word", + ) + with pytest.raises(ValueError, match="boom"): + async for _chunk in result.astream(): + pass + # acomplete must complete (not hang) even though the orchestration + # task raised, because cancel_generation was called in the except path. + await asyncio.wait_for(result.acomplete(), timeout=5.0) + finally: + ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign] + + assert result.completed is False + assert call_count >= 1 + + +@pytest.mark.asyncio +async def test_acomplete_surfaces_exception_without_astream() -> None: + """acomplete() must surface orchestrator exceptions even when the + consumer never iterates astream(). + + The alternative — only delivering the exception through the chunk queue + — silently swallows validator failures for callers who skip astream(). + """ + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + raise ValueError("surfaced-without-astream") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + response = "word " * 50 + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[RaisingReq()], + chunking="word", + ) + # Deliberately skip astream(). wait_for bounds any hang. + with pytest.raises(ValueError, match="surfaced-without-astream"): + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + assert result.completed is False + # Raise-once: a second acomplete() must not re-raise. + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + +@pytest.mark.asyncio +async def test_external_task_cancellation_releases_consumers() -> None: + """External cancellation of the orchestration task must still set _done. + + If the finally cleanup itself contains an ``await`` (e.g. awaiting a + terminator put into the chunk queue), CancelledError re-raises at that + await and ``_done.set()`` never runs — any consumer blocked on + ``acomplete()`` hangs forever. The cleanup must therefore end with + synchronous operations only. + """ + response = "word " * 200 # long enough that streaming is still in progress + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[AlwaysUnknownReq()], + chunking="word", + ) + + assert result._orchestration_task is not None + # Yield once so the orchestration task enters its main loop before we + # cancel it. + await asyncio.sleep(0.01) + + # Same mechanism asyncio.wait_for uses on timeout. + result._orchestration_task.cancel() + + # _done must be set by the finally cleanup. A hang would time out here. + await asyncio.wait_for(result._done.wait(), timeout=2.0) + assert result._done.is_set() + + # acomplete() surfaces the CancelledError via task.exception() and must + # not hang. + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(result.acomplete(), timeout=2.0)