Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions docs/examples/streaming/streaming_chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# pytest: ollama, e2e
Comment thread
planetf1 marked this conversation as resolved.

"""Streaming generation with per-chunk validation using stream_with_chunking().

Demonstrates:
- Subclassing Requirement to override stream_validate() for early-exit checks
- Calling stream_with_chunking() with sentence-level chunking
- Consuming validated chunks via astream() as they arrive
- Awaiting full completion with acomplete() to access final_validations and full_text
"""

import asyncio

from mellea.core.backend import Backend
from mellea.core.base import Context
from mellea.core.requirement import (
PartialValidationResult,
Requirement,
ValidationResult,
)
from mellea.stdlib.components import Instruction
from mellea.stdlib.streaming import stream_with_chunking


class MaxSentencesReq(Requirement):
"""Fails if the model generates more than *limit* sentences mid-stream.

Each ``stream_validate`` call receives one complete sentence from the
:class:`~mellea.stdlib.chunking.SentenceChunker`. The running count is
maintained on ``self`` — this is the standard pattern for requirements
that need context beyond a single chunk.
"""

def __init__(self, limit: int) -> None:
super().__init__()
self._limit = limit
self._count = 0

def format_for_llm(self) -> str:
return f"The response must be at most {self._limit} sentences long."

async def stream_validate(
self, chunk: str, *, backend: Backend, ctx: Context
) -> PartialValidationResult:
self._count += 1
if self._count > self._limit:
return PartialValidationResult(
"fail",
reason=f"Response exceeded {self._limit} sentence limit mid-stream",
)
return PartialValidationResult("unknown")

async def validate(
self,
backend: Backend,
ctx: Context,
*,
format: type | None = None,
model_options: dict | None = None,
) -> ValidationResult:
return ValidationResult(result=True)


async def main() -> None:
from mellea.stdlib.session import start_session

m = start_session()
backend = m.backend
ctx = m.ctx

action = Instruction(
"Write a short paragraph about the water cycle in exactly two sentences."
)
req = MaxSentencesReq(limit=3)

result = await stream_with_chunking(
action, backend, ctx, quick_check_requirements=[req], chunking="sentence"
)

print("Streaming chunks as they arrive:")
async for chunk in result.astream():
print(f" CHUNK: {chunk!r}")

await result.acomplete()

print(f"\nCompleted normally: {result.completed}")
print(f"Full text: {result.full_text!r}")

if result.streaming_failures:
for _req, pvr in result.streaming_failures:
print(f"Streaming failure: {pvr.reason}")

if result.final_validations:
for vr in result.final_validations:
print(f"Final validation: {'PASS' if vr.as_bool() else 'FAIL'}")


asyncio.run(main())
68 changes: 68 additions & 0 deletions mellea/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,74 @@ def _record_ttfb(self) -> None:
).total_seconds() * 1000
self._first_chunk_received = True

async def cancel_generation(self, error: Exception | None = None) -> None:
"""Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span.

Safe to call at any point during streaming. After this method returns,
``is_computed()`` is ``True`` and ``value`` contains whatever text was
accumulated before cancellation. Calling on an already-computed MOT
is a no-op.
Comment on lines +370 to +373
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a flag indicating that the generation was cancelled.


Draining the internal queue after cancellation is necessary to release
any ``asyncio.Queue.put()`` call that the generation task was blocked on
(queue maxsize=20).

Args:
error: Optional cause attributed to the open telemetry span. When
provided, this exception is recorded via ``set_span_error`` so
the span reflects the actual reason for cancellation (e.g. the
requirement failure or an unhandled exception from a streaming
validator). When ``None``, a generic
``RuntimeError("Generation cancelled")`` is recorded.
"""
if self._computed:
return

def _drain() -> None:
while not self._async_queue.empty():
try:
self._async_queue.get_nowait()
except asyncio.QueueEmpty:
break

if self._generate is not None and not self._generate.done():
self._generate.cancel()

if self._generate_extra is not None and not self._generate_extra.done():
self._generate_extra.cancel()

# Drain before awaiting — unblocks any put() the task is stuck on.
_drain()

if self._generate is not None:
try:
await self._generate
except (asyncio.CancelledError, Exception):
pass

if self._generate_extra is not None:
try:
await self._generate_extra
except (asyncio.CancelledError, Exception):
pass

# Drain again for any final item the task put before terminating.
_drain()

span = self._meta.pop("_telemetry_span", None)
if span is not None:
from ..telemetry import end_backend_span, set_span_error

recorded: Exception = (
error if error is not None else RuntimeError("Generation cancelled")
)
set_span_error(span, recorded)
end_backend_span(span)

if self._underlying_value is None:
self._underlying_value = ""
self._computed = True

def _copy_from(self, other: ModelOutputThunk) -> None:
"""Copy computed-output fields from *other* into *self*.

Expand Down
15 changes: 13 additions & 2 deletions mellea/stdlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,20 @@
``mellea.stdlib.session`` — for day-to-day use.

Streaming chunking strategies (for use with streaming validation) are available at
``mellea.stdlib.chunking`` and re-exported here for convenience.
``mellea.stdlib.chunking`` and re-exported here for convenience. The core streaming
orchestration primitive :func:`~mellea.stdlib.streaming.stream_with_chunking` and
its result type :class:`~mellea.stdlib.streaming.StreamChunkingResult` are also
re-exported here.
"""

from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker
from .streaming import StreamChunkingResult, stream_with_chunking

__all__ = ["ChunkingStrategy", "ParagraphChunker", "SentenceChunker", "WordChunker"]
__all__ = [
"ChunkingStrategy",
"ParagraphChunker",
"SentenceChunker",
"StreamChunkingResult",
"WordChunker",
"stream_with_chunking",
Comment on lines +23 to +28
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This comment is not actionable for the purposes of this PR.)

Some software engineering philosophy:

To make a streaming-based library pleasant to use, the library developers need to provide nice and easy ways to chunk the stream into semantically meaningful bits.

Historically this is the sort of thing that a library would achieve by getting a few "major" methods correct and then providing a... library... of methods that make it easy for devs to build what they need.

I wonder if this is one (sane) way in which coding agents change the way software gets built -- instead of shipping all the things, we instead ship a small number of most-common things (this is a reasonable list) together with a skill for the writing the chunking strategy for an existing requirement. And in this case writing a really good skill for that would be actually possible because you have a ground-truth on the I/O behavior, so the spec for the streaming version of the requirement checker is:

  1. the streaming requirement implementation should have the same ultimate return value as the non-streaming requirement implementation.
  2. the streaming requirement implementation should actually do incremental processing (this is a bit more hand-wavy but you could say every X% of tokens or something like that).

]
103 changes: 103 additions & 0 deletions mellea/stdlib/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ def split(self, accumulated_text: str) -> list[str]:
"""
...

def flush(self, accumulated_text: str) -> list[str]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll have to think about what to do about streaming from multi-modal models. I'm okay with that not being supported for now.

"""Return any trailing fragment that ``split`` withheld.

Called once by the orchestrator after the stream has ended naturally
(not on early-exit cancellation). Gives the chunker a chance to
release the final fragment that did not reach a terminator.

The default implementation returns an empty list — the trailing
fragment is discarded. Built-in chunkers override this to return
the withheld fragment as a single-element list when non-empty.

Args:
accumulated_text: The full accumulated text at stream end.

Returns:
The trailing fragment as ``[fragment]`` if it should be treated
as a final chunk, or an empty list to discard it.
"""
_ = accumulated_text
return []


# Sentence boundary: sentence-ending punctuation, optionally followed by a closing
# quote or paren, then whitespace.
Expand Down Expand Up @@ -94,6 +115,36 @@ def split(self, accumulated_text: str) -> list[str]:

return chunks

def flush(self, accumulated_text: str) -> list[str]:
"""Return the trailing sentence fragment (if any) as a final chunk.

Trailing whitespace on the fragment is non-semantic for sentence
boundaries and is dropped via ``rstrip``. Leading whitespace is
already removed by the loop's ``lstrip`` on each advance, so no
``lstrip`` is needed here. The result is the fragment's content
only, consistent with how :meth:`split` returns sentences without
trailing whitespace.

Args:
accumulated_text: The full accumulated text at stream end.

Returns:
A single-element list containing the trailing sentence fragment
with leading and trailing whitespace stripped, or an empty list
when there is no fragment (all content ended in a sentence
boundary or the input is empty/whitespace-only).
"""
if not accumulated_text:
return []
remaining = accumulated_text
while True:
match = _SENTENCE_BOUNDARY.search(remaining)
if match is None:
break
remaining = remaining[match.end() :].lstrip()
trailing = remaining.rstrip()
return [trailing] if trailing else []


class WordChunker(ChunkingStrategy):
"""Splits accumulated text on whitespace boundaries.
Expand Down Expand Up @@ -134,6 +185,32 @@ def split(self, accumulated_text: str) -> list[str]:

return parts

def flush(self, accumulated_text: str) -> list[str]:
"""Return the trailing word fragment (if any) as a final chunk.

The trailing fragment is the text after the last whitespace run when
the accumulated text does not end with whitespace. When it does end
with whitespace, every word is already complete and no fragment is
released.

Args:
accumulated_text: The full accumulated text at stream end.

Returns:
A single-element list containing the trailing word fragment, or
an empty list when the input ends with whitespace (every word
already complete) or is empty.
"""
if not accumulated_text:
return []
if accumulated_text[-1].isspace():
return []
parts = _WHITESPACE.split(accumulated_text)
for part in reversed(parts):
if part:
return [part]
return []


class ParagraphChunker(ChunkingStrategy):
r"""Splits accumulated text on double-newline paragraph boundaries.
Expand Down Expand Up @@ -168,3 +245,29 @@ def split(self, accumulated_text: str) -> list[str]:

# _PARA_BOUNDARY.split on leading \n\n produces an empty first element.
return [p for p in parts if p]

def flush(self, accumulated_text: str) -> list[str]:
r"""Return the trailing paragraph fragment (if any) as a final chunk.

Unlike :class:`SentenceChunker.flush`, the fragment is returned
byte-for-byte without stripping. Internal whitespace — including
a trailing single ``\n`` — can be semantically meaningful inside
a paragraph (e.g. a list item or a deliberate line break), and a
consumer validating paragraph content should see the fragment as
it was withheld.

Args:
accumulated_text: The full accumulated text at stream end.

Returns:
A single-element list containing the trailing paragraph fragment
byte-for-byte, or an empty list when the input ends with a
paragraph boundary (``\n\n`` or more) or is empty.
"""
if not accumulated_text:
return []
if _PARA_BOUNDARY_END.search(accumulated_text):
return []
parts = _PARA_BOUNDARY.split(accumulated_text)
trailing = parts[-1] if parts else ""
return [trailing] if trailing else []
Loading
Loading