Skip to content

Commit 61448a9

Browse files
committed
feat(stdlib): flush trailing chunk fragment at end of stream
ChunkingStrategy.split() withholds the trailing fragment by design (#899). Previously the orchestrator discarded it — it appeared in full_text and the final validate() saw it, but it was never yielded to astream() consumers and never seen by stream_validate. For a response that did not end in a chunk terminator (e.g. "Sentence one. Sentence two." with no trailing whitespace under SentenceChunker), the last sentence silently bypassed streaming validation. Adds ChunkingStrategy.flush(accumulated_text) -> list[str]: - Default in the ABC returns [] (backward-compatible — external chunkers retain the old discard behaviour until they opt in). - SentenceChunker, WordChunker, ParagraphChunker each override to return the withheld trailing fragment as a single-element list. _orchestrate_streaming calls chunking.flush(accumulated) after the main loop (only when the stream ended naturally, not on early exit — a cancelled stream's trailing fragment is by definition incomplete). Each flushed chunk goes through the same stream_validate / emit path as regular chunks, so the "no unvalidated content reaches the consumer" invariant extends to the trailing fragment, and a fail on the fragment still records a streaming failure and skips final validate(). Tests: - 13 new chunker tests covering the default-discard behaviour and each built-in's flush logic (empty input, fragment-present, already- terminated cases). - test_trailing_fragment_is_flushed_to_consumer: stream_validate sees the fragment and astream yields it. - test_early_exit_on_trailing_fragment: fail on the flushed fragment propagates to streaming_failures and skips final validation. Assisted-by: Claude Code
1 parent 35df77f commit 61448a9

4 files changed

Lines changed: 298 additions & 22 deletions

File tree

mellea/stdlib/chunking.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,27 @@ def split(self, accumulated_text: str) -> list[str]:
3535
"""
3636
...
3737

38+
def flush(self, accumulated_text: str) -> list[str]:
39+
"""Return any trailing fragment that ``split`` withheld.
40+
41+
Called once by the orchestrator after the stream has ended naturally
42+
(not on early-exit cancellation). Gives the chunker a chance to
43+
release the final fragment that did not reach a terminator.
44+
45+
The default implementation returns an empty list — the trailing
46+
fragment is discarded. Built-in chunkers override this to return
47+
the withheld fragment as a single-element list when non-empty.
48+
49+
Args:
50+
accumulated_text: The full accumulated text at stream end.
51+
52+
Returns:
53+
The trailing fragment as ``[fragment]`` if it should be treated
54+
as a final chunk, or an empty list to discard it.
55+
"""
56+
_ = accumulated_text
57+
return []
58+
3859

3960
# Sentence boundary: sentence-ending punctuation, optionally followed by a closing
4061
# quote or paren, then whitespace.
@@ -94,6 +115,19 @@ def split(self, accumulated_text: str) -> list[str]:
94115

95116
return chunks
96117

118+
def flush(self, accumulated_text: str) -> list[str]:
119+
"""Return the trailing sentence fragment (if any) as a final chunk."""
120+
if not accumulated_text:
121+
return []
122+
remaining = accumulated_text
123+
while True:
124+
match = _SENTENCE_BOUNDARY.search(remaining)
125+
if match is None:
126+
break
127+
remaining = remaining[match.end() :].lstrip()
128+
trailing = remaining.strip()
129+
return [trailing] if trailing else []
130+
97131

98132
class WordChunker(ChunkingStrategy):
99133
"""Splits accumulated text on whitespace boundaries.
@@ -134,6 +168,18 @@ def split(self, accumulated_text: str) -> list[str]:
134168

135169
return parts
136170

171+
def flush(self, accumulated_text: str) -> list[str]:
172+
"""Return the trailing word fragment (if any) as a final chunk."""
173+
if not accumulated_text:
174+
return []
175+
if accumulated_text[-1].isspace():
176+
return []
177+
parts = _WHITESPACE.split(accumulated_text)
178+
for part in reversed(parts):
179+
if part:
180+
return [part]
181+
return []
182+
137183

138184
class ParagraphChunker(ChunkingStrategy):
139185
r"""Splits accumulated text on double-newline paragraph boundaries.
@@ -168,3 +214,13 @@ def split(self, accumulated_text: str) -> list[str]:
168214

169215
# _PARA_BOUNDARY.split on leading \n\n produces an empty first element.
170216
return [p for p in parts if p]
217+
218+
def flush(self, accumulated_text: str) -> list[str]:
219+
"""Return the trailing paragraph fragment (if any) as a final chunk."""
220+
if not accumulated_text:
221+
return []
222+
if _PARA_BOUNDARY_END.search(accumulated_text):
223+
return []
224+
parts = _PARA_BOUNDARY.split(accumulated_text)
225+
trailing = parts[-1] if parts else ""
226+
return [trailing] if trailing else []

mellea/stdlib/streaming.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,35 @@ async def _orchestrate_streaming(
144144
failed_indices: set[int] = set()
145145
early_exit = False
146146

147+
async def _validate_and_emit(c: str) -> bool:
148+
"""Run stream_validate on chunk c across active requirements.
149+
150+
Returns True if a failure was recorded (caller should early-exit),
151+
False otherwise (chunk was emitted to the consumer queue).
152+
"""
153+
active = [
154+
(i, req) for i, req in enumerate(cloned_reqs) if i not in failed_indices
155+
]
156+
if active:
157+
pvrs: list[PartialValidationResult] = list(
158+
await asyncio.gather(
159+
*[
160+
req.stream_validate(c, backend=val_backend, ctx=ctx)
161+
for _, req in active
162+
]
163+
)
164+
)
165+
for (idx, req), pvr in zip(active, pvrs):
166+
if pvr.success == "fail":
167+
failed_indices.add(idx)
168+
result.streaming_failures.append((req, pvr))
169+
170+
if failed_indices:
171+
return True
172+
173+
await result._chunk_queue.put(c)
174+
return False
175+
147176
try:
148177
while not mot.is_computed():
149178
try:
@@ -157,36 +186,27 @@ async def _orchestrate_streaming(
157186
prev_chunk_count = len(chunks)
158187

159188
for c in new_chunks:
160-
active = [
161-
(i, req)
162-
for i, req in enumerate(cloned_reqs)
163-
if i not in failed_indices
164-
]
165-
if active:
166-
pvrs: list[PartialValidationResult] = list(
167-
await asyncio.gather(
168-
*[
169-
req.stream_validate(c, backend=val_backend, ctx=ctx)
170-
for _, req in active
171-
]
172-
)
173-
)
174-
for (idx, req), pvr in zip(active, pvrs):
175-
if pvr.success == "fail":
176-
failed_indices.add(idx)
177-
result.streaming_failures.append((req, pvr))
178-
179-
if failed_indices:
189+
failed = await _validate_and_emit(c)
190+
if failed:
180191
early_exit = True
181192
result.completed = False
182193
await mot.cancel_generation()
183194
break
184195

185-
await result._chunk_queue.put(c)
186-
187196
if early_exit:
188197
break
189198

199+
# Stream ended naturally: flush any withheld trailing fragment and
200+
# run stream_validate on it. Skipped on early exit — the generation
201+
# was cancelled, the trailing fragment is incomplete.
202+
if not early_exit:
203+
for c in chunking.flush(accumulated):
204+
failed = await _validate_and_emit(c)
205+
if failed:
206+
early_exit = True
207+
result.completed = False
208+
break
209+
190210
result.full_text = accumulated
191211

192212
non_failed = [
@@ -238,6 +258,12 @@ async def stream_with_chunking(
238258
failing chunk is not emitted to the consumer; use
239259
:attr:`StreamChunkingResult.streaming_failures` to inspect what failed.
240260
261+
When the stream ends naturally, any trailing fragment withheld by the
262+
chunking strategy (see :meth:`~mellea.stdlib.chunking.ChunkingStrategy.flush`)
263+
is released as a final chunk and run through ``stream_validate`` on the
264+
same terms as the regular chunks. On early exit, the trailing fragment
265+
is discarded because the generation was cancelled mid-token.
266+
241267
After the stream ends (naturally or via early exit), ``validate()`` is
242268
called on all requirements that did not return ``"fail"``. Requirements
243269
are cloned (``copy(req)``) before use so originals are never mutated.

test/stdlib/test_chunking.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,79 @@ def test_paragraph_chunker_incremental_simulation():
242242
"First paragraph.",
243243
"Second paragraph.",
244244
]
245+
246+
247+
# ---------------------------------------------------------------------------
248+
# flush() — trailing-fragment release at end of stream
249+
# ---------------------------------------------------------------------------
250+
251+
252+
def test_default_flush_returns_empty_list():
253+
"""The ABC default discards the trailing fragment."""
254+
255+
class Minimal(ChunkingStrategy):
256+
def split(self, accumulated_text: str) -> list[str]:
257+
_ = accumulated_text
258+
return []
259+
260+
assert Minimal().flush("anything at all") == []
261+
assert Minimal().flush("") == []
262+
263+
264+
def test_sentence_chunker_flush_empty():
265+
assert SentenceChunker().flush("") == []
266+
267+
268+
def test_sentence_chunker_flush_only_complete():
269+
"""All text ends in a complete sentence with trailing whitespace → no fragment."""
270+
assert SentenceChunker().flush("One. Two. ") == []
271+
272+
273+
def test_sentence_chunker_flush_trailing_fragment():
274+
"""Final sentence without trailing whitespace is released by flush."""
275+
assert SentenceChunker().flush("One. Two without period") == ["Two without period"]
276+
277+
278+
def test_sentence_chunker_flush_terminated_no_trailing_space():
279+
"""Final sentence with terminator but no trailing whitespace is a fragment
280+
under split() semantics and gets released by flush()."""
281+
assert SentenceChunker().flush("One. Two.") == ["Two."]
282+
283+
284+
def test_sentence_chunker_flush_single_sentence_no_terminator():
285+
assert SentenceChunker().flush("Incomplete sentence") == ["Incomplete sentence"]
286+
287+
288+
def test_word_chunker_flush_empty():
289+
assert WordChunker().flush("") == []
290+
291+
292+
def test_word_chunker_flush_trailing_whitespace():
293+
"""Trailing whitespace means all words are complete → no fragment."""
294+
assert WordChunker().flush("one two three ") == []
295+
296+
297+
def test_word_chunker_flush_trailing_fragment():
298+
assert WordChunker().flush("one two three") == ["three"]
299+
300+
301+
def test_word_chunker_flush_single_word():
302+
assert WordChunker().flush("solo") == ["solo"]
303+
304+
305+
def test_paragraph_chunker_flush_empty():
306+
assert ParagraphChunker().flush("") == []
307+
308+
309+
def test_paragraph_chunker_flush_only_complete():
310+
assert ParagraphChunker().flush("Para one.\n\nPara two.\n\n") == []
311+
312+
313+
def test_paragraph_chunker_flush_trailing_fragment():
314+
assert ParagraphChunker().flush("Para one.\n\nPara two (no sep)") == [
315+
"Para two (no sep)"
316+
]
317+
318+
319+
def test_paragraph_chunker_flush_single_paragraph_no_separator():
320+
assert ParagraphChunker().flush("Only paragraph") == ["Only paragraph"]

test/stdlib/test_streaming.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,124 @@ def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq:
443443
assert not all(len(seen[i]) < len(seen[i + 1]) for i in range(len(seen) - 1))
444444

445445

446+
@pytest.mark.asyncio
447+
async def test_trailing_fragment_is_flushed_to_consumer() -> None:
448+
"""Response without trailing whitespace: final sentence reaches astream() and stream_validate."""
449+
450+
class ChunkRecordingReq(Requirement):
451+
def __init__(self) -> None:
452+
self.seen_chunks: list[str] = []
453+
454+
def __copy__(self) -> "ChunkRecordingReq":
455+
clone = ChunkRecordingReq()
456+
clone.seen_chunks = []
457+
return clone
458+
459+
def format_for_llm(self) -> str:
460+
return "chunk recorder"
461+
462+
async def stream_validate(
463+
self, chunk: str, *, backend: Any, ctx: Any
464+
) -> PartialValidationResult:
465+
self.seen_chunks.append(chunk)
466+
return PartialValidationResult("unknown")
467+
468+
async def validate(
469+
self,
470+
backend: Any,
471+
ctx: Any,
472+
*,
473+
format: Any = None,
474+
model_options: Any = None,
475+
) -> ValidationResult:
476+
return ValidationResult(result=True)
477+
478+
# No trailing whitespace after the final sentence — SentenceChunker withholds it.
479+
response = "First sentence. Second sentence."
480+
backend = StreamingMockBackend(response, token_size=4)
481+
req = ChunkRecordingReq()
482+
483+
captured: list[ChunkRecordingReq] = []
484+
original_copy = ChunkRecordingReq.__copy__
485+
486+
def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq:
487+
clone = original_copy(self)
488+
captured.append(clone)
489+
return clone
490+
491+
ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign]
492+
try:
493+
result = await stream_with_chunking(
494+
_action(),
495+
backend,
496+
_ctx(),
497+
quick_check_requirements=[req],
498+
chunking="sentence",
499+
)
500+
yielded: list[str] = []
501+
async for chunk in result.astream():
502+
yielded.append(chunk)
503+
await result.acomplete()
504+
finally:
505+
ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign]
506+
507+
# Both sentences reach the consumer, including the terminating one without trailing whitespace.
508+
assert yielded == ["First sentence.", "Second sentence."]
509+
# stream_validate was called on both — the flush path is not a shortcut.
510+
assert captured[0].seen_chunks == ["First sentence.", "Second sentence."]
511+
assert result.completed is True
512+
513+
514+
@pytest.mark.asyncio
515+
async def test_early_exit_on_trailing_fragment() -> None:
516+
"""A fail on the flushed fragment records a streaming failure and skips final validate()."""
517+
518+
class FailOnSecondSentence(Requirement):
519+
def __init__(self) -> None:
520+
self._count = 0
521+
522+
def format_for_llm(self) -> str:
523+
return "fail on second sentence"
524+
525+
async def stream_validate(
526+
self, chunk: str, *, backend: Any, ctx: Any
527+
) -> PartialValidationResult:
528+
_ = chunk, backend, ctx
529+
self._count += 1
530+
if self._count >= 2:
531+
return PartialValidationResult("fail", reason="second sentence hit")
532+
return PartialValidationResult("unknown")
533+
534+
async def validate(
535+
self,
536+
backend: Any,
537+
ctx: Any,
538+
*,
539+
format: Any = None,
540+
model_options: Any = None,
541+
) -> ValidationResult:
542+
return ValidationResult(result=True)
543+
544+
response = "First sentence. Second sentence."
545+
backend = StreamingMockBackend(response, token_size=4)
546+
req = FailOnSecondSentence()
547+
548+
result = await stream_with_chunking(
549+
_action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence"
550+
)
551+
yielded: list[str] = []
552+
async for chunk in result.astream():
553+
yielded.append(chunk)
554+
await result.acomplete()
555+
556+
assert result.completed is False
557+
assert len(result.streaming_failures) == 1
558+
# First sentence was emitted; second (the flushed fragment) failed and wasn't emitted.
559+
assert yielded == ["First sentence."]
560+
# Early exit on fail skips final validate().
561+
assert result.final_validations == []
562+
563+
446564
@pytest.mark.asyncio
447565
async def test_no_requirements_streams_without_validation() -> None:
448566
"""quick_check_requirements=None → chunks produced, no validate() called."""

0 commit comments

Comments
 (0)