Skip to content

Commit da41a06

Browse files
committed
fix(stdlib): address second-round review feedback
Three items from the second independent review: cancel_generation(error=) — accept an optional Exception parameter. When the orchestrator enters the except Exception path, it now passes the caught exception to cancel_generation() so the backend telemetry span records the real cause via set_span_error instead of a generic RuntimeError("Generation cancelled"). The original exception still surfaces to the consumer via astream()/acomplete(); this is purely an OTEL accuracy fix. Backward-compatible: the default None preserves the previous "Generation cancelled" message for the normal fail path. stream_with_chunking docstring — the "After the stream ends (naturally or via early exit), validate() is called" wording overstated behaviour. The orchestrator actually skips final validate() on early exit (test_early_exit_on_fail verifies final_validations == []). Docstring now correctly says final validate() runs only on natural completion. test_exception_in_stream_validate_cancels_generation docstring — the test fails on chunk 1 so the queue never actually fills; it verifies the cancel-on-exception path and the no-hang guarantee but does not directly prove the worst-case "producer blocked on full queue" scenario. Docstring now states what it actually covers and points at test/core/ for the cancel_generation drain logic. Assisted-by: Claude Code
1 parent def10b6 commit da41a06

3 files changed

Lines changed: 40 additions & 14 deletions

File tree

mellea/core/base.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def _record_ttfb(self) -> None:
364364
).total_seconds() * 1000
365365
self._first_chunk_received = True
366366

367-
async def cancel_generation(self) -> None:
367+
async def cancel_generation(self, error: Exception | None = None) -> None:
368368
"""Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span.
369369
370370
Safe to call at any point during streaming. After this method returns,
@@ -375,6 +375,14 @@ async def cancel_generation(self) -> None:
375375
Draining the internal queue after cancellation is necessary to release
376376
any ``asyncio.Queue.put()`` call that the generation task was blocked on
377377
(queue maxsize=20).
378+
379+
Args:
380+
error: Optional cause attributed to the open telemetry span. When
381+
provided, this exception is recorded via ``set_span_error`` so
382+
the span reflects the actual reason for cancellation (e.g. the
383+
requirement failure or an unhandled exception from a streaming
384+
validator). When ``None``, a generic
385+
``RuntimeError("Generation cancelled")`` is recorded.
378386
"""
379387
if self._computed:
380388
return
@@ -414,7 +422,10 @@ def _drain() -> None:
414422
if span is not None:
415423
from ..telemetry import end_backend_span, set_span_error
416424

417-
set_span_error(span, RuntimeError("Generation cancelled"))
425+
recorded: Exception = (
426+
error if error is not None else RuntimeError("Generation cancelled")
427+
)
428+
set_span_error(span, recorded)
418429
end_backend_span(span)
419430
del self._meta["_telemetry_span"]
420431

mellea/stdlib/streaming.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,10 @@ async def _validate_and_emit(c: str) -> bool:
247247
# otherwise mot._async_queue (maxsize=20) fills and the feeder task
248248
# blocks indefinitely. The spec (#891, #901) calls this out for the
249249
# "fail" path; the same reasoning applies to any unplanned exit.
250+
# Pass `exc` so the backend telemetry span records the real cause
251+
# rather than a generic "Generation cancelled".
250252
try:
251-
await mot.cancel_generation()
253+
await mot.cancel_generation(error=exc)
252254
except Exception as cleanup_exc:
253255
# Never let cleanup mask the original exception: log loudly and
254256
# continue to surface `exc` to the consumer.
@@ -304,9 +306,12 @@ async def stream_with_chunking(
304306
same terms as the regular chunks. On early exit, the trailing fragment
305307
is discarded because the generation was cancelled mid-token.
306308
307-
After the stream ends (naturally or via early exit), ``validate()`` is
308-
called on all requirements that did not return ``"fail"``. Requirements
309-
are cloned (``copy(req)``) before use so originals are never mutated.
309+
After the stream ends naturally, ``validate()`` is called on every
310+
requirement that did not return ``"fail"`` — both ``"pass"`` and
311+
``"unknown"`` trigger final validation. On early exit, no ``validate()``
312+
call is made; :attr:`StreamChunkingResult.final_validations` remains
313+
empty. Requirements are cloned (``copy(req)``) before use so originals
314+
are never mutated.
310315
311316
Requirements that need context beyond the current chunk should
312317
accumulate it themselves across ``stream_validate`` calls (e.g.

test/stdlib/test_streaming.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -678,10 +678,12 @@ async def validate(
678678
call_count = 0
679679
real_cancel = ModelOutputThunk.cancel_generation
680680

681-
async def spy_cancel(self: ModelOutputThunk) -> None:
681+
async def spy_cancel(
682+
self: ModelOutputThunk, error: Exception | None = None
683+
) -> None:
682684
nonlocal call_count
683685
call_count += 1
684-
await real_cancel(self)
686+
await real_cancel(self, error)
685687

686688
ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign]
687689
try:
@@ -702,10 +704,16 @@ async def spy_cancel(self: ModelOutputThunk) -> None:
702704

703705
@pytest.mark.asyncio
704706
async def test_exception_in_stream_validate_cancels_generation() -> None:
705-
"""If stream_validate raises, the orchestrator must still call
706-
cancel_generation() — otherwise the backend producer blocks on the
707-
(maxsize=20) queue — and surface the exception to the consumer via
708-
astream()/acomplete()."""
707+
"""Verifies the orchestrator's exception-path cleanup: if stream_validate
708+
raises, cancel_generation() is called and the exception surfaces to the
709+
consumer via astream()/acomplete() without hanging.
710+
711+
This covers the cancel-on-exception path and the no-hang guarantee.
712+
It does not directly exercise the worst-case "producer already blocked on
713+
full queue" scenario (here the fail happens on chunk 1 so the queue never
714+
fills); the cancel_generation drain logic is covered by its own tests in
715+
test/core/.
716+
"""
709717

710718
from mellea.core.base import ModelOutputThunk
711719

@@ -736,10 +744,12 @@ async def validate(
736744
call_count = 0
737745
real_cancel = ModelOutputThunk.cancel_generation
738746

739-
async def spy_cancel(self: ModelOutputThunk) -> None:
747+
async def spy_cancel(
748+
self: ModelOutputThunk, error: Exception | None = None
749+
) -> None:
740750
nonlocal call_count
741751
call_count += 1
742-
await real_cancel(self)
752+
await real_cancel(self, error)
743753

744754
ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign]
745755
try:

0 commit comments

Comments
 (0)