Skip to content

Commit a6c98ed

Browse files
committed
feat(stdlib): add stream_with_chunking() with per-chunk validation (#901)
Adds stream_with_chunking() — the core streaming orchestration primitive that consumes a ModelOutputThunk's async stream via a background asyncio.Task, applies a ChunkingStrategy to produce semantic chunks, and runs stream_validate() in parallel across all requirements at each chunk boundary. Key behaviours: - Early exit: if any requirement returns "fail" during streaming, generation is cancelled immediately via cancel_generation() and StreamChunkingResult.completed is set to False. - Final validation: after natural completion, validate() is called on all non-failed requirements. - Clone-per-call: requirements are cloned (copy(req)) before each invocation; originals are never mutated. - String aliases: "sentence", "word", "paragraph" map to the corresponding ChunkingStrategy subclasses. StreamChunkingResult exposes: - astream() — async iterator yielding individual validated chunks - acomplete() — await full completion including final validation - as_thunk — wrap full_text as a computed ModelOutputThunk - completed, full_text, final_validations, streaming_failures Re-exports StreamChunkingResult and stream_with_chunking from mellea.stdlib for day-to-day use. Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com>
1 parent 69c9c3f commit a6c98ed

2 files changed

Lines changed: 295 additions & 2 deletions

File tree

mellea/stdlib/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,20 @@
1010
``mellea.stdlib.session`` — for day-to-day use.
1111
1212
Streaming chunking strategies (for use with streaming validation) are available at
13-
``mellea.stdlib.chunking`` and re-exported here for convenience.
13+
``mellea.stdlib.chunking`` and re-exported here for convenience. The core streaming
14+
orchestration primitive :func:`~mellea.stdlib.streaming.stream_with_chunking` and
15+
its result type :class:`~mellea.stdlib.streaming.StreamChunkingResult` are also
16+
re-exported here.
1417
"""
1518

1619
from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker
20+
from .streaming import StreamChunkingResult, stream_with_chunking
1721

18-
__all__ = ["ChunkingStrategy", "ParagraphChunker", "SentenceChunker", "WordChunker"]
22+
__all__ = [
23+
"ChunkingStrategy",
24+
"ParagraphChunker",
25+
"SentenceChunker",
26+
"StreamChunkingResult",
27+
"WordChunker",
28+
"stream_with_chunking",
29+
]

mellea/stdlib/streaming.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
"""Streaming generation with per-chunk validation.
2+
3+
Provides :func:`stream_with_chunking`, the core orchestration primitive that
4+
consumes a streaming :class:`~mellea.core.base.ModelOutputThunk`, applies a
5+
:class:`~mellea.stdlib.chunking.ChunkingStrategy` to produce semantic chunks,
6+
and runs :meth:`~mellea.core.requirement.Requirement.stream_validate` on each
7+
chunk in parallel. Higher-level streaming APIs build on this function.
8+
"""
9+
10+
import asyncio
11+
from collections.abc import AsyncIterator, Sequence
12+
from copy import copy
13+
from typing import Any
14+
15+
from ..backends.model_options import ModelOption
16+
from ..core.backend import Backend
17+
from ..core.base import CBlock, Component, Context, ModelOutputThunk
18+
from ..core.requirement import PartialValidationResult, Requirement, ValidationResult
19+
from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker
20+
21+
_CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = {
22+
"sentence": SentenceChunker,
23+
"word": WordChunker,
24+
"paragraph": ParagraphChunker,
25+
}
26+
27+
28+
class StreamChunkingResult:
29+
"""Result of a :func:`stream_with_chunking` operation.
30+
31+
Provides async iteration over validated text chunks as they complete
32+
(:meth:`astream`), a blocking :meth:`acomplete` for awaiting the full
33+
result including final validation, and :attr:`as_thunk` for wrapping the
34+
output as a :class:`~mellea.core.base.ModelOutputThunk`.
35+
36+
Instances are created by :func:`stream_with_chunking`; do not instantiate
37+
directly.
38+
39+
Attributes:
40+
completed: ``False`` if the stream exited early because a requirement
41+
returned ``"fail"`` during streaming; ``True`` otherwise.
42+
full_text: The complete generated text accumulated during streaming.
43+
Available after :meth:`acomplete` returns.
44+
final_validations: :class:`~mellea.core.requirement.ValidationResult`
45+
objects from the final :meth:`~mellea.core.requirement.Requirement.validate`
46+
calls on all non-failed requirements. Available after
47+
:meth:`acomplete` returns.
48+
streaming_failures: ``(Requirement, PartialValidationResult)`` pairs
49+
for every requirement that returned ``"fail"`` during streaming.
50+
"""
51+
52+
def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None:
53+
"""Initialise with the MOT and context from the backend call."""
54+
self._mot = mot
55+
self._ctx = ctx
56+
self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue()
57+
self._orchestration_task: asyncio.Task[None] | None = None
58+
self._done = asyncio.Event()
59+
60+
self.completed: bool = True
61+
self.full_text: str = ""
62+
self.final_validations: list[ValidationResult] = []
63+
self.streaming_failures: list[tuple[Requirement, PartialValidationResult]] = []
64+
65+
async def astream(self) -> AsyncIterator[str]:
66+
"""Yield validated text chunks as they complete.
67+
68+
Each yielded string is a chunk that has passed per-chunk streaming
69+
validation (or the stream had no requirements). Iteration ends when
70+
all chunks have been yielded, whether the stream completed normally or
71+
was cancelled early on a ``"fail"`` result.
72+
73+
Yields:
74+
str: A validated text chunk from the chunking strategy.
75+
76+
Raises:
77+
Exception: Propagates any error from the background orchestration
78+
task.
79+
"""
80+
while True:
81+
item = await self._chunk_queue.get()
82+
if item is None:
83+
return
84+
if isinstance(item, Exception):
85+
raise item
86+
yield item
87+
88+
async def acomplete(self) -> None:
89+
"""Await full completion, including final validation.
90+
91+
After this method returns, :attr:`full_text`, :attr:`completed`,
92+
:attr:`final_validations`, and :attr:`streaming_failures` are all
93+
populated. If :meth:`astream` has already been consumed to
94+
exhaustion, this call is effectively a no-op.
95+
96+
Raises:
97+
Exception: Propagates any error from the orchestration task.
98+
"""
99+
await self._done.wait()
100+
if self._orchestration_task is not None and self._orchestration_task.done():
101+
exc = self._orchestration_task.exception()
102+
if exc is not None:
103+
raise exc
104+
105+
@property
106+
def as_thunk(self) -> ModelOutputThunk:
107+
"""Wrap the output as a computed :class:`~mellea.core.base.ModelOutputThunk`.
108+
109+
Returns a new thunk with ``value`` set to :attr:`full_text` and
110+
generation metadata copied from the original MOT. Safe to call on
111+
early-exit results; ``value`` will reflect whatever was accumulated
112+
before cancellation.
113+
114+
Returns:
115+
ModelOutputThunk: A computed thunk containing the streamed output.
116+
117+
Raises:
118+
RuntimeError: If called before :meth:`acomplete` has returned.
119+
"""
120+
if not self._done.is_set():
121+
raise RuntimeError(
122+
"as_thunk accessed before acomplete() — await acomplete() first"
123+
)
124+
thunk = ModelOutputThunk(value=self.full_text)
125+
thunk.generation = copy(self._mot.generation)
126+
return thunk
127+
128+
129+
async def _orchestrate_streaming(
130+
result: StreamChunkingResult,
131+
mot: ModelOutputThunk,
132+
ctx: Context,
133+
cloned_reqs: list[Requirement],
134+
chunking: ChunkingStrategy,
135+
val_backend: Backend,
136+
) -> None:
137+
accumulated = ""
138+
prev_chunk_count = 0
139+
failed_indices: set[int] = set()
140+
early_exit = False
141+
142+
try:
143+
while not mot.is_computed():
144+
try:
145+
delta = await mot.astream()
146+
except RuntimeError:
147+
break
148+
149+
accumulated += delta
150+
chunks = chunking.split(accumulated)
151+
new_chunks = chunks[prev_chunk_count:]
152+
153+
if new_chunks:
154+
active = [
155+
(i, req)
156+
for i, req in enumerate(cloned_reqs)
157+
if i not in failed_indices
158+
]
159+
if active:
160+
pvrs: list[PartialValidationResult] = list(
161+
await asyncio.gather(
162+
*[
163+
req.stream_validate(
164+
accumulated, backend=val_backend, ctx=ctx
165+
)
166+
for _, req in active
167+
]
168+
)
169+
)
170+
for (idx, req), pvr in zip(active, pvrs):
171+
if pvr.success == "fail":
172+
failed_indices.add(idx)
173+
result.streaming_failures.append((req, pvr))
174+
175+
if failed_indices:
176+
early_exit = True
177+
result.completed = False
178+
await mot.cancel_generation()
179+
for c in new_chunks:
180+
await result._chunk_queue.put(c)
181+
break
182+
183+
for c in new_chunks:
184+
await result._chunk_queue.put(c)
185+
prev_chunk_count = len(chunks)
186+
187+
result.full_text = accumulated
188+
189+
non_failed = [
190+
req for i, req in enumerate(cloned_reqs) if i not in failed_indices
191+
]
192+
if non_failed and not early_exit:
193+
result.final_validations = list(
194+
await asyncio.gather(
195+
*[req.validate(val_backend, ctx) for req in non_failed]
196+
)
197+
)
198+
199+
except Exception as exc:
200+
await result._chunk_queue.put(exc)
201+
finally:
202+
await result._chunk_queue.put(None)
203+
result._done.set()
204+
205+
206+
async def stream_with_chunking(
207+
action: Component[Any] | CBlock,
208+
backend: Backend,
209+
ctx: Context,
210+
*,
211+
quick_check_requirements: Sequence[Requirement] | None = None,
212+
chunking: str | ChunkingStrategy = "sentence",
213+
quick_check_backend: Backend | None = None,
214+
) -> StreamChunkingResult:
215+
"""Generate a streaming response with per-chunk validation.
216+
217+
Starts a backend generation with streaming enabled, consumes the
218+
:class:`~mellea.core.base.ModelOutputThunk`'s async stream in a single
219+
background task, splits the accumulated text using *chunking*, and runs
220+
:meth:`~mellea.core.requirement.Requirement.stream_validate` on each new
221+
chunk in parallel across all requirements.
222+
223+
If any requirement returns ``"fail"`` during streaming validation, the
224+
generation is cancelled immediately (via
225+
:meth:`~mellea.core.base.ModelOutputThunk.cancel_generation`) and
226+
:attr:`StreamChunkingResult.completed` is set to ``False``.
227+
228+
After the stream ends (naturally or via early exit), ``validate()`` is
229+
called on all requirements that did not return ``"fail"``. Requirements
230+
are cloned (``copy(req)``) before use so originals are never mutated.
231+
232+
``stream_validate`` receives the *accumulated* model output so far, not
233+
just the current chunk. The chunking strategy determines *when* it is
234+
called (at chunk boundaries). Requirements that want delta-only
235+
processing track ``self._seen_len`` and slice
236+
``accumulated[self._seen_len:]``.
237+
238+
Note:
239+
v1 retry is simple re-invocation of this function. Plugin hooks
240+
(``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire
241+
on retries — use the ``#902`` event types for observability instead.
242+
243+
Args:
244+
action: The component or content block to generate from.
245+
backend: The backend used for generation and final validation.
246+
ctx: The generation context.
247+
quick_check_requirements: Sequence of requirements to validate against
248+
each chunk during streaming. ``None`` disables streaming validation
249+
(chunks are still produced; ``validate()`` is not called at stream end).
250+
chunking: Chunking strategy — either a :class:`~mellea.stdlib.chunking.ChunkingStrategy`
251+
instance or one of the string aliases ``"sentence"`` (default),
252+
``"word"``, or ``"paragraph"``.
253+
quick_check_backend: Optional alternate backend for both
254+
``stream_validate`` and final ``validate`` calls. When ``None``,
255+
*backend* is used for validation.
256+
257+
Returns:
258+
StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream`
259+
for incremental chunk consumption and
260+
:meth:`~StreamChunkingResult.acomplete` for blocking until done.
261+
"""
262+
if isinstance(chunking, str):
263+
cls = _CHUNKING_ALIASES.get(chunking)
264+
if cls is None:
265+
raise ValueError(
266+
f"Unknown chunking alias {chunking!r}. Choose from: {list(_CHUNKING_ALIASES)}"
267+
)
268+
chunking = cls()
269+
270+
opts: dict[str, Any] = {ModelOption.STREAM: True}
271+
mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts)
272+
273+
result = StreamChunkingResult(mot, gen_ctx)
274+
275+
cloned_reqs = [copy(req) for req in (quick_check_requirements or [])]
276+
val_backend = quick_check_backend if quick_check_backend is not None else backend
277+
278+
result._orchestration_task = asyncio.create_task(
279+
_orchestrate_streaming(result, mot, gen_ctx, cloned_reqs, chunking, val_backend)
280+
)
281+
282+
return result

0 commit comments

Comments
 (0)