Skip to content

Commit c132854

Browse files
committed
test(stdlib): add StreamingMockBackend and streaming orchestration tests
Adds test/stdlib/test_streaming.py with 9 unit tests covering: - Normal completion: validate() called at stream end, completed=True - Early exit on "fail": completed=False, streaming_failures populated - Clone isolation: originals never mutated across retries - quick_check_backend routing: validation uses alternate backend - Deadlock prevention: early exit with asyncio.wait_for timeout - as_thunk correctness: value=full_text, raises before acomplete() - astream() yields individual chunks (not accumulated text) - No requirements: streams without validation StreamingMockBackend subclasses Backend and feeds a fixed response string into a MOT queue char-by-char via asyncio.create_task, following the create_manual_mock_thunk() pattern from test_astream_mock.py. Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com>
1 parent a6c98ed commit c132854

1 file changed

Lines changed: 384 additions & 0 deletions

File tree

test/stdlib/test_streaming.py

Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
"""Tests for stream_with_chunking() and StreamChunkingResult.
2+
3+
Uses StreamingMockBackend — a deterministic test double that feeds tokens from a
4+
fixed response string into a MOT queue without network or LLM calls.
5+
6+
All tests are unit tests (no @pytest.mark.ollama needed).
7+
"""
8+
9+
import asyncio
10+
from typing import Any
11+
12+
import pytest
13+
14+
from mellea.core.backend import Backend
15+
from mellea.core.base import CBlock, Context, GenerateType, ModelOutputThunk
16+
from mellea.core.requirement import (
17+
PartialValidationResult,
18+
Requirement,
19+
ValidationResult,
20+
)
21+
from mellea.stdlib.context import SimpleContext
22+
from mellea.stdlib.streaming import stream_with_chunking
23+
24+
# ---------------------------------------------------------------------------
25+
# StreamingMockBackend
26+
# ---------------------------------------------------------------------------
27+
28+
29+
async def _mock_process(mot: ModelOutputThunk, chunk: Any) -> None:
30+
if mot._underlying_value is None:
31+
mot._underlying_value = ""
32+
if chunk is not None:
33+
mot._underlying_value += chunk
34+
35+
36+
async def _mock_post_process(_mot: ModelOutputThunk) -> None:
37+
pass
38+
39+
40+
def _make_mot() -> ModelOutputThunk:
41+
mot = ModelOutputThunk(value=None)
42+
mot._action = CBlock("mock_action")
43+
mot._generate_type = GenerateType.ASYNC
44+
mot._process = _mock_process
45+
mot._post_process = _mock_post_process
46+
mot._chunk_size = 0
47+
return mot
48+
49+
50+
async def _feed_tokens(mot: ModelOutputThunk, response: str, token_size: int) -> None:
51+
i = 0
52+
while i < len(response):
53+
token = response[i : i + token_size]
54+
await mot._async_queue.put(token)
55+
await asyncio.sleep(0)
56+
i += token_size
57+
await mot._async_queue.put(None)
58+
59+
60+
class StreamingMockBackend(Backend):
61+
"""Test double that streams a fixed response one token at a time.
62+
63+
``token_size`` controls how many characters constitute one token.
64+
Validation calls (via ``stream_validate`` / ``validate``) are delegated
65+
to the requirements themselves — this backend does not perform any real
66+
inference.
67+
"""
68+
69+
def __init__(self, response: str, token_size: int = 1) -> None:
70+
self._response = response
71+
self._token_size = token_size
72+
73+
async def _generate_from_context(
74+
self,
75+
action: Any,
76+
ctx: Context,
77+
*,
78+
format: Any = None,
79+
model_options: dict | None = None,
80+
tool_calls: bool = False,
81+
) -> tuple[ModelOutputThunk, Context]:
82+
_ = format, model_options, tool_calls
83+
mot = _make_mot()
84+
task = asyncio.create_task(_feed_tokens(mot, self._response, self._token_size))
85+
_ = task
86+
new_ctx = ctx.add(action).add(mot)
87+
return mot, new_ctx
88+
89+
async def generate_from_raw(
90+
self, actions: Any, ctx: Any, **kwargs: Any
91+
) -> list[ModelOutputThunk]:
92+
raise NotImplementedError
93+
94+
95+
# ---------------------------------------------------------------------------
96+
# Requirement test doubles
97+
# ---------------------------------------------------------------------------
98+
99+
100+
class AlwaysUnknownReq(Requirement):
101+
"""stream_validate always returns 'unknown'; validate returns True."""
102+
103+
def format_for_llm(self) -> str:
104+
return "always unknown"
105+
106+
async def stream_validate(
107+
self, chunk: str, *, backend: Any, ctx: Any
108+
) -> PartialValidationResult:
109+
return PartialValidationResult("unknown")
110+
111+
async def validate(
112+
self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None
113+
) -> ValidationResult:
114+
return ValidationResult(result=True)
115+
116+
117+
class FailAfterWordsReq(Requirement):
118+
"""Returns 'fail' once the accumulated text reaches *threshold* words."""
119+
120+
def __init__(self, threshold: int) -> None:
121+
self._threshold = threshold
122+
123+
def format_for_llm(self) -> str:
124+
return f"fail after {self._threshold} words"
125+
126+
async def stream_validate(
127+
self, chunk: str, *, backend: Any, ctx: Any
128+
) -> PartialValidationResult:
129+
if len(chunk.split()) >= self._threshold:
130+
return PartialValidationResult("fail", reason="too many words")
131+
return PartialValidationResult("unknown")
132+
133+
async def validate(
134+
self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None
135+
) -> ValidationResult:
136+
return ValidationResult(result=True)
137+
138+
139+
class BackendRecordingReq(Requirement):
140+
"""Records which backend was passed to stream_validate and validate."""
141+
142+
def __init__(self) -> None:
143+
self.seen_backends: list[Any] = []
144+
145+
def __copy__(self) -> "BackendRecordingReq":
146+
clone = BackendRecordingReq()
147+
clone.seen_backends = [] # fresh list — do not share with original
148+
return clone
149+
150+
def format_for_llm(self) -> str:
151+
return "backend recorder"
152+
153+
async def stream_validate(
154+
self, chunk: str, *, backend: Any, ctx: Any
155+
) -> PartialValidationResult:
156+
_ = chunk
157+
self.seen_backends.append(backend)
158+
return PartialValidationResult("unknown")
159+
160+
async def validate(
161+
self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None
162+
) -> ValidationResult:
163+
self.seen_backends.append(backend)
164+
return ValidationResult(result=True)
165+
166+
167+
class MutationDetectorReq(Requirement):
168+
"""Tracks how many times stream_validate was called on this instance."""
169+
170+
def __init__(self) -> None:
171+
self._call_count = 0
172+
173+
def format_for_llm(self) -> str:
174+
return "mutation detector"
175+
176+
async def stream_validate(
177+
self, chunk: str, *, backend: Any, ctx: Any
178+
) -> PartialValidationResult:
179+
_ = chunk, backend, ctx
180+
self._call_count += 1
181+
return PartialValidationResult("unknown")
182+
183+
async def validate(
184+
self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None
185+
) -> ValidationResult:
186+
return ValidationResult(result=True)
187+
188+
189+
# ---------------------------------------------------------------------------
190+
# Helpers
191+
# ---------------------------------------------------------------------------
192+
193+
194+
def _ctx() -> SimpleContext:
195+
return SimpleContext()
196+
197+
198+
def _action() -> CBlock:
199+
return CBlock("prompt")
200+
201+
202+
# ---------------------------------------------------------------------------
203+
# Tests
204+
# ---------------------------------------------------------------------------
205+
206+
207+
@pytest.mark.asyncio
208+
async def test_normal_completion_calls_validate_at_stream_end() -> None:
209+
"""All 'unknown' requirements → validate() called at stream end; completed=True."""
210+
response = "Hello world. How are you. "
211+
backend = StreamingMockBackend(response, token_size=3)
212+
req = AlwaysUnknownReq()
213+
214+
result = await stream_with_chunking(
215+
_action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence"
216+
)
217+
await result.acomplete()
218+
219+
assert result.completed is True
220+
assert result.full_text == response
221+
assert len(result.final_validations) == 1
222+
assert result.final_validations[0].as_bool() is True
223+
assert result.streaming_failures == []
224+
225+
226+
@pytest.mark.asyncio
227+
async def test_early_exit_on_fail() -> None:
228+
"""Requirement fails mid-stream → completed=False, streaming_failures populated."""
229+
# 5 words to trigger failure
230+
response = "one two three four five six seven eight. "
231+
backend = StreamingMockBackend(response, token_size=2)
232+
req = FailAfterWordsReq(threshold=4)
233+
234+
result = await stream_with_chunking(
235+
_action(), backend, _ctx(), quick_check_requirements=[req], chunking="word"
236+
)
237+
await result.acomplete()
238+
239+
assert result.completed is False
240+
assert len(result.streaming_failures) == 1
241+
_req, pvr = result.streaming_failures[0]
242+
assert pvr.success == "fail"
243+
assert pvr.reason == "too many words"
244+
# final_validations should be empty — final validate() skipped on early exit
245+
assert result.final_validations == []
246+
247+
248+
@pytest.mark.asyncio
249+
async def test_clone_isolation_across_retries() -> None:
250+
"""Originals must not be mutated; two invocations are independent."""
251+
response = "Sentence one. Sentence two. "
252+
req = MutationDetectorReq()
253+
original_reqs = [req]
254+
255+
backend = StreamingMockBackend(response, token_size=4)
256+
257+
r1 = await stream_with_chunking(
258+
_action(),
259+
backend,
260+
_ctx(),
261+
quick_check_requirements=original_reqs,
262+
chunking="sentence",
263+
)
264+
await r1.acomplete()
265+
266+
r2 = await stream_with_chunking(
267+
_action(),
268+
backend,
269+
_ctx(),
270+
quick_check_requirements=original_reqs,
271+
chunking="sentence",
272+
)
273+
await r2.acomplete()
274+
275+
# Original requirement must never have been called — only clones are used
276+
assert req._call_count == 0
277+
278+
279+
@pytest.mark.asyncio
280+
async def test_quick_check_backend_routing() -> None:
281+
"""stream_validate and validate receive quick_check_backend, not the main backend."""
282+
response = "One sentence. Two sentences. "
283+
main_backend = StreamingMockBackend(response, token_size=3)
284+
val_backend = StreamingMockBackend("unused", token_size=1)
285+
286+
req = BackendRecordingReq()
287+
288+
result = await stream_with_chunking(
289+
_action(),
290+
main_backend,
291+
_ctx(),
292+
quick_check_requirements=[req],
293+
chunking="sentence",
294+
quick_check_backend=val_backend,
295+
)
296+
await result.acomplete()
297+
298+
# The clone's seen_backends should only contain val_backend
299+
# (The original req was never called; clones were.)
300+
# Verify via final_validations side-effect: at least one backend recorded
301+
assert result.completed is True
302+
# The original req._seen_backends is untouched (clone isolation)
303+
assert req.seen_backends == []
304+
305+
306+
@pytest.mark.asyncio
307+
async def test_early_exit_does_not_deadlock() -> None:
308+
"""Early failure with a high-throughput stream must not hang."""
309+
long_response = "word " * 200
310+
backend = StreamingMockBackend(long_response, token_size=5)
311+
req = FailAfterWordsReq(threshold=3)
312+
313+
result = await stream_with_chunking(
314+
_action(), backend, _ctx(), quick_check_requirements=[req], chunking="word"
315+
)
316+
# 5-second timeout — should complete in milliseconds on success
317+
await asyncio.wait_for(result.acomplete(), timeout=5.0)
318+
319+
assert result.completed is False
320+
321+
322+
@pytest.mark.asyncio
323+
async def test_as_thunk_correctness() -> None:
324+
"""as_thunk is computed, value matches full_text, generation metadata preserved."""
325+
response = "This is a test sentence. "
326+
backend = StreamingMockBackend(response, token_size=4)
327+
328+
result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence")
329+
await result.acomplete()
330+
331+
thunk = result.as_thunk
332+
assert thunk.is_computed()
333+
assert thunk.value == result.full_text == response
334+
335+
336+
@pytest.mark.asyncio
337+
async def test_as_thunk_raises_before_acomplete() -> None:
338+
"""as_thunk raises RuntimeError if accessed before acomplete()."""
339+
response = "Some text. "
340+
backend = StreamingMockBackend(response, token_size=2)
341+
342+
result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence")
343+
344+
with pytest.raises(RuntimeError, match="acomplete"):
345+
_ = result.as_thunk
346+
347+
348+
@pytest.mark.asyncio
349+
async def test_astream_yields_individual_chunks() -> None:
350+
"""Consumer via astream() receives individual chunks, not accumulated text."""
351+
response = "First sentence. Second sentence. Third sentence. "
352+
backend = StreamingMockBackend(response, token_size=5)
353+
354+
result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence")
355+
356+
chunks: list[str] = []
357+
async for chunk in result.astream():
358+
chunks.append(chunk)
359+
360+
await result.acomplete()
361+
362+
# Each chunk must be a complete sentence (not the accumulated text)
363+
assert len(chunks) == 3
364+
for chunk in chunks:
365+
assert chunk.endswith(".")
366+
# Chunks don't include inter-sentence spaces; joined with a space they appear in full_text
367+
assert " ".join(chunks) in result.full_text
368+
369+
370+
@pytest.mark.asyncio
371+
async def test_no_requirements_streams_without_validation() -> None:
372+
"""quick_check_requirements=None → chunks produced, no validate() called."""
373+
response = "Chunk one. Chunk two. Chunk three. "
374+
backend = StreamingMockBackend(response, token_size=3)
375+
376+
result = await stream_with_chunking(
377+
_action(), backend, _ctx(), quick_check_requirements=None, chunking="sentence"
378+
)
379+
await result.acomplete()
380+
381+
assert result.completed is True
382+
assert result.full_text == response
383+
assert result.final_validations == []
384+
assert result.streaming_failures == []

0 commit comments

Comments
 (0)