Skip to content

Commit def10b6

Browse files
committed
fix(stdlib): address review feedback on streaming validation
Addresses issues raised by independent review on top of PR #942. Orchestrator (mellea/stdlib/streaming.py): - except Exception now calls mot.cancel_generation() before surfacing the exception to the consumer — previously the backend producer was left running, eventually blocking on mot._async_queue (maxsize=20). Cleanup failures are logged via MelleaLogger.warning with a TODO(#902) marker; #902 replaces the log with a proper ErrorEvent. - RuntimeError catch in the astream() loop now re-raises unless mot.is_computed() is true, so only the documented "already computed" race is swallowed. - astream() docstring now states the single-consumer contract explicitly; a second iteration blocks on an empty queue with no sentinel to deliver. - as_thunk docstring now flags the early-exit case: cancel_generation forces is_computed=True without running post_processing(), so generation.usage and related telemetry fields may be None. Chunker (mellea/stdlib/chunking.py): - SentenceChunker.flush switches from .strip() to .rstrip() with a comment explaining why: the loop's lstrip has already removed leading whitespace, and trailing whitespace on a sentence fragment is non-semantic (consistent with split() returning sentences without trailing whitespace). - ParagraphChunker.flush adds a docstring noting the deliberate asymmetry: paragraph fragments are returned byte-for-byte because internal whitespace (e.g. trailing \n of a list item) can be semantically meaningful. Tests (test/stdlib/test_streaming.py): - test_stream_validate_receives_individual_chunks now uses exact- match on the captured chunk list, which directly regresses if someone reverts to accumulated-text semantics. - test_multiple_chunks_in_one_batch_with_mid_batch_fail: response fed as one large token so split() yields 4 sentences at once; verifies chunk 1 emits, chunk 2 fails (not emitted), chunks 3 and 4 are neither validated nor emitted. - test_cancel_generation_invoked_on_fail: spies on ModelOutputThunk.cancel_generation and asserts it was called on the "fail" early-exit path. - test_exception_in_stream_validate_cancels_generation: a requirement that raises must cause cancel_generation to run and the exception to surface via astream()/acomplete() without hanging. Telemetry observability (orchestrator-level spans, metrics, span events) remains deferred to #902 per the epic, which now has the acceptance criteria updated to cover event emission, the OTEL bridge, and the ErrorEvent type that will replace the MelleaLogger stopgap. Assisted-by: Claude Code
1 parent 61448a9 commit def10b6

3 files changed

Lines changed: 253 additions & 12 deletions

File tree

mellea/stdlib/chunking.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,15 @@ def split(self, accumulated_text: str) -> list[str]:
116116
return chunks
117117

118118
def flush(self, accumulated_text: str) -> list[str]:
119-
"""Return the trailing sentence fragment (if any) as a final chunk."""
119+
"""Return the trailing sentence fragment (if any) as a final chunk.
120+
121+
Trailing whitespace on the fragment is non-semantic for sentence
122+
boundaries and is dropped via ``rstrip``. Leading whitespace is
123+
already removed by the loop's ``lstrip`` on each advance, so no
124+
``lstrip`` is needed here. The result is the fragment's content
125+
only, consistent with how :meth:`split` returns sentences without
126+
trailing whitespace.
127+
"""
120128
if not accumulated_text:
121129
return []
122130
remaining = accumulated_text
@@ -125,7 +133,7 @@ def flush(self, accumulated_text: str) -> list[str]:
125133
if match is None:
126134
break
127135
remaining = remaining[match.end() :].lstrip()
128-
trailing = remaining.strip()
136+
trailing = remaining.rstrip()
129137
return [trailing] if trailing else []
130138

131139

@@ -216,7 +224,15 @@ def split(self, accumulated_text: str) -> list[str]:
216224
return [p for p in parts if p]
217225

218226
def flush(self, accumulated_text: str) -> list[str]:
219-
"""Return the trailing paragraph fragment (if any) as a final chunk."""
227+
r"""Return the trailing paragraph fragment (if any) as a final chunk.
228+
229+
Unlike :class:`SentenceChunker.flush`, the fragment is returned
230+
byte-for-byte without stripping. Internal whitespace — including
231+
a trailing single ``\n`` — can be semantically meaningful inside
232+
a paragraph (e.g. a list item or a deliberate line break), and a
233+
consumer validating paragraph content should see the fragment as
234+
it was withheld.
235+
"""
220236
if not accumulated_text:
221237
return []
222238
if _PARA_BOUNDARY_END.search(accumulated_text):

mellea/stdlib/streaming.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ..core.backend import Backend
1717
from ..core.base import CBlock, Component, Context, ModelOutputThunk
1818
from ..core.requirement import PartialValidationResult, Requirement, ValidationResult
19+
from ..core.utils import MelleaLogger
1920
from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker
2021

2122
_CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = {
@@ -75,6 +76,14 @@ async def astream(self) -> AsyncIterator[str]:
7576
all chunks have been yielded, whether the stream completed normally or
7677
was cancelled early on a ``"fail"`` result.
7778
79+
**Single-consumer.** Chunks are delivered via an
80+
:class:`asyncio.Queue` that this method drains; calling
81+
``astream()`` a second time on the same result blocks indefinitely
82+
because the queue is empty and the terminating ``None`` sentinel
83+
has already been consumed. If you need the chunks after
84+
iteration, capture them into a list during the first pass or use
85+
:attr:`full_text` after :meth:`acomplete`.
86+
7887
Yields:
7988
str: A validated text chunk from the chunking strategy.
8089
@@ -116,6 +125,15 @@ def as_thunk(self) -> ModelOutputThunk:
116125
early-exit results; ``value`` will reflect whatever was accumulated
117126
before cancellation.
118127
128+
Note:
129+
On early exit, ``cancel_generation()`` forces the MOT into a
130+
computed state without running the backend's
131+
``post_processing()``. Telemetry fields on the returned thunk
132+
(``generation.usage``, ``generation.ttfb_ms``, etc.) may
133+
therefore be ``None`` or reflect the partial state at
134+
cancellation time. ``value`` and ``streaming`` are reliable;
135+
usage totals are not.
136+
119137
Returns:
120138
ModelOutputThunk: A computed thunk containing the streamed output.
121139
@@ -178,7 +196,12 @@ async def _validate_and_emit(c: str) -> bool:
178196
try:
179197
delta = await mot.astream()
180198
except RuntimeError:
181-
break
199+
# Expected race: mot.is_computed() was False at the top of the
200+
# loop but the stream finished before we re-entered astream().
201+
# Any other RuntimeError is a real bug and must propagate.
202+
if mot.is_computed():
203+
break
204+
raise
182205

183206
accumulated += delta
184207
chunks = chunking.split(accumulated)
@@ -220,6 +243,23 @@ async def _validate_and_emit(c: str) -> bool:
220243
)
221244

222245
except Exception as exc:
246+
# Orchestrator is leaving — we must stop the backend producer too,
247+
# otherwise mot._async_queue (maxsize=20) fills and the feeder task
248+
# blocks indefinitely. The spec (#891, #901) calls this out for the
249+
# "fail" path; the same reasoning applies to any unplanned exit.
250+
try:
251+
await mot.cancel_generation()
252+
except Exception as cleanup_exc:
253+
# Never let cleanup mask the original exception: log loudly and
254+
# continue to surface `exc` to the consumer.
255+
# TODO(#902): replace this log with an ErrorEvent emission.
256+
MelleaLogger.get_logger().warning(
257+
"stream_with_chunking: cancel_generation() raised during "
258+
"exception cleanup (original: %r, cleanup: %r)",
259+
exc,
260+
cleanup_exc,
261+
)
262+
result.completed = False
223263
await result._chunk_queue.put(exc)
224264
finally:
225265
await result._chunk_queue.put(None)

test/stdlib/test_streaming.py

Lines changed: 193 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -433,14 +433,12 @@ def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq:
433433

434434
assert len(captured) == 1
435435
seen = captured[0].seen_chunks
436-
# Three complete sentences → three separate stream_validate calls.
437-
assert len(seen) == 3
438-
# Each chunk is one sentence, not a prefix of accumulated text.
439-
for chunk in seen:
440-
assert chunk.endswith(".")
441-
# Lengths must not be monotonically growing (which would indicate accumulated text).
442-
# With per-chunk semantics, each chunk is roughly the same length as one sentence.
443-
assert not all(len(seen[i]) < len(seen[i + 1]) for i in range(len(seen) - 1))
436+
# Exact match: three separate calls, one per complete sentence,
437+
# each call receiving that sentence and nothing more. Under the old
438+
# accumulated-text semantics, seen would have been
439+
# ["First sentence.", "First sentence. Second sentence.", ...] —
440+
# exact match against the per-chunk list is the direct regression guard.
441+
assert seen == ["First sentence.", "Second sentence.", "Third sentence."]
444442

445443

446444
@pytest.mark.asyncio
@@ -576,3 +574,190 @@ async def test_no_requirements_streams_without_validation() -> None:
576574
assert result.full_text == response
577575
assert result.final_validations == []
578576
assert result.streaming_failures == []
577+
578+
579+
@pytest.mark.asyncio
580+
async def test_multiple_chunks_in_one_batch_with_mid_batch_fail() -> None:
581+
"""When one astream() delta produces several complete chunks and one in
582+
the middle fails, earlier chunks emit, failing chunk is recorded, later
583+
chunks are neither validated nor emitted."""
584+
585+
captured: list[Any] = []
586+
587+
class FailOnNthChunk(Requirement):
588+
def __init__(self, n: int) -> None:
589+
self._n = n
590+
self._calls = 0
591+
self.seen: list[str] = []
592+
593+
def __copy__(self) -> "FailOnNthChunk":
594+
clone = FailOnNthChunk(self._n)
595+
captured.append(clone)
596+
return clone
597+
598+
def format_for_llm(self) -> str:
599+
return f"fail on chunk {self._n}"
600+
601+
async def stream_validate(
602+
self, chunk: str, *, backend: Any, ctx: Any
603+
) -> PartialValidationResult:
604+
_ = backend, ctx
605+
self._calls += 1
606+
self.seen.append(chunk)
607+
if self._calls == self._n:
608+
return PartialValidationResult("fail", reason=f"n={self._n}")
609+
return PartialValidationResult("unknown")
610+
611+
async def validate(
612+
self,
613+
backend: Any,
614+
ctx: Any,
615+
*,
616+
format: Any = None,
617+
model_options: Any = None,
618+
) -> ValidationResult:
619+
_ = backend, ctx, format, model_options
620+
return ValidationResult(result=True)
621+
622+
# token_size larger than the whole response → one astream() delta delivers
623+
# the full text, so chunking.split produces 4 sentences in a single batch.
624+
response = "One. Two. Three. Four. "
625+
backend = StreamingMockBackend(response, token_size=100)
626+
req = FailOnNthChunk(n=2)
627+
628+
result = await stream_with_chunking(
629+
_action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence"
630+
)
631+
yielded: list[str] = []
632+
async for c in result.astream():
633+
yielded.append(c)
634+
await result.acomplete()
635+
636+
assert result.completed is False
637+
assert len(result.streaming_failures) == 1
638+
# Chunk 1 was validated and emitted; chunk 2 was validated and failed
639+
# (NOT emitted); chunks 3 and 4 were NEITHER validated NOR emitted.
640+
assert yielded == ["One."]
641+
assert len(captured) == 1
642+
assert captured[0].seen == ["One.", "Two."]
643+
assert captured[0]._calls == 2
644+
645+
646+
@pytest.mark.asyncio
647+
async def test_cancel_generation_invoked_on_fail() -> None:
648+
"""Early exit on 'fail' must call mot.cancel_generation() — the spec reason
649+
is that asyncio.Queue(maxsize=20) will block the producer if the consumer
650+
stops without cancelling."""
651+
652+
from mellea.core.base import ModelOutputThunk
653+
654+
response = "word " * 50
655+
backend = StreamingMockBackend(response, token_size=3)
656+
657+
class FailOnFirstChunk(Requirement):
658+
def format_for_llm(self) -> str:
659+
return "fail immediately"
660+
661+
async def stream_validate(
662+
self, chunk: str, *, backend: Any, ctx: Any
663+
) -> PartialValidationResult:
664+
_ = chunk, backend, ctx
665+
return PartialValidationResult("fail", reason="nope")
666+
667+
async def validate(
668+
self,
669+
backend: Any,
670+
ctx: Any,
671+
*,
672+
format: Any = None,
673+
model_options: Any = None,
674+
) -> ValidationResult:
675+
_ = backend, ctx, format, model_options
676+
return ValidationResult(result=True)
677+
678+
call_count = 0
679+
real_cancel = ModelOutputThunk.cancel_generation
680+
681+
async def spy_cancel(self: ModelOutputThunk) -> None:
682+
nonlocal call_count
683+
call_count += 1
684+
await real_cancel(self)
685+
686+
ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign]
687+
try:
688+
result = await stream_with_chunking(
689+
_action(),
690+
backend,
691+
_ctx(),
692+
quick_check_requirements=[FailOnFirstChunk()],
693+
chunking="word",
694+
)
695+
await asyncio.wait_for(result.acomplete(), timeout=5.0)
696+
finally:
697+
ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign]
698+
699+
assert result.completed is False
700+
assert call_count >= 1
701+
702+
703+
@pytest.mark.asyncio
704+
async def test_exception_in_stream_validate_cancels_generation() -> None:
705+
"""If stream_validate raises, the orchestrator must still call
706+
cancel_generation() — otherwise the backend producer blocks on the
707+
(maxsize=20) queue — and surface the exception to the consumer via
708+
astream()/acomplete()."""
709+
710+
from mellea.core.base import ModelOutputThunk
711+
712+
class RaisingReq(Requirement):
713+
def format_for_llm(self) -> str:
714+
return "raises"
715+
716+
async def stream_validate(
717+
self, chunk: str, *, backend: Any, ctx: Any
718+
) -> PartialValidationResult:
719+
_ = chunk, backend, ctx
720+
raise ValueError("boom")
721+
722+
async def validate(
723+
self,
724+
backend: Any,
725+
ctx: Any,
726+
*,
727+
format: Any = None,
728+
model_options: Any = None,
729+
) -> ValidationResult:
730+
_ = backend, ctx, format, model_options
731+
return ValidationResult(result=True)
732+
733+
response = "word " * 50 # enough to fill maxsize=20 queue without cleanup
734+
backend = StreamingMockBackend(response, token_size=3)
735+
736+
call_count = 0
737+
real_cancel = ModelOutputThunk.cancel_generation
738+
739+
async def spy_cancel(self: ModelOutputThunk) -> None:
740+
nonlocal call_count
741+
call_count += 1
742+
await real_cancel(self)
743+
744+
ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign]
745+
try:
746+
result = await stream_with_chunking(
747+
_action(),
748+
backend,
749+
_ctx(),
750+
quick_check_requirements=[RaisingReq()],
751+
chunking="word",
752+
)
753+
with pytest.raises(ValueError, match="boom"):
754+
async for _chunk in result.astream():
755+
pass
756+
# acomplete must complete (not hang) even though the orchestration
757+
# task raised, because cancel_generation was called in the except path.
758+
await asyncio.wait_for(result.acomplete(), timeout=5.0)
759+
finally:
760+
ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign]
761+
762+
assert result.completed is False
763+
assert call_count >= 1

0 commit comments

Comments
 (0)