Skip to content

Commit c1f3ab1

Browse files
committed
fix(core): fix stream_validate docstring and add missing stateful tests
The docstring incorrectly stated that implementations must not mutate self. Issue #900 spec explicitly allows stateful accumulation and requires the shallow-copy caveat to be documented. Fix the docstring to match the spec. Add two tests required by the issue acceptance criteria: - test_stateful_subclass_accumulates_state: verifies a subclass can accumulate state (bullet counter) across stream_validate calls - test_stateful_subclass_clone_isolation: verifies copy() gives an independent clone, confirming the orchestrator clone pattern Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com>
1 parent 6296007 commit c1f3ab1

2 files changed

Lines changed: 77 additions & 3 deletions

File tree

mellea/core/requirement.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,15 @@ async def stream_validate(
292292
— meaning insufficient data to decide yet. Subclasses override this method
293293
to inspect the accumulated chunk and return ``"pass"`` or ``"fail"`` early.
294294
295-
This method must not mutate ``self``. The orchestrator is responsible for
296-
cloning the requirement before each attempt; any state needed across chunks
297-
must be managed externally.
295+
Implementations may accumulate state on ``self`` across calls within a
296+
single attempt. The orchestrator clones the requirement (``copy(req)``)
297+
before each attempt, so state does not bleed across retries.
298+
299+
Shallow-copy caveat: mutable container fields (e.g. ``self._buffer = []``)
300+
are shared by reference under ``copy()``. Reassign rather than mutate in
301+
place (``self._buffer = self._buffer + [chunk]``, not
302+
``self._buffer.append(chunk)``), or override ``__copy__`` for proper
303+
isolation.
298304
299305
Args:
300306
chunk: The accumulated model output so far (not just the latest token).

test/core/test_stream_validate.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Unit tests for Requirement.stream_validate() hook."""
22

33
import inspect
4+
from copy import copy
45

56
import pytest
67

@@ -82,3 +83,70 @@ async def test_stream_validate_idempotent():
8283
assert result1.success == "unknown"
8384
assert result2.success == "unknown"
8485
assert req._output is None
86+
87+
88+
@pytest.mark.asyncio
89+
async def test_stateful_subclass_accumulates_state():
90+
"""Stateful subclass correctly accumulates state across stream_validate calls."""
91+
92+
class BulletCounter(Requirement):
93+
def __init__(self) -> None:
94+
super().__init__(description="no more than 3 bullets")
95+
self._bullet_count = 0
96+
97+
async def stream_validate(
98+
self, chunk: str, *, backend: Backend, ctx: Context
99+
) -> PartialValidationResult:
100+
self._bullet_count = chunk.count("\n-")
101+
if self._bullet_count > 3:
102+
return PartialValidationResult(
103+
"fail", reason=f"{self._bullet_count} bullets exceeds limit"
104+
)
105+
return PartialValidationResult("unknown")
106+
107+
req = BulletCounter()
108+
assert req._bullet_count == 0
109+
110+
await req.stream_validate("intro text", backend=None, ctx=None) # type: ignore[arg-type]
111+
assert req._bullet_count == 0
112+
113+
await req.stream_validate("intro\n- one\n- two", backend=None, ctx=None) # type: ignore[arg-type]
114+
assert req._bullet_count == 2
115+
116+
result = await req.stream_validate(
117+
"intro\n- one\n- two\n- three\n- four",
118+
backend=None,
119+
ctx=None, # type: ignore[arg-type]
120+
)
121+
assert req._bullet_count == 4
122+
assert result.success == "fail"
123+
assert result.reason is not None and "4" in result.reason
124+
125+
126+
@pytest.mark.asyncio
127+
async def test_stateful_subclass_clone_isolation():
128+
"""copy() of a stateful requirement gives an independent clone — orchestrator pattern."""
129+
130+
class CallCounter(Requirement):
131+
def __init__(self) -> None:
132+
super().__init__(description="call counter")
133+
self._calls = 0
134+
135+
async def stream_validate(
136+
self, chunk: str, *, backend: Backend, ctx: Context
137+
) -> PartialValidationResult:
138+
self._calls += 1
139+
return PartialValidationResult("unknown")
140+
141+
req = CallCounter()
142+
await req.stream_validate("a", backend=None, ctx=None) # type: ignore[arg-type]
143+
await req.stream_validate("b", backend=None, ctx=None) # type: ignore[arg-type]
144+
assert req._calls == 2
145+
146+
# Simulate orchestrator cloning before a new attempt
147+
cloned = copy(req)
148+
assert cloned._calls == 2 # clone inherits state at clone time
149+
150+
await cloned.stream_validate("c", backend=None, ctx=None) # type: ignore[arg-type]
151+
assert cloned._calls == 3 # clone advances independently
152+
assert req._calls == 2 # original is unchanged

0 commit comments

Comments
 (0)