Skip to content

Commit 69c9c3f

Browse files
committed
feat(core): add cancel_generation() to ModelOutputThunk
Adds an async cancel_generation() method that cancels in-progress _generate and _generate_extra tasks, drains the internal async queue to release any blocked put() calls, closes the open telemetry span, and sets _computed=True so the MOT is immediately usable. Required by the stream_with_chunking() orchestrator (#901) for clean early-exit when a streaming requirement returns "fail". Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com>
1 parent 3919927 commit 69c9c3f

1 file changed

Lines changed: 58 additions & 0 deletions

File tree

mellea/core/base.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,64 @@ def _record_ttfb(self) -> None:
364364
).total_seconds() * 1000
365365
self._first_chunk_received = True
366366

367+
async def cancel_generation(self) -> None:
368+
"""Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span.
369+
370+
Safe to call at any point during streaming. After this method returns,
371+
``is_computed()`` is ``True`` and ``value`` contains whatever text was
372+
accumulated before cancellation. Calling on an already-computed MOT
373+
is a no-op.
374+
375+
Draining the internal queue after cancellation is necessary to release
376+
any ``asyncio.Queue.put()`` call that the generation task was blocked on
377+
(queue maxsize=20).
378+
"""
379+
if self._computed:
380+
return
381+
382+
def _drain() -> None:
383+
while not self._async_queue.empty():
384+
try:
385+
self._async_queue.get_nowait()
386+
except asyncio.QueueEmpty:
387+
break
388+
389+
if self._generate is not None and not self._generate.done():
390+
self._generate.cancel()
391+
392+
if self._generate_extra is not None and not self._generate_extra.done():
393+
self._generate_extra.cancel()
394+
395+
# Drain before awaiting — unblocks any put() the task is stuck on.
396+
_drain()
397+
398+
if self._generate is not None:
399+
try:
400+
await self._generate
401+
except (asyncio.CancelledError, Exception):
402+
pass
403+
404+
if self._generate_extra is not None:
405+
try:
406+
await self._generate_extra
407+
except (asyncio.CancelledError, Exception):
408+
pass
409+
410+
# Drain again for any final item the task put before terminating.
411+
_drain()
412+
413+
span = self._meta.get("_telemetry_span")
414+
if span is not None:
415+
from ..telemetry import end_backend_span, set_span_error
416+
417+
set_span_error(span, RuntimeError("Generation cancelled"))
418+
end_backend_span(span)
419+
del self._meta["_telemetry_span"]
420+
421+
if self._underlying_value is None:
422+
self._underlying_value = ""
423+
self._computed = True
424+
367425
def _copy_from(self, other: ModelOutputThunk) -> None:
368426
"""Copy computed-output fields from *other* into *self*.
369427

0 commit comments

Comments
 (0)