|
| 1 | +"""Tests for stream_with_chunking() and StreamChunkingResult. |
| 2 | +
|
| 3 | +Uses StreamingMockBackend — a deterministic test double that feeds tokens from a |
| 4 | +fixed response string into a MOT queue without network or LLM calls. |
| 5 | +
|
| 6 | +All tests are unit tests (no @pytest.mark.ollama needed). |
| 7 | +""" |
| 8 | + |
| 9 | +import asyncio |
| 10 | +from typing import Any |
| 11 | + |
| 12 | +import pytest |
| 13 | + |
| 14 | +from mellea.core.backend import Backend |
| 15 | +from mellea.core.base import CBlock, Context, GenerateType, ModelOutputThunk |
| 16 | +from mellea.core.requirement import ( |
| 17 | + PartialValidationResult, |
| 18 | + Requirement, |
| 19 | + ValidationResult, |
| 20 | +) |
| 21 | +from mellea.stdlib.context import SimpleContext |
| 22 | +from mellea.stdlib.streaming import stream_with_chunking |
| 23 | + |
| 24 | +# --------------------------------------------------------------------------- |
| 25 | +# StreamingMockBackend |
| 26 | +# --------------------------------------------------------------------------- |
| 27 | + |
| 28 | + |
| 29 | +async def _mock_process(mot: ModelOutputThunk, chunk: Any) -> None: |
| 30 | + if mot._underlying_value is None: |
| 31 | + mot._underlying_value = "" |
| 32 | + if chunk is not None: |
| 33 | + mot._underlying_value += chunk |
| 34 | + |
| 35 | + |
| 36 | +async def _mock_post_process(_mot: ModelOutputThunk) -> None: |
| 37 | + pass |
| 38 | + |
| 39 | + |
| 40 | +def _make_mot() -> ModelOutputThunk: |
| 41 | + mot = ModelOutputThunk(value=None) |
| 42 | + mot._action = CBlock("mock_action") |
| 43 | + mot._generate_type = GenerateType.ASYNC |
| 44 | + mot._process = _mock_process |
| 45 | + mot._post_process = _mock_post_process |
| 46 | + mot._chunk_size = 0 |
| 47 | + return mot |
| 48 | + |
| 49 | + |
| 50 | +async def _feed_tokens(mot: ModelOutputThunk, response: str, token_size: int) -> None: |
| 51 | + i = 0 |
| 52 | + while i < len(response): |
| 53 | + token = response[i : i + token_size] |
| 54 | + await mot._async_queue.put(token) |
| 55 | + await asyncio.sleep(0) |
| 56 | + i += token_size |
| 57 | + await mot._async_queue.put(None) |
| 58 | + |
| 59 | + |
| 60 | +class StreamingMockBackend(Backend): |
| 61 | + """Test double that streams a fixed response one token at a time. |
| 62 | +
|
| 63 | + ``token_size`` controls how many characters constitute one token. |
| 64 | + Validation calls (via ``stream_validate`` / ``validate``) are delegated |
| 65 | + to the requirements themselves — this backend does not perform any real |
| 66 | + inference. |
| 67 | + """ |
| 68 | + |
| 69 | + def __init__(self, response: str, token_size: int = 1) -> None: |
| 70 | + self._response = response |
| 71 | + self._token_size = token_size |
| 72 | + |
| 73 | + async def _generate_from_context( |
| 74 | + self, |
| 75 | + action: Any, |
| 76 | + ctx: Context, |
| 77 | + *, |
| 78 | + format: Any = None, |
| 79 | + model_options: dict | None = None, |
| 80 | + tool_calls: bool = False, |
| 81 | + ) -> tuple[ModelOutputThunk, Context]: |
| 82 | + _ = format, model_options, tool_calls |
| 83 | + mot = _make_mot() |
| 84 | + task = asyncio.create_task(_feed_tokens(mot, self._response, self._token_size)) |
| 85 | + _ = task |
| 86 | + new_ctx = ctx.add(action).add(mot) |
| 87 | + return mot, new_ctx |
| 88 | + |
| 89 | + async def generate_from_raw( |
| 90 | + self, actions: Any, ctx: Any, **kwargs: Any |
| 91 | + ) -> list[ModelOutputThunk]: |
| 92 | + raise NotImplementedError |
| 93 | + |
| 94 | + |
| 95 | +# --------------------------------------------------------------------------- |
| 96 | +# Requirement test doubles |
| 97 | +# --------------------------------------------------------------------------- |
| 98 | + |
| 99 | + |
| 100 | +class AlwaysUnknownReq(Requirement): |
| 101 | + """stream_validate always returns 'unknown'; validate returns True.""" |
| 102 | + |
| 103 | + def format_for_llm(self) -> str: |
| 104 | + return "always unknown" |
| 105 | + |
| 106 | + async def stream_validate( |
| 107 | + self, chunk: str, *, backend: Any, ctx: Any |
| 108 | + ) -> PartialValidationResult: |
| 109 | + return PartialValidationResult("unknown") |
| 110 | + |
| 111 | + async def validate( |
| 112 | + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None |
| 113 | + ) -> ValidationResult: |
| 114 | + return ValidationResult(result=True) |
| 115 | + |
| 116 | + |
| 117 | +class FailAfterWordsReq(Requirement): |
| 118 | + """Returns 'fail' once the accumulated text reaches *threshold* words.""" |
| 119 | + |
| 120 | + def __init__(self, threshold: int) -> None: |
| 121 | + self._threshold = threshold |
| 122 | + |
| 123 | + def format_for_llm(self) -> str: |
| 124 | + return f"fail after {self._threshold} words" |
| 125 | + |
| 126 | + async def stream_validate( |
| 127 | + self, chunk: str, *, backend: Any, ctx: Any |
| 128 | + ) -> PartialValidationResult: |
| 129 | + if len(chunk.split()) >= self._threshold: |
| 130 | + return PartialValidationResult("fail", reason="too many words") |
| 131 | + return PartialValidationResult("unknown") |
| 132 | + |
| 133 | + async def validate( |
| 134 | + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None |
| 135 | + ) -> ValidationResult: |
| 136 | + return ValidationResult(result=True) |
| 137 | + |
| 138 | + |
| 139 | +class BackendRecordingReq(Requirement): |
| 140 | + """Records which backend was passed to stream_validate and validate.""" |
| 141 | + |
| 142 | + def __init__(self) -> None: |
| 143 | + self.seen_backends: list[Any] = [] |
| 144 | + |
| 145 | + def __copy__(self) -> "BackendRecordingReq": |
| 146 | + clone = BackendRecordingReq() |
| 147 | + clone.seen_backends = [] # fresh list — do not share with original |
| 148 | + return clone |
| 149 | + |
| 150 | + def format_for_llm(self) -> str: |
| 151 | + return "backend recorder" |
| 152 | + |
| 153 | + async def stream_validate( |
| 154 | + self, chunk: str, *, backend: Any, ctx: Any |
| 155 | + ) -> PartialValidationResult: |
| 156 | + _ = chunk |
| 157 | + self.seen_backends.append(backend) |
| 158 | + return PartialValidationResult("unknown") |
| 159 | + |
| 160 | + async def validate( |
| 161 | + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None |
| 162 | + ) -> ValidationResult: |
| 163 | + self.seen_backends.append(backend) |
| 164 | + return ValidationResult(result=True) |
| 165 | + |
| 166 | + |
| 167 | +class MutationDetectorReq(Requirement): |
| 168 | + """Tracks how many times stream_validate was called on this instance.""" |
| 169 | + |
| 170 | + def __init__(self) -> None: |
| 171 | + self._call_count = 0 |
| 172 | + |
| 173 | + def format_for_llm(self) -> str: |
| 174 | + return "mutation detector" |
| 175 | + |
| 176 | + async def stream_validate( |
| 177 | + self, chunk: str, *, backend: Any, ctx: Any |
| 178 | + ) -> PartialValidationResult: |
| 179 | + _ = chunk, backend, ctx |
| 180 | + self._call_count += 1 |
| 181 | + return PartialValidationResult("unknown") |
| 182 | + |
| 183 | + async def validate( |
| 184 | + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None |
| 185 | + ) -> ValidationResult: |
| 186 | + return ValidationResult(result=True) |
| 187 | + |
| 188 | + |
| 189 | +# --------------------------------------------------------------------------- |
| 190 | +# Helpers |
| 191 | +# --------------------------------------------------------------------------- |
| 192 | + |
| 193 | + |
| 194 | +def _ctx() -> SimpleContext: |
| 195 | + return SimpleContext() |
| 196 | + |
| 197 | + |
| 198 | +def _action() -> CBlock: |
| 199 | + return CBlock("prompt") |
| 200 | + |
| 201 | + |
| 202 | +# --------------------------------------------------------------------------- |
| 203 | +# Tests |
| 204 | +# --------------------------------------------------------------------------- |
| 205 | + |
| 206 | + |
| 207 | +@pytest.mark.asyncio |
| 208 | +async def test_normal_completion_calls_validate_at_stream_end() -> None: |
| 209 | + """All 'unknown' requirements → validate() called at stream end; completed=True.""" |
| 210 | + response = "Hello world. How are you. " |
| 211 | + backend = StreamingMockBackend(response, token_size=3) |
| 212 | + req = AlwaysUnknownReq() |
| 213 | + |
| 214 | + result = await stream_with_chunking( |
| 215 | + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" |
| 216 | + ) |
| 217 | + await result.acomplete() |
| 218 | + |
| 219 | + assert result.completed is True |
| 220 | + assert result.full_text == response |
| 221 | + assert len(result.final_validations) == 1 |
| 222 | + assert result.final_validations[0].as_bool() is True |
| 223 | + assert result.streaming_failures == [] |
| 224 | + |
| 225 | + |
| 226 | +@pytest.mark.asyncio |
| 227 | +async def test_early_exit_on_fail() -> None: |
| 228 | + """Requirement fails mid-stream → completed=False, streaming_failures populated.""" |
| 229 | + # 5 words to trigger failure |
| 230 | + response = "one two three four five six seven eight. " |
| 231 | + backend = StreamingMockBackend(response, token_size=2) |
| 232 | + req = FailAfterWordsReq(threshold=4) |
| 233 | + |
| 234 | + result = await stream_with_chunking( |
| 235 | + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" |
| 236 | + ) |
| 237 | + await result.acomplete() |
| 238 | + |
| 239 | + assert result.completed is False |
| 240 | + assert len(result.streaming_failures) == 1 |
| 241 | + _req, pvr = result.streaming_failures[0] |
| 242 | + assert pvr.success == "fail" |
| 243 | + assert pvr.reason == "too many words" |
| 244 | + # final_validations should be empty — final validate() skipped on early exit |
| 245 | + assert result.final_validations == [] |
| 246 | + |
| 247 | + |
| 248 | +@pytest.mark.asyncio |
| 249 | +async def test_clone_isolation_across_retries() -> None: |
| 250 | + """Originals must not be mutated; two invocations are independent.""" |
| 251 | + response = "Sentence one. Sentence two. " |
| 252 | + req = MutationDetectorReq() |
| 253 | + original_reqs = [req] |
| 254 | + |
| 255 | + backend = StreamingMockBackend(response, token_size=4) |
| 256 | + |
| 257 | + r1 = await stream_with_chunking( |
| 258 | + _action(), |
| 259 | + backend, |
| 260 | + _ctx(), |
| 261 | + quick_check_requirements=original_reqs, |
| 262 | + chunking="sentence", |
| 263 | + ) |
| 264 | + await r1.acomplete() |
| 265 | + |
| 266 | + r2 = await stream_with_chunking( |
| 267 | + _action(), |
| 268 | + backend, |
| 269 | + _ctx(), |
| 270 | + quick_check_requirements=original_reqs, |
| 271 | + chunking="sentence", |
| 272 | + ) |
| 273 | + await r2.acomplete() |
| 274 | + |
| 275 | + # Original requirement must never have been called — only clones are used |
| 276 | + assert req._call_count == 0 |
| 277 | + |
| 278 | + |
| 279 | +@pytest.mark.asyncio |
| 280 | +async def test_quick_check_backend_routing() -> None: |
| 281 | + """stream_validate and validate receive quick_check_backend, not the main backend.""" |
| 282 | + response = "One sentence. Two sentences. " |
| 283 | + main_backend = StreamingMockBackend(response, token_size=3) |
| 284 | + val_backend = StreamingMockBackend("unused", token_size=1) |
| 285 | + |
| 286 | + req = BackendRecordingReq() |
| 287 | + |
| 288 | + result = await stream_with_chunking( |
| 289 | + _action(), |
| 290 | + main_backend, |
| 291 | + _ctx(), |
| 292 | + quick_check_requirements=[req], |
| 293 | + chunking="sentence", |
| 294 | + quick_check_backend=val_backend, |
| 295 | + ) |
| 296 | + await result.acomplete() |
| 297 | + |
| 298 | + # The clone's seen_backends should only contain val_backend |
| 299 | + # (The original req was never called; clones were.) |
| 300 | + # Verify via final_validations side-effect: at least one backend recorded |
| 301 | + assert result.completed is True |
| 302 | + # The original req._seen_backends is untouched (clone isolation) |
| 303 | + assert req.seen_backends == [] |
| 304 | + |
| 305 | + |
| 306 | +@pytest.mark.asyncio |
| 307 | +async def test_early_exit_does_not_deadlock() -> None: |
| 308 | + """Early failure with a high-throughput stream must not hang.""" |
| 309 | + long_response = "word " * 200 |
| 310 | + backend = StreamingMockBackend(long_response, token_size=5) |
| 311 | + req = FailAfterWordsReq(threshold=3) |
| 312 | + |
| 313 | + result = await stream_with_chunking( |
| 314 | + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" |
| 315 | + ) |
| 316 | + # 5-second timeout — should complete in milliseconds on success |
| 317 | + await asyncio.wait_for(result.acomplete(), timeout=5.0) |
| 318 | + |
| 319 | + assert result.completed is False |
| 320 | + |
| 321 | + |
| 322 | +@pytest.mark.asyncio |
| 323 | +async def test_as_thunk_correctness() -> None: |
| 324 | + """as_thunk is computed, value matches full_text, generation metadata preserved.""" |
| 325 | + response = "This is a test sentence. " |
| 326 | + backend = StreamingMockBackend(response, token_size=4) |
| 327 | + |
| 328 | + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") |
| 329 | + await result.acomplete() |
| 330 | + |
| 331 | + thunk = result.as_thunk |
| 332 | + assert thunk.is_computed() |
| 333 | + assert thunk.value == result.full_text == response |
| 334 | + |
| 335 | + |
| 336 | +@pytest.mark.asyncio |
| 337 | +async def test_as_thunk_raises_before_acomplete() -> None: |
| 338 | + """as_thunk raises RuntimeError if accessed before acomplete().""" |
| 339 | + response = "Some text. " |
| 340 | + backend = StreamingMockBackend(response, token_size=2) |
| 341 | + |
| 342 | + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") |
| 343 | + |
| 344 | + with pytest.raises(RuntimeError, match="acomplete"): |
| 345 | + _ = result.as_thunk |
| 346 | + |
| 347 | + |
| 348 | +@pytest.mark.asyncio |
| 349 | +async def test_astream_yields_individual_chunks() -> None: |
| 350 | + """Consumer via astream() receives individual chunks, not accumulated text.""" |
| 351 | + response = "First sentence. Second sentence. Third sentence. " |
| 352 | + backend = StreamingMockBackend(response, token_size=5) |
| 353 | + |
| 354 | + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") |
| 355 | + |
| 356 | + chunks: list[str] = [] |
| 357 | + async for chunk in result.astream(): |
| 358 | + chunks.append(chunk) |
| 359 | + |
| 360 | + await result.acomplete() |
| 361 | + |
| 362 | + # Each chunk must be a complete sentence (not the accumulated text) |
| 363 | + assert len(chunks) == 3 |
| 364 | + for chunk in chunks: |
| 365 | + assert chunk.endswith(".") |
| 366 | + # Chunks don't include inter-sentence spaces; joined with a space they appear in full_text |
| 367 | + assert " ".join(chunks) in result.full_text |
| 368 | + |
| 369 | + |
| 370 | +@pytest.mark.asyncio |
| 371 | +async def test_no_requirements_streams_without_validation() -> None: |
| 372 | + """quick_check_requirements=None → chunks produced, no validate() called.""" |
| 373 | + response = "Chunk one. Chunk two. Chunk three. " |
| 374 | + backend = StreamingMockBackend(response, token_size=3) |
| 375 | + |
| 376 | + result = await stream_with_chunking( |
| 377 | + _action(), backend, _ctx(), quick_check_requirements=None, chunking="sentence" |
| 378 | + ) |
| 379 | + await result.acomplete() |
| 380 | + |
| 381 | + assert result.completed is True |
| 382 | + assert result.full_text == response |
| 383 | + assert result.final_validations == [] |
| 384 | + assert result.streaming_failures == [] |
0 commit comments