|
| 1 | +"""Streaming generation with per-chunk validation. |
| 2 | +
|
| 3 | +Provides :func:`stream_with_chunking`, the core orchestration primitive that |
| 4 | +consumes a streaming :class:`~mellea.core.base.ModelOutputThunk`, applies a |
| 5 | +:class:`~mellea.stdlib.chunking.ChunkingStrategy` to produce semantic chunks, |
| 6 | +and runs :meth:`~mellea.core.requirement.Requirement.stream_validate` on each |
| 7 | +chunk in parallel. Higher-level streaming APIs build on this function. |
| 8 | +""" |
| 9 | + |
| 10 | +import asyncio |
| 11 | +from collections.abc import AsyncIterator, Sequence |
| 12 | +from copy import copy |
| 13 | +from typing import Any |
| 14 | + |
| 15 | +from ..backends.model_options import ModelOption |
| 16 | +from ..core.backend import Backend |
| 17 | +from ..core.base import CBlock, Component, Context, ModelOutputThunk |
| 18 | +from ..core.requirement import PartialValidationResult, Requirement, ValidationResult |
| 19 | +from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker |
| 20 | + |
| 21 | +_CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = { |
| 22 | + "sentence": SentenceChunker, |
| 23 | + "word": WordChunker, |
| 24 | + "paragraph": ParagraphChunker, |
| 25 | +} |
| 26 | + |
| 27 | + |
| 28 | +class StreamChunkingResult: |
| 29 | + """Result of a :func:`stream_with_chunking` operation. |
| 30 | +
|
| 31 | + Provides async iteration over validated text chunks as they complete |
| 32 | + (:meth:`astream`), a blocking :meth:`acomplete` for awaiting the full |
| 33 | + result including final validation, and :attr:`as_thunk` for wrapping the |
| 34 | + output as a :class:`~mellea.core.base.ModelOutputThunk`. |
| 35 | +
|
| 36 | + Instances are created by :func:`stream_with_chunking`; do not instantiate |
| 37 | + directly. |
| 38 | +
|
| 39 | + Attributes: |
| 40 | + completed: ``False`` if the stream exited early because a requirement |
| 41 | + returned ``"fail"`` during streaming; ``True`` otherwise. |
| 42 | + full_text: The complete generated text accumulated during streaming. |
| 43 | + Available after :meth:`acomplete` returns. |
| 44 | + final_validations: :class:`~mellea.core.requirement.ValidationResult` |
| 45 | + objects from the final :meth:`~mellea.core.requirement.Requirement.validate` |
| 46 | + calls on all non-failed requirements. Available after |
| 47 | + :meth:`acomplete` returns. |
| 48 | + streaming_failures: ``(Requirement, PartialValidationResult)`` pairs |
| 49 | + for every requirement that returned ``"fail"`` during streaming. |
| 50 | + """ |
| 51 | + |
| 52 | + def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: |
| 53 | + """Initialise with the MOT and context from the backend call.""" |
| 54 | + self._mot = mot |
| 55 | + self._ctx = ctx |
| 56 | + self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue() |
| 57 | + self._orchestration_task: asyncio.Task[None] | None = None |
| 58 | + self._done = asyncio.Event() |
| 59 | + |
| 60 | + self.completed: bool = True |
| 61 | + self.full_text: str = "" |
| 62 | + self.final_validations: list[ValidationResult] = [] |
| 63 | + self.streaming_failures: list[tuple[Requirement, PartialValidationResult]] = [] |
| 64 | + |
| 65 | + async def astream(self) -> AsyncIterator[str]: |
| 66 | + """Yield validated text chunks as they complete. |
| 67 | +
|
| 68 | + Each yielded string is a chunk that has passed per-chunk streaming |
| 69 | + validation (or the stream had no requirements). Iteration ends when |
| 70 | + all chunks have been yielded, whether the stream completed normally or |
| 71 | + was cancelled early on a ``"fail"`` result. |
| 72 | +
|
| 73 | + Yields: |
| 74 | + str: A validated text chunk from the chunking strategy. |
| 75 | +
|
| 76 | + Raises: |
| 77 | + Exception: Propagates any error from the background orchestration |
| 78 | + task. |
| 79 | + """ |
| 80 | + while True: |
| 81 | + item = await self._chunk_queue.get() |
| 82 | + if item is None: |
| 83 | + return |
| 84 | + if isinstance(item, Exception): |
| 85 | + raise item |
| 86 | + yield item |
| 87 | + |
| 88 | + async def acomplete(self) -> None: |
| 89 | + """Await full completion, including final validation. |
| 90 | +
|
| 91 | + After this method returns, :attr:`full_text`, :attr:`completed`, |
| 92 | + :attr:`final_validations`, and :attr:`streaming_failures` are all |
| 93 | + populated. If :meth:`astream` has already been consumed to |
| 94 | + exhaustion, this call is effectively a no-op. |
| 95 | +
|
| 96 | + Raises: |
| 97 | + Exception: Propagates any error from the orchestration task. |
| 98 | + """ |
| 99 | + await self._done.wait() |
| 100 | + if self._orchestration_task is not None and self._orchestration_task.done(): |
| 101 | + exc = self._orchestration_task.exception() |
| 102 | + if exc is not None: |
| 103 | + raise exc |
| 104 | + |
| 105 | + @property |
| 106 | + def as_thunk(self) -> ModelOutputThunk: |
| 107 | + """Wrap the output as a computed :class:`~mellea.core.base.ModelOutputThunk`. |
| 108 | +
|
| 109 | + Returns a new thunk with ``value`` set to :attr:`full_text` and |
| 110 | + generation metadata copied from the original MOT. Safe to call on |
| 111 | + early-exit results; ``value`` will reflect whatever was accumulated |
| 112 | + before cancellation. |
| 113 | +
|
| 114 | + Returns: |
| 115 | + ModelOutputThunk: A computed thunk containing the streamed output. |
| 116 | +
|
| 117 | + Raises: |
| 118 | + RuntimeError: If called before :meth:`acomplete` has returned. |
| 119 | + """ |
| 120 | + if not self._done.is_set(): |
| 121 | + raise RuntimeError( |
| 122 | + "as_thunk accessed before acomplete() — await acomplete() first" |
| 123 | + ) |
| 124 | + thunk = ModelOutputThunk(value=self.full_text) |
| 125 | + thunk.generation = copy(self._mot.generation) |
| 126 | + return thunk |
| 127 | + |
| 128 | + |
| 129 | +async def _orchestrate_streaming( |
| 130 | + result: StreamChunkingResult, |
| 131 | + mot: ModelOutputThunk, |
| 132 | + ctx: Context, |
| 133 | + cloned_reqs: list[Requirement], |
| 134 | + chunking: ChunkingStrategy, |
| 135 | + val_backend: Backend, |
| 136 | +) -> None: |
| 137 | + accumulated = "" |
| 138 | + prev_chunk_count = 0 |
| 139 | + failed_indices: set[int] = set() |
| 140 | + early_exit = False |
| 141 | + |
| 142 | + try: |
| 143 | + while not mot.is_computed(): |
| 144 | + try: |
| 145 | + delta = await mot.astream() |
| 146 | + except RuntimeError: |
| 147 | + break |
| 148 | + |
| 149 | + accumulated += delta |
| 150 | + chunks = chunking.split(accumulated) |
| 151 | + new_chunks = chunks[prev_chunk_count:] |
| 152 | + |
| 153 | + if new_chunks: |
| 154 | + active = [ |
| 155 | + (i, req) |
| 156 | + for i, req in enumerate(cloned_reqs) |
| 157 | + if i not in failed_indices |
| 158 | + ] |
| 159 | + if active: |
| 160 | + pvrs: list[PartialValidationResult] = list( |
| 161 | + await asyncio.gather( |
| 162 | + *[ |
| 163 | + req.stream_validate( |
| 164 | + accumulated, backend=val_backend, ctx=ctx |
| 165 | + ) |
| 166 | + for _, req in active |
| 167 | + ] |
| 168 | + ) |
| 169 | + ) |
| 170 | + for (idx, req), pvr in zip(active, pvrs): |
| 171 | + if pvr.success == "fail": |
| 172 | + failed_indices.add(idx) |
| 173 | + result.streaming_failures.append((req, pvr)) |
| 174 | + |
| 175 | + if failed_indices: |
| 176 | + early_exit = True |
| 177 | + result.completed = False |
| 178 | + await mot.cancel_generation() |
| 179 | + for c in new_chunks: |
| 180 | + await result._chunk_queue.put(c) |
| 181 | + break |
| 182 | + |
| 183 | + for c in new_chunks: |
| 184 | + await result._chunk_queue.put(c) |
| 185 | + prev_chunk_count = len(chunks) |
| 186 | + |
| 187 | + result.full_text = accumulated |
| 188 | + |
| 189 | + non_failed = [ |
| 190 | + req for i, req in enumerate(cloned_reqs) if i not in failed_indices |
| 191 | + ] |
| 192 | + if non_failed and not early_exit: |
| 193 | + result.final_validations = list( |
| 194 | + await asyncio.gather( |
| 195 | + *[req.validate(val_backend, ctx) for req in non_failed] |
| 196 | + ) |
| 197 | + ) |
| 198 | + |
| 199 | + except Exception as exc: |
| 200 | + await result._chunk_queue.put(exc) |
| 201 | + finally: |
| 202 | + await result._chunk_queue.put(None) |
| 203 | + result._done.set() |
| 204 | + |
| 205 | + |
| 206 | +async def stream_with_chunking( |
| 207 | + action: Component[Any] | CBlock, |
| 208 | + backend: Backend, |
| 209 | + ctx: Context, |
| 210 | + *, |
| 211 | + quick_check_requirements: Sequence[Requirement] | None = None, |
| 212 | + chunking: str | ChunkingStrategy = "sentence", |
| 213 | + quick_check_backend: Backend | None = None, |
| 214 | +) -> StreamChunkingResult: |
| 215 | + """Generate a streaming response with per-chunk validation. |
| 216 | +
|
| 217 | + Starts a backend generation with streaming enabled, consumes the |
| 218 | + :class:`~mellea.core.base.ModelOutputThunk`'s async stream in a single |
| 219 | + background task, splits the accumulated text using *chunking*, and runs |
| 220 | + :meth:`~mellea.core.requirement.Requirement.stream_validate` on each new |
| 221 | + chunk in parallel across all requirements. |
| 222 | +
|
| 223 | + If any requirement returns ``"fail"`` during streaming validation, the |
| 224 | + generation is cancelled immediately (via |
| 225 | + :meth:`~mellea.core.base.ModelOutputThunk.cancel_generation`) and |
| 226 | + :attr:`StreamChunkingResult.completed` is set to ``False``. |
| 227 | +
|
| 228 | + After the stream ends (naturally or via early exit), ``validate()`` is |
| 229 | + called on all requirements that did not return ``"fail"``. Requirements |
| 230 | + are cloned (``copy(req)``) before use so originals are never mutated. |
| 231 | +
|
| 232 | + ``stream_validate`` receives the *accumulated* model output so far, not |
| 233 | + just the current chunk. The chunking strategy determines *when* it is |
| 234 | + called (at chunk boundaries). Requirements that want delta-only |
| 235 | + processing track ``self._seen_len`` and slice |
| 236 | + ``accumulated[self._seen_len:]``. |
| 237 | +
|
| 238 | + Note: |
| 239 | + v1 retry is simple re-invocation of this function. Plugin hooks |
| 240 | + (``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire |
| 241 | + on retries — use the ``#902`` event types for observability instead. |
| 242 | +
|
| 243 | + Args: |
| 244 | + action: The component or content block to generate from. |
| 245 | + backend: The backend used for generation and final validation. |
| 246 | + ctx: The generation context. |
| 247 | + quick_check_requirements: Sequence of requirements to validate against |
| 248 | + each chunk during streaming. ``None`` disables streaming validation |
| 249 | + (chunks are still produced; ``validate()`` is not called at stream end). |
| 250 | + chunking: Chunking strategy — either a :class:`~mellea.stdlib.chunking.ChunkingStrategy` |
| 251 | + instance or one of the string aliases ``"sentence"`` (default), |
| 252 | + ``"word"``, or ``"paragraph"``. |
| 253 | + quick_check_backend: Optional alternate backend for both |
| 254 | + ``stream_validate`` and final ``validate`` calls. When ``None``, |
| 255 | + *backend* is used for validation. |
| 256 | +
|
| 257 | + Returns: |
| 258 | + StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream` |
| 259 | + for incremental chunk consumption and |
| 260 | + :meth:`~StreamChunkingResult.acomplete` for blocking until done. |
| 261 | + """ |
| 262 | + if isinstance(chunking, str): |
| 263 | + cls = _CHUNKING_ALIASES.get(chunking) |
| 264 | + if cls is None: |
| 265 | + raise ValueError( |
| 266 | + f"Unknown chunking alias {chunking!r}. Choose from: {list(_CHUNKING_ALIASES)}" |
| 267 | + ) |
| 268 | + chunking = cls() |
| 269 | + |
| 270 | + opts: dict[str, Any] = {ModelOption.STREAM: True} |
| 271 | + mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) |
| 272 | + |
| 273 | + result = StreamChunkingResult(mot, gen_ctx) |
| 274 | + |
| 275 | + cloned_reqs = [copy(req) for req in (quick_check_requirements or [])] |
| 276 | + val_backend = quick_check_backend if quick_check_backend is not None else backend |
| 277 | + |
| 278 | + result._orchestration_task = asyncio.create_task( |
| 279 | + _orchestrate_streaming(result, mot, gen_ctx, cloned_reqs, chunking, val_backend) |
| 280 | + ) |
| 281 | + |
| 282 | + return result |
0 commit comments