Skip to content

Commit e1e0556

Browse files
committed
fix(core): address PR #925 review feedback on stream_validate
- Remove "In Phase 1" temporal qualifier from docstring — reworded to timeless statement about orchestrator responsibility - Add type annotations (str, Backend, Context) to test subclass overrides - Add idempotency test: multiple calls on the same Requirement instance leave state unchanged Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com>
1 parent 9be937b commit e1e0556

2 files changed

Lines changed: 20 additions & 4 deletions

File tree

mellea/core/requirement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ async def stream_validate(
304304
Returns:
305305
PartialValidationResult: ``"unknown"`` by default. Subclasses may return
306306
``"pass"`` (constraint satisfied so far) or ``"fail"`` (constraint violated,
307-
streaming should be aborted). In Phase 1, ``"pass"`` is informational and
308-
does not short-circuit the final ``validate()`` call.
307+
streaming should be aborted). ``"pass"`` does not short-circuit the final
308+
``validate()`` call; the orchestrator decides whether to skip it.
309309
"""
310310
return PartialValidationResult("unknown")
311311

test/core/test_stream_validate.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pytest
66

77
from mellea.core import PartialValidationResult, Requirement
8+
from mellea.core.backend import Backend
9+
from mellea.core.base import Context
810

911

1012
@pytest.mark.asyncio
@@ -29,7 +31,9 @@ def test_stream_validate_is_coroutine():
2931
@pytest.mark.asyncio
3032
async def test_subclass_can_return_pass():
3133
class PassRequirement(Requirement):
32-
async def stream_validate(self, chunk, backend, ctx) -> PartialValidationResult:
34+
async def stream_validate(
35+
self, chunk: str, backend: Backend, ctx: Context
36+
) -> PartialValidationResult:
3337
return PartialValidationResult("pass")
3438

3539
req = PassRequirement(description="always passes")
@@ -40,7 +44,9 @@ async def stream_validate(self, chunk, backend, ctx) -> PartialValidationResult:
4044
@pytest.mark.asyncio
4145
async def test_subclass_can_return_fail():
4246
class FailRequirement(Requirement):
43-
async def stream_validate(self, chunk, backend, ctx) -> PartialValidationResult:
47+
async def stream_validate(
48+
self, chunk: str, backend: Backend, ctx: Context
49+
) -> PartialValidationResult:
4450
if "bad" in chunk:
4551
return PartialValidationResult("fail", reason="bad word detected")
4652
return PartialValidationResult("unknown")
@@ -66,3 +72,13 @@ async def test_does_not_mutate_requirement():
6672
assert req.description == original_description
6773
assert req._output == original_output
6874
assert req.validation_fn == original_validation_fn
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_stream_validate_idempotent():
79+
req = Requirement(description="repeated calls")
80+
result1 = await req.stream_validate("chunk one", backend=None, ctx=None) # type: ignore[arg-type]
81+
result2 = await req.stream_validate("chunk two", backend=None, ctx=None) # type: ignore[arg-type]
82+
assert result1.success == "unknown"
83+
assert result2.success == "unknown"
84+
assert req._output is None

0 commit comments

Comments
 (0)